polars_mem_engine/
predicate.rs

1use core::fmt;
2use std::sync::Arc;
3
4use arrow::bitmap::Bitmap;
5use polars_core::frame::DataFrame;
6use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
7use polars_core::scalar::Scalar;
8use polars_core::schema::{Schema, SchemaRef};
9use polars_error::PolarsResult;
10use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};
11use polars_expr::state::ExecutionState;
12use polars_io::predicates::{
13    ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicateExpr,
14};
15use polars_utils::pl_str::PlSmallStr;
16use polars_utils::{IdxSize, format_pl_smallstr};
17
18/// All the expressions and metadata used to filter out rows using predicates.
19#[derive(Clone)]
20pub struct ScanPredicate {
21    pub predicate: Arc<dyn PhysicalExpr>,
22
23    /// Column names that are used in the predicate.
24    pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
25
26    /// A predicate expression used to skip record batches based on its statistics.
27    ///
28    /// This expression will be given a batch size along with a `min`, `max` and `null count` for
29    /// each live column (set to `null` when it is not known) and the expression evaluates to
30    /// `true` if the whole batch can for sure be skipped. This may be conservative and evaluate to
31    /// `false` even when the batch could theoretically be skipped.
32    pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,
33
34    /// Partial predicates for each column for filter when loading columnar formats.
35    pub column_predicates: PhysicalColumnPredicates,
36}
37
38impl fmt::Debug for ScanPredicate {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.write_str("scan_predicate")
41    }
42}
43
44#[derive(Clone)]
45pub struct PhysicalColumnPredicates {
46    pub predicates: PlHashMap<
47        PlSmallStr,
48        (
49            Arc<dyn PhysicalExpr>,
50            Option<SpecializedColumnPredicateExpr>,
51        ),
52    >,
53    pub is_sumwise_complete: bool,
54}
55
56/// Helper to implement [`SkipBatchPredicate`].
57struct SkipBatchPredicateHelper {
58    skip_batch_predicate: Arc<dyn PhysicalExpr>,
59    schema: SchemaRef,
60}
61
62/// Helper for the [`PhysicalExpr`] trait to include constant columns.
63pub struct PhysicalExprWithConstCols {
64    constants: Vec<(PlSmallStr, Scalar)>,
65    child: Arc<dyn PhysicalExpr>,
66}
67
68impl PhysicalExpr for PhysicalExprWithConstCols {
69    fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
70        let mut df = df.clone();
71        for (name, scalar) in &self.constants {
72            df.with_column(Column::new_scalar(
73                name.clone(),
74                scalar.clone(),
75                df.height(),
76            ))?;
77        }
78
79        self.child.evaluate(&df, state)
80    }
81
82    fn evaluate_on_groups<'a>(
83        &self,
84        df: &DataFrame,
85        groups: &'a GroupPositions,
86        state: &ExecutionState,
87    ) -> PolarsResult<AggregationContext<'a>> {
88        let mut df = df.clone();
89        for (name, scalar) in &self.constants {
90            df.with_column(Column::new_scalar(
91                name.clone(),
92                scalar.clone(),
93                df.height(),
94            ))?;
95        }
96
97        self.child.evaluate_on_groups(&df, groups, state)
98    }
99
100    fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
101        self.child.to_field(input_schema)
102    }
103    fn is_scalar(&self) -> bool {
104        self.child.is_scalar()
105    }
106}
107
108impl ScanPredicate {
109    pub fn with_constant_columns(
110        &self,
111        constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,
112    ) -> Self {
113        let constant_columns = constant_columns.into_iter();
114
115        let mut live_columns = self.live_columns.as_ref().clone();
116        let mut skip_batch_predicate_constants = Vec::with_capacity(
117            self.skip_batch_predicate
118                .is_some()
119                .then_some(1 + constant_columns.size_hint().0 * 3)
120                .unwrap_or_default(),
121        );
122
123        let predicate_constants = constant_columns
124            .filter_map(|(name, scalar): (PlSmallStr, Scalar)| {
125                if !live_columns.swap_remove(&name) {
126                    return None;
127                }
128
129                if self.skip_batch_predicate.is_some() {
130                    let mut null_count: Scalar = (0 as IdxSize).into();
131
132                    // If the constant value is Null, we don't know how many nulls there are
133                    // because the length of the batch may vary.
134                    if scalar.is_null() {
135                        null_count.update(AnyValue::Null);
136                    }
137
138                    skip_batch_predicate_constants.extend([
139                        (format_pl_smallstr!("{name}_min"), scalar.clone()),
140                        (format_pl_smallstr!("{name}_max"), scalar.clone()),
141                        (format_pl_smallstr!("{name}_nc"), null_count),
142                    ]);
143                }
144
145                Some((name, scalar))
146            })
147            .collect();
148
149        let predicate = Arc::new(PhysicalExprWithConstCols {
150            constants: predicate_constants,
151            child: self.predicate.clone(),
152        });
153        let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {
154            Arc::new(PhysicalExprWithConstCols {
155                constants: skip_batch_predicate_constants,
156                child: skp.clone(),
157            }) as _
158        });
159
160        Self {
161            predicate,
162            live_columns: Arc::new(live_columns),
163            skip_batch_predicate,
164            column_predicates: self.column_predicates.clone(), // Q? Maybe this should cull
165                                                               // predicates.
166        }
167    }
168
169    /// Create a predicate to skip batches using statistics.
170    pub(crate) fn to_dyn_skip_batch_predicate(
171        &self,
172        schema: SchemaRef,
173    ) -> Option<Arc<dyn SkipBatchPredicate>> {
174        let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
175        Some(Arc::new(SkipBatchPredicateHelper {
176            skip_batch_predicate,
177            schema,
178        }))
179    }
180
181    pub fn to_io(
182        &self,
183        skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
184        schema: SchemaRef,
185    ) -> ScanIOPredicate {
186        ScanIOPredicate {
187            predicate: phys_expr_to_io_expr(self.predicate.clone()),
188            live_columns: self.live_columns.clone(),
189            skip_batch_predicate: skip_batch_predicate
190                .cloned()
191                .or_else(|| self.to_dyn_skip_batch_predicate(schema)),
192            column_predicates: Arc::new(ColumnPredicates {
193                predicates: self
194                    .column_predicates
195                    .predicates
196                    .iter()
197                    .map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))
198                    .collect(),
199                is_sumwise_complete: self.column_predicates.is_sumwise_complete,
200            }),
201        }
202    }
203}
204
205impl SkipBatchPredicate for SkipBatchPredicateHelper {
206    fn schema(&self) -> &SchemaRef {
207        &self.schema
208    }
209
210    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
211        let array = self
212            .skip_batch_predicate
213            .evaluate(df, &Default::default())?;
214        let array = array.bool()?;
215        let array = array.downcast_as_array();
216
217        if let Some(validity) = array.validity() {
218            Ok(array.values() & validity)
219        } else {
220            Ok(array.values().clone())
221        }
222    }
223}