polars_mem_engine/
predicate.rs

1use core::fmt;
2use std::sync::Arc;
3
4use arrow::bitmap::Bitmap;
5use polars_core::frame::DataFrame;
6use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
7use polars_core::scalar::Scalar;
8use polars_core::schema::{Schema, SchemaRef};
9use polars_error::PolarsResult;
10use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};
11use polars_expr::state::ExecutionState;
12use polars_io::predicates::{
13    ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicateExpr,
14};
15use polars_utils::pl_str::PlSmallStr;
16use polars_utils::{IdxSize, format_pl_smallstr};
17
18/// All the expressions and metadata used to filter out rows using predicates.
19#[derive(Clone)]
20pub struct ScanPredicate {
21    pub predicate: Arc<dyn PhysicalExpr>,
22
23    /// Column names that are used in the predicate.
24    pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
25
26    /// A predicate expression used to skip record batches based on its statistics.
27    ///
28    /// This expression will be given a batch size along with a `min`, `max` and `null count` for
29    /// each live column (set to `null` when it is not known) and the expression evaluates to
30    /// `true` if the whole batch can for sure be skipped. This may be conservative and evaluate to
31    /// `false` even when the batch could theoretically be skipped.
32    pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,
33
34    /// Partial predicates for each column for filter when loading columnar formats.
35    pub column_predicates: PhysicalColumnPredicates,
36
37    /// Predicate only referring to hive columns.
38    pub hive_predicate: Option<Arc<dyn PhysicalExpr>>,
39    pub hive_predicate_is_full_predicate: bool,
40}
41
42impl fmt::Debug for ScanPredicate {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.write_str("scan_predicate")
45    }
46}
47
48#[derive(Clone)]
49pub struct PhysicalColumnPredicates {
50    pub predicates: PlHashMap<
51        PlSmallStr,
52        (
53            Arc<dyn PhysicalExpr>,
54            Option<SpecializedColumnPredicateExpr>,
55        ),
56    >,
57    pub is_sumwise_complete: bool,
58}
59
60/// Helper to implement [`SkipBatchPredicate`].
61struct SkipBatchPredicateHelper {
62    skip_batch_predicate: Arc<dyn PhysicalExpr>,
63    schema: SchemaRef,
64}
65
66/// Helper for the [`PhysicalExpr`] trait to include constant columns.
67pub struct PhysicalExprWithConstCols {
68    constants: Vec<(PlSmallStr, Scalar)>,
69    child: Arc<dyn PhysicalExpr>,
70}
71
72impl PhysicalExpr for PhysicalExprWithConstCols {
73    fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
74        let mut df = df.clone();
75        for (name, scalar) in &self.constants {
76            df.with_column(Column::new_scalar(
77                name.clone(),
78                scalar.clone(),
79                df.height(),
80            ))?;
81        }
82
83        self.child.evaluate(&df, state)
84    }
85
86    fn evaluate_on_groups<'a>(
87        &self,
88        df: &DataFrame,
89        groups: &'a GroupPositions,
90        state: &ExecutionState,
91    ) -> PolarsResult<AggregationContext<'a>> {
92        let mut df = df.clone();
93        for (name, scalar) in &self.constants {
94            df.with_column(Column::new_scalar(
95                name.clone(),
96                scalar.clone(),
97                df.height(),
98            ))?;
99        }
100
101        self.child.evaluate_on_groups(&df, groups, state)
102    }
103
104    fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
105        self.child.to_field(input_schema)
106    }
107    fn is_scalar(&self) -> bool {
108        self.child.is_scalar()
109    }
110}
111
112impl ScanPredicate {
113    pub fn with_constant_columns(
114        &self,
115        constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,
116    ) -> Self {
117        let constant_columns = constant_columns.into_iter();
118
119        let mut live_columns = self.live_columns.as_ref().clone();
120        let mut skip_batch_predicate_constants =
121            Vec::with_capacity(if self.skip_batch_predicate.is_some() {
122                1 + constant_columns.size_hint().0 * 3
123            } else {
124                Default::default()
125            });
126
127        let predicate_constants = constant_columns
128            .filter_map(|(name, scalar): (PlSmallStr, Scalar)| {
129                if !live_columns.swap_remove(&name) {
130                    return None;
131                }
132
133                if self.skip_batch_predicate.is_some() {
134                    let mut null_count: Scalar = (0 as IdxSize).into();
135
136                    // If the constant value is Null, we don't know how many nulls there are
137                    // because the length of the batch may vary.
138                    if scalar.is_null() {
139                        null_count.update(AnyValue::Null);
140                    }
141
142                    skip_batch_predicate_constants.extend([
143                        (format_pl_smallstr!("{name}_min"), scalar.clone()),
144                        (format_pl_smallstr!("{name}_max"), scalar.clone()),
145                        (format_pl_smallstr!("{name}_nc"), null_count),
146                    ]);
147                }
148
149                Some((name, scalar))
150            })
151            .collect();
152
153        let predicate = Arc::new(PhysicalExprWithConstCols {
154            constants: predicate_constants,
155            child: self.predicate.clone(),
156        });
157        let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {
158            Arc::new(PhysicalExprWithConstCols {
159                constants: skip_batch_predicate_constants,
160                child: skp.clone(),
161            }) as _
162        });
163
164        Self {
165            predicate,
166            live_columns: Arc::new(live_columns),
167            skip_batch_predicate,
168            column_predicates: self.column_predicates.clone(), // Q? Maybe this should cull
169            // predicates.
170            hive_predicate: None,
171            hive_predicate_is_full_predicate: false,
172        }
173    }
174
175    /// Create a predicate to skip batches using statistics.
176    pub(crate) fn to_dyn_skip_batch_predicate(
177        &self,
178        schema: SchemaRef,
179    ) -> Option<Arc<dyn SkipBatchPredicate>> {
180        let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
181        Some(Arc::new(SkipBatchPredicateHelper {
182            skip_batch_predicate,
183            schema,
184        }))
185    }
186
187    pub fn to_io(
188        &self,
189        skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
190        schema: SchemaRef,
191    ) -> ScanIOPredicate {
192        ScanIOPredicate {
193            predicate: phys_expr_to_io_expr(self.predicate.clone()),
194            live_columns: self.live_columns.clone(),
195            skip_batch_predicate: skip_batch_predicate
196                .cloned()
197                .or_else(|| self.to_dyn_skip_batch_predicate(schema)),
198            column_predicates: Arc::new(ColumnPredicates {
199                predicates: self
200                    .column_predicates
201                    .predicates
202                    .iter()
203                    .map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))
204                    .collect(),
205                is_sumwise_complete: self.column_predicates.is_sumwise_complete,
206            }),
207            hive_predicate: self.hive_predicate.clone().map(phys_expr_to_io_expr),
208            hive_predicate_is_full_predicate: self.hive_predicate_is_full_predicate,
209        }
210    }
211}
212
213impl SkipBatchPredicate for SkipBatchPredicateHelper {
214    fn schema(&self) -> &SchemaRef {
215        &self.schema
216    }
217
218    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
219        let array = self
220            .skip_batch_predicate
221            .evaluate(df, &Default::default())?;
222        let array = array.bool()?;
223        let array = array.downcast_as_array();
224
225        if let Some(validity) = array.validity() {
226            Ok(array.values() & validity)
227        } else {
228            Ok(array.values().clone())
229        }
230    }
231}