polars_mem_engine/
predicate.rs

1use core::fmt;
2use std::sync::Arc;
3
4use arrow::bitmap::Bitmap;
5use polars_core::frame::DataFrame;
6use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
7use polars_core::scalar::Scalar;
8use polars_core::schema::{Schema, SchemaRef};
9use polars_error::PolarsResult;
10use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};
11use polars_expr::state::ExecutionState;
12use polars_io::predicates::{
13    ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicate,
14};
15use polars_utils::pl_str::PlSmallStr;
16use polars_utils::{IdxSize, format_pl_smallstr};
17
18/// All the expressions and metadata used to filter out rows using predicates.
19#[derive(Clone)]
20pub struct ScanPredicate {
21    pub predicate: Arc<dyn PhysicalExpr>,
22
23    /// Column names that are used in the predicate.
24    pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
25
26    /// A predicate expression used to skip record batches based on its statistics.
27    ///
28    /// This expression will be given a batch size along with a `min`, `max` and `null count` for
29    /// each live column (set to `null` when it is not known) and the expression evaluates to
30    /// `true` if the whole batch can for sure be skipped. This may be conservative and evaluate to
31    /// `false` even when the batch could theoretically be skipped.
32    pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,
33
34    /// Partial predicates for each column for filter when loading columnar formats.
35    pub column_predicates: PhysicalColumnPredicates,
36
37    /// Predicate only referring to hive columns.
38    pub hive_predicate: Option<Arc<dyn PhysicalExpr>>,
39    pub hive_predicate_is_full_predicate: bool,
40}
41
42impl fmt::Debug for ScanPredicate {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.write_str("scan_predicate")
45    }
46}
47
48#[derive(Clone)]
49pub struct PhysicalColumnPredicates {
50    pub predicates:
51        PlHashMap<PlSmallStr, (Arc<dyn PhysicalExpr>, Option<SpecializedColumnPredicate>)>,
52    pub is_sumwise_complete: bool,
53}
54
55/// Helper to implement [`SkipBatchPredicate`].
56struct SkipBatchPredicateHelper {
57    skip_batch_predicate: Arc<dyn PhysicalExpr>,
58    schema: SchemaRef,
59}
60
61/// Helper for the [`PhysicalExpr`] trait to include constant columns.
62pub struct PhysicalExprWithConstCols {
63    constants: Vec<(PlSmallStr, Scalar)>,
64    child: Arc<dyn PhysicalExpr>,
65}
66
67impl PhysicalExpr for PhysicalExprWithConstCols {
68    fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
69        let mut df = df.clone();
70        for (name, scalar) in &self.constants {
71            df.with_column(Column::new_scalar(
72                name.clone(),
73                scalar.clone(),
74                df.height(),
75            ))?;
76        }
77
78        self.child.evaluate(&df, state)
79    }
80
81    fn evaluate_on_groups<'a>(
82        &self,
83        df: &DataFrame,
84        groups: &'a GroupPositions,
85        state: &ExecutionState,
86    ) -> PolarsResult<AggregationContext<'a>> {
87        let mut df = df.clone();
88        for (name, scalar) in &self.constants {
89            df.with_column(Column::new_scalar(
90                name.clone(),
91                scalar.clone(),
92                df.height(),
93            ))?;
94        }
95
96        self.child.evaluate_on_groups(&df, groups, state)
97    }
98
99    fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
100        self.child.to_field(input_schema)
101    }
102    fn is_scalar(&self) -> bool {
103        self.child.is_scalar()
104    }
105}
106
107impl ScanPredicate {
108    pub fn with_constant_columns(
109        &self,
110        constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,
111    ) -> Self {
112        let constant_columns = constant_columns.into_iter();
113
114        let mut live_columns = self.live_columns.as_ref().clone();
115        let mut skip_batch_predicate_constants =
116            Vec::with_capacity(if self.skip_batch_predicate.is_some() {
117                1 + constant_columns.size_hint().0 * 3
118            } else {
119                Default::default()
120            });
121
122        let predicate_constants = constant_columns
123            .filter_map(|(name, scalar): (PlSmallStr, Scalar)| {
124                if !live_columns.swap_remove(&name) {
125                    return None;
126                }
127
128                if self.skip_batch_predicate.is_some() {
129                    let mut null_count: Scalar = (0 as IdxSize).into();
130
131                    // If the constant value is Null, we don't know how many nulls there are
132                    // because the length of the batch may vary.
133                    if scalar.is_null() {
134                        null_count.update(AnyValue::Null);
135                    }
136
137                    skip_batch_predicate_constants.extend([
138                        (format_pl_smallstr!("{name}_min"), scalar.clone()),
139                        (format_pl_smallstr!("{name}_max"), scalar.clone()),
140                        (format_pl_smallstr!("{name}_nc"), null_count),
141                    ]);
142                }
143
144                Some((name, scalar))
145            })
146            .collect();
147
148        let predicate = Arc::new(PhysicalExprWithConstCols {
149            constants: predicate_constants,
150            child: self.predicate.clone(),
151        });
152        let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {
153            Arc::new(PhysicalExprWithConstCols {
154                constants: skip_batch_predicate_constants,
155                child: skp.clone(),
156            }) as _
157        });
158
159        Self {
160            predicate,
161            live_columns: Arc::new(live_columns),
162            skip_batch_predicate,
163            column_predicates: self.column_predicates.clone(), // Q? Maybe this should cull
164            // predicates.
165            hive_predicate: None,
166            hive_predicate_is_full_predicate: false,
167        }
168    }
169
170    /// Create a predicate to skip batches using statistics.
171    pub(crate) fn to_dyn_skip_batch_predicate(
172        &self,
173        schema: SchemaRef,
174    ) -> Option<Arc<dyn SkipBatchPredicate>> {
175        let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
176        Some(Arc::new(SkipBatchPredicateHelper {
177            skip_batch_predicate,
178            schema,
179        }))
180    }
181
182    pub fn to_io(
183        &self,
184        skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
185        schema: SchemaRef,
186    ) -> ScanIOPredicate {
187        ScanIOPredicate {
188            predicate: phys_expr_to_io_expr(self.predicate.clone()),
189            live_columns: self.live_columns.clone(),
190            skip_batch_predicate: skip_batch_predicate
191                .cloned()
192                .or_else(|| self.to_dyn_skip_batch_predicate(schema)),
193            column_predicates: Arc::new(ColumnPredicates {
194                predicates: self
195                    .column_predicates
196                    .predicates
197                    .iter()
198                    .map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))
199                    .collect(),
200                is_sumwise_complete: self.column_predicates.is_sumwise_complete,
201            }),
202            hive_predicate: self.hive_predicate.clone().map(phys_expr_to_io_expr),
203            hive_predicate_is_full_predicate: self.hive_predicate_is_full_predicate,
204        }
205    }
206}
207
208impl SkipBatchPredicate for SkipBatchPredicateHelper {
209    fn schema(&self) -> &SchemaRef {
210        &self.schema
211    }
212
213    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
214        let array = self
215            .skip_batch_predicate
216            .evaluate(df, &Default::default())?;
217        let array = array.bool()?;
218        let array = array.downcast_as_array();
219
220        if let Some(validity) = array.validity() {
221            Ok(array.values() & validity)
222        } else {
223            Ok(array.values().clone())
224        }
225    }
226}