polars_io/
predicates.rs

1use std::fmt;
2
3use arrow::array::Array;
4use arrow::bitmap::{Bitmap, BitmapBuilder};
5use polars_core::prelude::*;
6#[cfg(feature = "parquet")]
7use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, ParquetScalarRange};
8use polars_utils::format_pl_smallstr;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12pub trait PhysicalIoExpr: Send + Sync {
13    /// Take a [`DataFrame`] and produces a boolean [`Series`] that serves
14    /// as a predicate mask
15    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series>;
16}
17
18#[derive(Debug, Clone)]
19pub enum SpecializedColumnPredicateExpr {
20    Eq(Scalar),
21    EqMissing(Scalar),
22}
23
24#[derive(Clone)]
25pub struct ColumnPredicateExpr {
26    column_name: PlSmallStr,
27    dtype: DataType,
28    specialized: Option<SpecializedColumnPredicateExpr>,
29    expr: Arc<dyn PhysicalIoExpr>,
30}
31
32impl ColumnPredicateExpr {
33    pub fn new(
34        column_name: PlSmallStr,
35        dtype: DataType,
36        expr: Arc<dyn PhysicalIoExpr>,
37        specialized: Option<SpecializedColumnPredicateExpr>,
38    ) -> Self {
39        Self {
40            column_name,
41            dtype,
42            specialized,
43            expr,
44        }
45    }
46
47    pub fn is_eq_scalar(&self) -> bool {
48        self.to_eq_scalar().is_some()
49    }
50    pub fn to_eq_scalar(&self) -> Option<&Scalar> {
51        match &self.specialized {
52            Some(SpecializedColumnPredicateExpr::Eq(sc)) if !sc.is_null() => Some(sc),
53            Some(SpecializedColumnPredicateExpr::EqMissing(sc)) => Some(sc),
54            _ => None,
55        }
56    }
57}
58
59#[cfg(feature = "parquet")]
60impl ParquetColumnExpr for ColumnPredicateExpr {
61    fn evaluate_mut(&self, values: &dyn Array, bm: &mut BitmapBuilder) {
62        // We should never evaluate nulls with this.
63        assert!(values.validity().is_none_or(|v| v.set_bits() == 0));
64
65        // @TODO: Probably these unwraps should be removed.
66        let series =
67            Series::from_chunk_and_dtype(self.column_name.clone(), values.to_boxed(), &self.dtype)
68                .unwrap();
69        let column = series.into_column();
70        let df = unsafe { DataFrame::new_no_checks(values.len(), vec![column]) };
71
72        // @TODO: Probably these unwraps should be removed.
73        let true_mask = self.expr.evaluate_io(&df).unwrap();
74        let true_mask = true_mask.bool().unwrap();
75
76        bm.reserve(true_mask.len());
77        for chunk in true_mask.downcast_iter() {
78            match chunk.validity() {
79                None => bm.extend_from_bitmap(chunk.values()),
80                Some(v) => bm.extend_from_bitmap(&(chunk.values() & v)),
81            }
82        }
83    }
84    fn evaluate_null(&self) -> bool {
85        let column = Column::full_null(self.column_name.clone(), 1, &self.dtype);
86        let df = unsafe { DataFrame::new_no_checks(1, vec![column]) };
87
88        // @TODO: Probably these unwraps should be removed.
89        let true_mask = self.expr.evaluate_io(&df).unwrap();
90        let true_mask = true_mask.bool().unwrap();
91
92        true_mask.get(0).unwrap_or(false)
93    }
94
95    fn to_equals_scalar(&self) -> Option<ParquetScalar> {
96        self.to_eq_scalar()
97            .and_then(|s| cast_to_parquet_scalar(s.clone()))
98    }
99
100    fn to_range_scalar(&self) -> Option<ParquetScalarRange> {
101        None
102    }
103}
104
105#[cfg(feature = "parquet")]
106fn cast_to_parquet_scalar(scalar: Scalar) -> Option<ParquetScalar> {
107    use {AnyValue as A, ParquetScalar as P};
108
109    Some(match scalar.into_value() {
110        A::Null => P::Null,
111        A::Boolean(v) => P::Boolean(v),
112
113        A::UInt8(v) => P::UInt8(v),
114        A::UInt16(v) => P::UInt16(v),
115        A::UInt32(v) => P::UInt32(v),
116        A::UInt64(v) => P::UInt64(v),
117
118        A::Int8(v) => P::Int8(v),
119        A::Int16(v) => P::Int16(v),
120        A::Int32(v) => P::Int32(v),
121        A::Int64(v) => P::Int64(v),
122
123        #[cfg(feature = "dtype-time")]
124        A::Date(v) => P::Int32(v),
125        #[cfg(feature = "dtype-datetime")]
126        A::Datetime(v, _, _) | A::DatetimeOwned(v, _, _) => P::Int64(v),
127        #[cfg(feature = "dtype-duration")]
128        A::Duration(v, _) => P::Int64(v),
129        #[cfg(feature = "dtype-time")]
130        A::Time(v) => P::Int64(v),
131
132        A::Float32(v) => P::Float32(v),
133        A::Float64(v) => P::Float64(v),
134
135        // @TODO: Cast to string
136        #[cfg(feature = "dtype-categorical")]
137        A::Categorical(_, _, _)
138        | A::CategoricalOwned(_, _, _)
139        | A::Enum(_, _, _)
140        | A::EnumOwned(_, _, _) => return None,
141
142        A::String(v) => P::String(v.into()),
143        A::StringOwned(v) => P::String(v.as_str().into()),
144        A::Binary(v) => P::Binary(v.into()),
145        A::BinaryOwned(v) => P::Binary(v.into()),
146        _ => return None,
147    })
148}
149
150#[cfg(any(feature = "parquet", feature = "ipc"))]
151pub fn apply_predicate(
152    df: &mut DataFrame,
153    predicate: Option<&dyn PhysicalIoExpr>,
154    parallel: bool,
155) -> PolarsResult<()> {
156    if let (Some(predicate), false) = (&predicate, df.get_columns().is_empty()) {
157        let s = predicate.evaluate_io(df)?;
158        let mask = s.bool().expect("filter predicates was not of type boolean");
159
160        if parallel {
161            *df = df.filter(mask)?;
162        } else {
163            *df = df._filter_seq(mask)?;
164        }
165    }
166    Ok(())
167}
168
169/// Statistics of the values in a column.
170///
171/// The following statistics are tracked for each row group:
172/// - Null count
173/// - Minimum value
174/// - Maximum value
175#[derive(Debug, Clone)]
176#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
177pub struct ColumnStats {
178    field: Field,
179    // Each Series contains the stats for each row group.
180    null_count: Option<Series>,
181    min_value: Option<Series>,
182    max_value: Option<Series>,
183}
184
185impl ColumnStats {
186    /// Constructs a new [`ColumnStats`].
187    pub fn new(
188        field: Field,
189        null_count: Option<Series>,
190        min_value: Option<Series>,
191        max_value: Option<Series>,
192    ) -> Self {
193        Self {
194            field,
195            null_count,
196            min_value,
197            max_value,
198        }
199    }
200
201    /// Constructs a new [`ColumnStats`] with only the [`Field`] information and no statistics.
202    pub fn from_field(field: Field) -> Self {
203        Self {
204            field,
205            null_count: None,
206            min_value: None,
207            max_value: None,
208        }
209    }
210
211    /// Constructs a new [`ColumnStats`] from a single-value Series.
212    pub fn from_column_literal(s: Series) -> Self {
213        debug_assert_eq!(s.len(), 1);
214        Self {
215            field: s.field().into_owned(),
216            null_count: None,
217            min_value: Some(s.clone()),
218            max_value: Some(s),
219        }
220    }
221
222    pub fn field_name(&self) -> &PlSmallStr {
223        self.field.name()
224    }
225
226    /// Returns the [`DataType`] of the column.
227    pub fn dtype(&self) -> &DataType {
228        self.field.dtype()
229    }
230
231    /// Returns the null count of each row group of the column.
232    pub fn get_null_count_state(&self) -> Option<&Series> {
233        self.null_count.as_ref()
234    }
235
236    /// Returns the minimum value of each row group of the column.
237    pub fn get_min_state(&self) -> Option<&Series> {
238        self.min_value.as_ref()
239    }
240
241    /// Returns the maximum value of each row group of the column.
242    pub fn get_max_state(&self) -> Option<&Series> {
243        self.max_value.as_ref()
244    }
245
246    /// Returns the null count of the column.
247    pub fn null_count(&self) -> Option<usize> {
248        match self.dtype() {
249            #[cfg(feature = "dtype-struct")]
250            DataType::Struct(_) => None,
251            _ => {
252                let s = self.get_null_count_state()?;
253                // if all null, there are no statistics.
254                if s.null_count() != s.len() {
255                    s.sum().ok()
256                } else {
257                    None
258                }
259            },
260        }
261    }
262
263    /// Returns the minimum and maximum values of the column as a single [`Series`].
264    pub fn to_min_max(&self) -> Option<Series> {
265        let min_val = self.get_min_state()?;
266        let max_val = self.get_max_state()?;
267        let dtype = self.dtype();
268
269        if !use_min_max(dtype) {
270            return None;
271        }
272
273        let mut min_max_values = min_val.clone();
274        min_max_values.append(max_val).unwrap();
275        if min_max_values.null_count() > 0 {
276            None
277        } else {
278            Some(min_max_values)
279        }
280    }
281
282    /// Returns the minimum value of the column as a single-value [`Series`].
283    ///
284    /// Returns `None` if no maximum value is available.
285    pub fn to_min(&self) -> Option<&Series> {
286        // @scalar-opt
287        let min_val = self.min_value.as_ref()?;
288        let dtype = min_val.dtype();
289
290        if !use_min_max(dtype) || min_val.len() != 1 {
291            return None;
292        }
293
294        if min_val.null_count() > 0 {
295            None
296        } else {
297            Some(min_val)
298        }
299    }
300
301    /// Returns the maximum value of the column as a single-value [`Series`].
302    ///
303    /// Returns `None` if no maximum value is available.
304    pub fn to_max(&self) -> Option<&Series> {
305        // @scalar-opt
306        let max_val = self.max_value.as_ref()?;
307        let dtype = max_val.dtype();
308
309        if !use_min_max(dtype) || max_val.len() != 1 {
310            return None;
311        }
312
313        if max_val.null_count() > 0 {
314            None
315        } else {
316            Some(max_val)
317        }
318    }
319}
320
321/// Returns whether the [`DataType`] supports minimum/maximum operations.
322fn use_min_max(dtype: &DataType) -> bool {
323    dtype.is_primitive_numeric()
324        || dtype.is_temporal()
325        || matches!(
326            dtype,
327            DataType::String | DataType::Binary | DataType::Boolean
328        )
329}
330
331pub struct ColumnStatistics {
332    pub dtype: DataType,
333    pub min: AnyValue<'static>,
334    pub max: AnyValue<'static>,
335    pub null_count: Option<IdxSize>,
336}
337
338pub trait SkipBatchPredicate: Send + Sync {
339    fn schema(&self) -> &SchemaRef;
340
341    fn can_skip_batch(
342        &self,
343        batch_size: IdxSize,
344        live_columns: &PlIndexSet<PlSmallStr>,
345        mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
346    ) -> PolarsResult<bool> {
347        let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);
348
349        columns.push(Column::new_scalar(
350            PlSmallStr::from_static("len"),
351            Scalar::new(IDX_DTYPE, batch_size.into()),
352            1,
353        ));
354
355        for col in live_columns.iter() {
356            let dtype = self.schema().get(col).unwrap();
357            let (min, max, nc) = match statistics.swap_remove(col) {
358                None => (
359                    Scalar::null(dtype.clone()),
360                    Scalar::null(dtype.clone()),
361                    Scalar::null(IDX_DTYPE),
362                ),
363                Some(stat) => (
364                    Scalar::new(dtype.clone(), stat.min),
365                    Scalar::new(dtype.clone(), stat.max),
366                    Scalar::new(
367                        IDX_DTYPE,
368                        stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
369                    ),
370                ),
371            };
372            columns.extend([
373                Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
374                Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
375                Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
376            ]);
377        }
378
379        // SAFETY:
380        // * Each column is length = 1
381        // * We have an IndexSet, so each column name is unique
382        let df = unsafe { DataFrame::new_no_checks(1, columns) };
383        Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
384    }
385    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
386}
387
388#[derive(Clone)]
389pub struct ColumnPredicates {
390    pub predicates: PlHashMap<
391        PlSmallStr,
392        (
393            Arc<dyn PhysicalIoExpr>,
394            Option<SpecializedColumnPredicateExpr>,
395        ),
396    >,
397    pub is_sumwise_complete: bool,
398}
399
400// I want to be explicit here.
401#[allow(clippy::derivable_impls)]
402impl Default for ColumnPredicates {
403    fn default() -> Self {
404        Self {
405            predicates: PlHashMap::default(),
406            is_sumwise_complete: false,
407        }
408    }
409}
410
411pub struct PhysicalExprWithConstCols<T> {
412    constants: Vec<(PlSmallStr, Scalar)>,
413    child: T,
414}
415
416impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
417    fn schema(&self) -> &SchemaRef {
418        self.child.schema()
419    }
420
421    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
422        let mut df = df.clone();
423        for (name, scalar) in self.constants.iter() {
424            df.with_column(Column::new_scalar(
425                name.clone(),
426                scalar.clone(),
427                df.height(),
428            ))?;
429        }
430        self.child.evaluate_with_stat_df(&df)
431    }
432}
433
434impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
435    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
436        let mut df = df.clone();
437        for (name, scalar) in self.constants.iter() {
438            df.with_column(Column::new_scalar(
439                name.clone(),
440                scalar.clone(),
441                df.height(),
442            ))?;
443        }
444
445        self.child.evaluate_io(&df)
446    }
447}
448
449#[derive(Clone)]
450pub struct ScanIOPredicate {
451    pub predicate: Arc<dyn PhysicalIoExpr>,
452
453    /// Column names that are used in the predicate.
454    pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
455
456    /// A predicate that gets given statistics and evaluates whether a batch can be skipped.
457    pub skip_batch_predicate: Option<Arc<dyn SkipBatchPredicate>>,
458
459    /// A predicate that gets given statistics and evaluates whether a batch can be skipped.
460    pub column_predicates: Arc<ColumnPredicates>,
461
462    /// Predicate parts only referring to hive columns.
463    pub hive_predicate: Option<Arc<dyn PhysicalIoExpr>>,
464
465    pub hive_predicate_is_full_predicate: bool,
466}
467
468impl ScanIOPredicate {
469    pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
470        if constant_columns.is_empty() {
471            return;
472        }
473
474        let mut live_columns = self.live_columns.as_ref().clone();
475        for (c, _) in constant_columns.iter() {
476            live_columns.swap_remove(c);
477        }
478        self.live_columns = Arc::new(live_columns);
479
480        if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
481            let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
482            for (c, v) in constant_columns.iter() {
483                sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
484                sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
485                let nc = if v.is_null() {
486                    AnyValue::Null
487                } else {
488                    (0 as IdxSize).into()
489                };
490                sbp_constant_columns
491                    .push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
492            }
493            self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
494                constants: sbp_constant_columns,
495                child: skip_batch_predicate,
496            }));
497        }
498
499        let mut column_predicates = self.column_predicates.as_ref().clone();
500        for (c, _) in constant_columns.iter() {
501            column_predicates.predicates.remove(c);
502        }
503        self.column_predicates = Arc::new(column_predicates);
504
505        self.predicate = Arc::new(PhysicalExprWithConstCols {
506            constants: constant_columns,
507            child: self.predicate.clone(),
508        });
509    }
510}
511
512impl fmt::Debug for ScanIOPredicate {
513    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514        f.write_str("scan_io_predicate")
515    }
516}