1use std::fmt;
2
3use arrow::array::Array;
4use arrow::bitmap::{Bitmap, BitmapBuilder};
5use polars_core::prelude::*;
6#[cfg(feature = "parquet")]
7use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, ParquetScalarRange};
8use polars_utils::format_pl_smallstr;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12pub trait PhysicalIoExpr: Send + Sync {
13 fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series>;
16}
17
18#[derive(Debug, Clone)]
19pub enum SpecializedColumnPredicateExpr {
20 Eq(Scalar),
21 EqMissing(Scalar),
22}
23
24#[derive(Clone)]
25pub struct ColumnPredicateExpr {
26 column_name: PlSmallStr,
27 dtype: DataType,
28 specialized: Option<SpecializedColumnPredicateExpr>,
29 expr: Arc<dyn PhysicalIoExpr>,
30}
31
32impl ColumnPredicateExpr {
33 pub fn new(
34 column_name: PlSmallStr,
35 dtype: DataType,
36 expr: Arc<dyn PhysicalIoExpr>,
37 specialized: Option<SpecializedColumnPredicateExpr>,
38 ) -> Self {
39 Self {
40 column_name,
41 dtype,
42 specialized,
43 expr,
44 }
45 }
46
47 pub fn is_eq_scalar(&self) -> bool {
48 self.to_eq_scalar().is_some()
49 }
50 pub fn to_eq_scalar(&self) -> Option<&Scalar> {
51 match &self.specialized {
52 Some(SpecializedColumnPredicateExpr::Eq(sc)) if !sc.is_null() => Some(sc),
53 Some(SpecializedColumnPredicateExpr::EqMissing(sc)) => Some(sc),
54 _ => None,
55 }
56 }
57}
58
59#[cfg(feature = "parquet")]
60impl ParquetColumnExpr for ColumnPredicateExpr {
61 fn evaluate_mut(&self, values: &dyn Array, bm: &mut BitmapBuilder) {
62 assert!(values.validity().is_none_or(|v| v.set_bits() == 0));
64
65 let series =
67 Series::from_chunk_and_dtype(self.column_name.clone(), values.to_boxed(), &self.dtype)
68 .unwrap();
69 let column = series.into_column();
70 let df = unsafe { DataFrame::new_no_checks(values.len(), vec![column]) };
71
72 let true_mask = self.expr.evaluate_io(&df).unwrap();
74 let true_mask = true_mask.bool().unwrap();
75
76 bm.reserve(true_mask.len());
77 for chunk in true_mask.downcast_iter() {
78 match chunk.validity() {
79 None => bm.extend_from_bitmap(chunk.values()),
80 Some(v) => bm.extend_from_bitmap(&(chunk.values() & v)),
81 }
82 }
83 }
84 fn evaluate_null(&self) -> bool {
85 let column = Column::full_null(self.column_name.clone(), 1, &self.dtype);
86 let df = unsafe { DataFrame::new_no_checks(1, vec![column]) };
87
88 let true_mask = self.expr.evaluate_io(&df).unwrap();
90 let true_mask = true_mask.bool().unwrap();
91
92 true_mask.get(0).unwrap_or(false)
93 }
94
95 fn to_equals_scalar(&self) -> Option<ParquetScalar> {
96 self.to_eq_scalar()
97 .and_then(|s| cast_to_parquet_scalar(s.clone()))
98 }
99
100 fn to_range_scalar(&self) -> Option<ParquetScalarRange> {
101 None
102 }
103}
104
105#[cfg(feature = "parquet")]
106fn cast_to_parquet_scalar(scalar: Scalar) -> Option<ParquetScalar> {
107 use {AnyValue as A, ParquetScalar as P};
108
109 Some(match scalar.into_value() {
110 A::Null => P::Null,
111 A::Boolean(v) => P::Boolean(v),
112
113 A::UInt8(v) => P::UInt8(v),
114 A::UInt16(v) => P::UInt16(v),
115 A::UInt32(v) => P::UInt32(v),
116 A::UInt64(v) => P::UInt64(v),
117
118 A::Int8(v) => P::Int8(v),
119 A::Int16(v) => P::Int16(v),
120 A::Int32(v) => P::Int32(v),
121 A::Int64(v) => P::Int64(v),
122
123 #[cfg(feature = "dtype-time")]
124 A::Date(v) => P::Int32(v),
125 #[cfg(feature = "dtype-datetime")]
126 A::Datetime(v, _, _) | A::DatetimeOwned(v, _, _) => P::Int64(v),
127 #[cfg(feature = "dtype-duration")]
128 A::Duration(v, _) => P::Int64(v),
129 #[cfg(feature = "dtype-time")]
130 A::Time(v) => P::Int64(v),
131
132 A::Float32(v) => P::Float32(v),
133 A::Float64(v) => P::Float64(v),
134
135 #[cfg(feature = "dtype-categorical")]
137 A::Categorical(_, _, _)
138 | A::CategoricalOwned(_, _, _)
139 | A::Enum(_, _, _)
140 | A::EnumOwned(_, _, _) => return None,
141
142 A::String(v) => P::String(v.into()),
143 A::StringOwned(v) => P::String(v.as_str().into()),
144 A::Binary(v) => P::Binary(v.into()),
145 A::BinaryOwned(v) => P::Binary(v.into()),
146 _ => return None,
147 })
148}
149
150#[cfg(any(feature = "parquet", feature = "ipc"))]
151pub fn apply_predicate(
152 df: &mut DataFrame,
153 predicate: Option<&dyn PhysicalIoExpr>,
154 parallel: bool,
155) -> PolarsResult<()> {
156 if let (Some(predicate), false) = (&predicate, df.get_columns().is_empty()) {
157 let s = predicate.evaluate_io(df)?;
158 let mask = s.bool().expect("filter predicates was not of type boolean");
159
160 if parallel {
161 *df = df.filter(mask)?;
162 } else {
163 *df = df._filter_seq(mask)?;
164 }
165 }
166 Ok(())
167}
168
169#[derive(Debug, Clone)]
176#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
177pub struct ColumnStats {
178 field: Field,
179 null_count: Option<Series>,
181 min_value: Option<Series>,
182 max_value: Option<Series>,
183}
184
185impl ColumnStats {
186 pub fn new(
188 field: Field,
189 null_count: Option<Series>,
190 min_value: Option<Series>,
191 max_value: Option<Series>,
192 ) -> Self {
193 Self {
194 field,
195 null_count,
196 min_value,
197 max_value,
198 }
199 }
200
201 pub fn from_field(field: Field) -> Self {
203 Self {
204 field,
205 null_count: None,
206 min_value: None,
207 max_value: None,
208 }
209 }
210
211 pub fn from_column_literal(s: Series) -> Self {
213 debug_assert_eq!(s.len(), 1);
214 Self {
215 field: s.field().into_owned(),
216 null_count: None,
217 min_value: Some(s.clone()),
218 max_value: Some(s),
219 }
220 }
221
222 pub fn field_name(&self) -> &PlSmallStr {
223 self.field.name()
224 }
225
226 pub fn dtype(&self) -> &DataType {
228 self.field.dtype()
229 }
230
231 pub fn get_null_count_state(&self) -> Option<&Series> {
233 self.null_count.as_ref()
234 }
235
236 pub fn get_min_state(&self) -> Option<&Series> {
238 self.min_value.as_ref()
239 }
240
241 pub fn get_max_state(&self) -> Option<&Series> {
243 self.max_value.as_ref()
244 }
245
246 pub fn null_count(&self) -> Option<usize> {
248 match self.dtype() {
249 #[cfg(feature = "dtype-struct")]
250 DataType::Struct(_) => None,
251 _ => {
252 let s = self.get_null_count_state()?;
253 if s.null_count() != s.len() {
255 s.sum().ok()
256 } else {
257 None
258 }
259 },
260 }
261 }
262
263 pub fn to_min_max(&self) -> Option<Series> {
265 let min_val = self.get_min_state()?;
266 let max_val = self.get_max_state()?;
267 let dtype = self.dtype();
268
269 if !use_min_max(dtype) {
270 return None;
271 }
272
273 let mut min_max_values = min_val.clone();
274 min_max_values.append(max_val).unwrap();
275 if min_max_values.null_count() > 0 {
276 None
277 } else {
278 Some(min_max_values)
279 }
280 }
281
282 pub fn to_min(&self) -> Option<&Series> {
286 let min_val = self.min_value.as_ref()?;
288 let dtype = min_val.dtype();
289
290 if !use_min_max(dtype) || min_val.len() != 1 {
291 return None;
292 }
293
294 if min_val.null_count() > 0 {
295 None
296 } else {
297 Some(min_val)
298 }
299 }
300
301 pub fn to_max(&self) -> Option<&Series> {
305 let max_val = self.max_value.as_ref()?;
307 let dtype = max_val.dtype();
308
309 if !use_min_max(dtype) || max_val.len() != 1 {
310 return None;
311 }
312
313 if max_val.null_count() > 0 {
314 None
315 } else {
316 Some(max_val)
317 }
318 }
319}
320
321fn use_min_max(dtype: &DataType) -> bool {
323 dtype.is_primitive_numeric()
324 || dtype.is_temporal()
325 || matches!(
326 dtype,
327 DataType::String | DataType::Binary | DataType::Boolean
328 )
329}
330
331pub struct ColumnStatistics {
332 pub dtype: DataType,
333 pub min: AnyValue<'static>,
334 pub max: AnyValue<'static>,
335 pub null_count: Option<IdxSize>,
336}
337
338pub trait SkipBatchPredicate: Send + Sync {
339 fn schema(&self) -> &SchemaRef;
340
341 fn can_skip_batch(
342 &self,
343 batch_size: IdxSize,
344 live_columns: &PlIndexSet<PlSmallStr>,
345 mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
346 ) -> PolarsResult<bool> {
347 let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);
348
349 columns.push(Column::new_scalar(
350 PlSmallStr::from_static("len"),
351 Scalar::new(IDX_DTYPE, batch_size.into()),
352 1,
353 ));
354
355 for col in live_columns.iter() {
356 let dtype = self.schema().get(col).unwrap();
357 let (min, max, nc) = match statistics.swap_remove(col) {
358 None => (
359 Scalar::null(dtype.clone()),
360 Scalar::null(dtype.clone()),
361 Scalar::null(IDX_DTYPE),
362 ),
363 Some(stat) => (
364 Scalar::new(dtype.clone(), stat.min),
365 Scalar::new(dtype.clone(), stat.max),
366 Scalar::new(
367 IDX_DTYPE,
368 stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
369 ),
370 ),
371 };
372 columns.extend([
373 Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
374 Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
375 Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
376 ]);
377 }
378
379 let df = unsafe { DataFrame::new_no_checks(1, columns) };
383 Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
384 }
385 fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
386}
387
388#[derive(Clone)]
389pub struct ColumnPredicates {
390 pub predicates: PlHashMap<
391 PlSmallStr,
392 (
393 Arc<dyn PhysicalIoExpr>,
394 Option<SpecializedColumnPredicateExpr>,
395 ),
396 >,
397 pub is_sumwise_complete: bool,
398}
399
400#[allow(clippy::derivable_impls)]
402impl Default for ColumnPredicates {
403 fn default() -> Self {
404 Self {
405 predicates: PlHashMap::default(),
406 is_sumwise_complete: false,
407 }
408 }
409}
410
411pub struct PhysicalExprWithConstCols<T> {
412 constants: Vec<(PlSmallStr, Scalar)>,
413 child: T,
414}
415
416impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
417 fn schema(&self) -> &SchemaRef {
418 self.child.schema()
419 }
420
421 fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
422 let mut df = df.clone();
423 for (name, scalar) in self.constants.iter() {
424 df.with_column(Column::new_scalar(
425 name.clone(),
426 scalar.clone(),
427 df.height(),
428 ))?;
429 }
430 self.child.evaluate_with_stat_df(&df)
431 }
432}
433
434impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
435 fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
436 let mut df = df.clone();
437 for (name, scalar) in self.constants.iter() {
438 df.with_column(Column::new_scalar(
439 name.clone(),
440 scalar.clone(),
441 df.height(),
442 ))?;
443 }
444
445 self.child.evaluate_io(&df)
446 }
447}
448
449#[derive(Clone)]
450pub struct ScanIOPredicate {
451 pub predicate: Arc<dyn PhysicalIoExpr>,
452
453 pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
455
456 pub skip_batch_predicate: Option<Arc<dyn SkipBatchPredicate>>,
458
459 pub column_predicates: Arc<ColumnPredicates>,
461
462 pub hive_predicate: Option<Arc<dyn PhysicalIoExpr>>,
464
465 pub hive_predicate_is_full_predicate: bool,
466}
467
468impl ScanIOPredicate {
469 pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
470 if constant_columns.is_empty() {
471 return;
472 }
473
474 let mut live_columns = self.live_columns.as_ref().clone();
475 for (c, _) in constant_columns.iter() {
476 live_columns.swap_remove(c);
477 }
478 self.live_columns = Arc::new(live_columns);
479
480 if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
481 let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
482 for (c, v) in constant_columns.iter() {
483 sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
484 sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
485 let nc = if v.is_null() {
486 AnyValue::Null
487 } else {
488 (0 as IdxSize).into()
489 };
490 sbp_constant_columns
491 .push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
492 }
493 self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
494 constants: sbp_constant_columns,
495 child: skip_batch_predicate,
496 }));
497 }
498
499 let mut column_predicates = self.column_predicates.as_ref().clone();
500 for (c, _) in constant_columns.iter() {
501 column_predicates.predicates.remove(c);
502 }
503 self.column_predicates = Arc::new(column_predicates);
504
505 self.predicate = Arc::new(PhysicalExprWithConstCols {
506 constants: constant_columns,
507 child: self.predicate.clone(),
508 });
509 }
510}
511
512impl fmt::Debug for ScanIOPredicate {
513 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514 f.write_str("scan_io_predicate")
515 }
516}