polars_mem_engine/scan_predicate/
mod.rs

1pub mod functions;
2pub mod skip_files_mask;
3use core::fmt;
4use std::sync::Arc;
5
6use arrow::bitmap::Bitmap;
7pub use functions::{create_scan_predicate, initialize_scan_predicate};
8use polars_core::frame::DataFrame;
9use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
10use polars_core::scalar::Scalar;
11use polars_core::schema::{Schema, SchemaRef};
12use polars_error::PolarsResult;
13use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};
14use polars_expr::state::ExecutionState;
15use polars_io::predicates::{
16    ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicate,
17};
18use polars_utils::pl_str::PlSmallStr;
19use polars_utils::{IdxSize, format_pl_smallstr};
20
21/// All the expressions and metadata used to filter out rows using predicates.
22#[derive(Clone)]
23pub struct ScanPredicate {
24    pub predicate: Arc<dyn PhysicalExpr>,
25
26    /// Column names that are used in the predicate.
27    pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
28
29    /// A predicate expression used to skip record batches based on its statistics.
30    ///
31    /// This expression will be given a batch size along with a `min`, `max` and `null count` for
32    /// each live column (set to `null` when it is not known) and the expression evaluates to
33    /// `true` if the whole batch can for sure be skipped. This may be conservative and evaluate to
34    /// `false` even when the batch could theoretically be skipped.
35    pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,
36
37    /// Partial predicates for each column for filter when loading columnar formats.
38    pub column_predicates: PhysicalColumnPredicates,
39
40    /// Predicate only referring to hive columns.
41    pub hive_predicate: Option<Arc<dyn PhysicalExpr>>,
42    pub hive_predicate_is_full_predicate: bool,
43}
44
45impl fmt::Debug for ScanPredicate {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        f.write_str("scan_predicate")
48    }
49}
50
51#[derive(Clone)]
52pub struct PhysicalColumnPredicates {
53    pub predicates:
54        PlHashMap<PlSmallStr, (Arc<dyn PhysicalExpr>, Option<SpecializedColumnPredicate>)>,
55    pub is_sumwise_complete: bool,
56}
57
58/// Helper to implement [`SkipBatchPredicate`].
59struct SkipBatchPredicateHelper {
60    skip_batch_predicate: Arc<dyn PhysicalExpr>,
61    schema: SchemaRef,
62}
63
64/// Helper for the [`PhysicalExpr`] trait to include constant columns.
65pub struct PhysicalExprWithConstCols {
66    constants: Vec<(PlSmallStr, Scalar)>,
67    child: Arc<dyn PhysicalExpr>,
68}
69
70impl PhysicalExpr for PhysicalExprWithConstCols {
71    fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
72        let mut df = df.clone();
73        for (name, scalar) in &self.constants {
74            df.with_column(Column::new_scalar(
75                name.clone(),
76                scalar.clone(),
77                df.height(),
78            ))?;
79        }
80
81        self.child.evaluate(&df, state)
82    }
83
84    fn evaluate_on_groups<'a>(
85        &self,
86        df: &DataFrame,
87        groups: &'a GroupPositions,
88        state: &ExecutionState,
89    ) -> PolarsResult<AggregationContext<'a>> {
90        let mut df = df.clone();
91        for (name, scalar) in &self.constants {
92            df.with_column(Column::new_scalar(
93                name.clone(),
94                scalar.clone(),
95                df.height(),
96            ))?;
97        }
98
99        self.child.evaluate_on_groups(&df, groups, state)
100    }
101
102    fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
103        self.child.to_field(input_schema)
104    }
105    fn is_scalar(&self) -> bool {
106        self.child.is_scalar()
107    }
108}
109
110impl ScanPredicate {
111    pub fn with_constant_columns(
112        &self,
113        constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,
114    ) -> Self {
115        let constant_columns = constant_columns.into_iter();
116
117        let mut live_columns = self.live_columns.as_ref().clone();
118        let mut skip_batch_predicate_constants =
119            Vec::with_capacity(if self.skip_batch_predicate.is_some() {
120                1 + constant_columns.size_hint().0 * 3
121            } else {
122                Default::default()
123            });
124
125        let predicate_constants = constant_columns
126            .filter_map(|(name, scalar): (PlSmallStr, Scalar)| {
127                if !live_columns.swap_remove(&name) {
128                    return None;
129                }
130
131                if self.skip_batch_predicate.is_some() {
132                    let mut null_count: Scalar = (0 as IdxSize).into();
133
134                    // If the constant value is Null, we don't know how many nulls there are
135                    // because the length of the batch may vary.
136                    if scalar.is_null() {
137                        null_count.update(AnyValue::Null);
138                    }
139
140                    skip_batch_predicate_constants.extend([
141                        (format_pl_smallstr!("{name}_min"), scalar.clone()),
142                        (format_pl_smallstr!("{name}_max"), scalar.clone()),
143                        (format_pl_smallstr!("{name}_nc"), null_count),
144                    ]);
145                }
146
147                Some((name, scalar))
148            })
149            .collect();
150
151        let predicate = Arc::new(PhysicalExprWithConstCols {
152            constants: predicate_constants,
153            child: self.predicate.clone(),
154        });
155        let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {
156            Arc::new(PhysicalExprWithConstCols {
157                constants: skip_batch_predicate_constants,
158                child: skp.clone(),
159            }) as _
160        });
161
162        Self {
163            predicate,
164            live_columns: Arc::new(live_columns),
165            skip_batch_predicate,
166            column_predicates: self.column_predicates.clone(), // Q? Maybe this should cull
167            // predicates.
168            hive_predicate: None,
169            hive_predicate_is_full_predicate: false,
170        }
171    }
172
173    /// Create a predicate to skip batches using statistics.
174    pub(crate) fn to_dyn_skip_batch_predicate(
175        &self,
176        schema: SchemaRef,
177    ) -> Option<Arc<dyn SkipBatchPredicate>> {
178        let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
179        Some(Arc::new(SkipBatchPredicateHelper {
180            skip_batch_predicate,
181            schema,
182        }))
183    }
184
185    pub fn to_io(
186        &self,
187        skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
188        schema: SchemaRef,
189    ) -> ScanIOPredicate {
190        ScanIOPredicate {
191            predicate: phys_expr_to_io_expr(self.predicate.clone()),
192            live_columns: self.live_columns.clone(),
193            skip_batch_predicate: skip_batch_predicate
194                .cloned()
195                .or_else(|| self.to_dyn_skip_batch_predicate(schema)),
196            column_predicates: Arc::new(ColumnPredicates {
197                predicates: self
198                    .column_predicates
199                    .predicates
200                    .iter()
201                    .map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))
202                    .collect(),
203                is_sumwise_complete: self.column_predicates.is_sumwise_complete,
204            }),
205            hive_predicate: self.hive_predicate.clone().map(phys_expr_to_io_expr),
206            hive_predicate_is_full_predicate: self.hive_predicate_is_full_predicate,
207        }
208    }
209}
210
211impl SkipBatchPredicate for SkipBatchPredicateHelper {
212    fn schema(&self) -> &SchemaRef {
213        &self.schema
214    }
215
216    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
217        let array = self
218            .skip_batch_predicate
219            .evaluate(df, &Default::default())?;
220        let array = array.bool()?.rechunk();
221        let array = array.downcast_as_array();
222
223        let array = if let Some(validity) = array.validity() {
224            array.values() & validity
225        } else {
226            array.values().clone()
227        };
228
229        // @NOTE: Certain predicates like `1 == 1` will only output 1 value. We need to broadcast
230        // the result back to the dataframe length.
231        if array.len() == 1 && df.height() != 0 {
232            return Ok(Bitmap::new_with_value(array.get_bit(0), df.height()));
233        }
234
235        assert_eq!(array.len(), df.height());
236        Ok(array)
237    }
238}