polars_mem_engine/scan_predicate/
mod.rs1pub mod functions;
2pub mod skip_files_mask;
3use core::fmt;
4use std::sync::Arc;
5
6use arrow::bitmap::Bitmap;
7pub use functions::{create_scan_predicate, initialize_scan_predicate};
8use polars_core::frame::DataFrame;
9use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
10use polars_core::scalar::Scalar;
11use polars_core::schema::{Schema, SchemaRef};
12use polars_error::PolarsResult;
13use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};
14use polars_expr::state::ExecutionState;
15use polars_io::predicates::{
16 ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicate,
17};
18use polars_utils::pl_str::PlSmallStr;
19use polars_utils::{IdxSize, format_pl_smallstr};
20
21#[derive(Clone)]
23pub struct ScanPredicate {
24 pub predicate: Arc<dyn PhysicalExpr>,
25
26 pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
28
29 pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,
36
37 pub column_predicates: PhysicalColumnPredicates,
39
40 pub hive_predicate: Option<Arc<dyn PhysicalExpr>>,
42 pub hive_predicate_is_full_predicate: bool,
43}
44
45impl fmt::Debug for ScanPredicate {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.write_str("scan_predicate")
48 }
49}
50
51#[derive(Clone)]
52pub struct PhysicalColumnPredicates {
53 pub predicates:
54 PlHashMap<PlSmallStr, (Arc<dyn PhysicalExpr>, Option<SpecializedColumnPredicate>)>,
55 pub is_sumwise_complete: bool,
56}
57
58struct SkipBatchPredicateHelper {
60 skip_batch_predicate: Arc<dyn PhysicalExpr>,
61 schema: SchemaRef,
62}
63
64pub struct PhysicalExprWithConstCols {
66 constants: Vec<(PlSmallStr, Scalar)>,
67 child: Arc<dyn PhysicalExpr>,
68}
69
70impl PhysicalExpr for PhysicalExprWithConstCols {
71 fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
72 let mut df = df.clone();
73 for (name, scalar) in &self.constants {
74 df.with_column(Column::new_scalar(
75 name.clone(),
76 scalar.clone(),
77 df.height(),
78 ))?;
79 }
80
81 self.child.evaluate(&df, state)
82 }
83
84 fn evaluate_on_groups<'a>(
85 &self,
86 df: &DataFrame,
87 groups: &'a GroupPositions,
88 state: &ExecutionState,
89 ) -> PolarsResult<AggregationContext<'a>> {
90 let mut df = df.clone();
91 for (name, scalar) in &self.constants {
92 df.with_column(Column::new_scalar(
93 name.clone(),
94 scalar.clone(),
95 df.height(),
96 ))?;
97 }
98
99 self.child.evaluate_on_groups(&df, groups, state)
100 }
101
102 fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
103 self.child.to_field(input_schema)
104 }
105 fn is_scalar(&self) -> bool {
106 self.child.is_scalar()
107 }
108}
109
110impl ScanPredicate {
111 pub fn with_constant_columns(
112 &self,
113 constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,
114 ) -> Self {
115 let constant_columns = constant_columns.into_iter();
116
117 let mut live_columns = self.live_columns.as_ref().clone();
118 let mut skip_batch_predicate_constants =
119 Vec::with_capacity(if self.skip_batch_predicate.is_some() {
120 1 + constant_columns.size_hint().0 * 3
121 } else {
122 Default::default()
123 });
124
125 let predicate_constants = constant_columns
126 .filter_map(|(name, scalar): (PlSmallStr, Scalar)| {
127 if !live_columns.swap_remove(&name) {
128 return None;
129 }
130
131 if self.skip_batch_predicate.is_some() {
132 let mut null_count: Scalar = (0 as IdxSize).into();
133
134 if scalar.is_null() {
137 null_count.update(AnyValue::Null);
138 }
139
140 skip_batch_predicate_constants.extend([
141 (format_pl_smallstr!("{name}_min"), scalar.clone()),
142 (format_pl_smallstr!("{name}_max"), scalar.clone()),
143 (format_pl_smallstr!("{name}_nc"), null_count),
144 ]);
145 }
146
147 Some((name, scalar))
148 })
149 .collect();
150
151 let predicate = Arc::new(PhysicalExprWithConstCols {
152 constants: predicate_constants,
153 child: self.predicate.clone(),
154 });
155 let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {
156 Arc::new(PhysicalExprWithConstCols {
157 constants: skip_batch_predicate_constants,
158 child: skp.clone(),
159 }) as _
160 });
161
162 Self {
163 predicate,
164 live_columns: Arc::new(live_columns),
165 skip_batch_predicate,
166 column_predicates: self.column_predicates.clone(), hive_predicate: None,
169 hive_predicate_is_full_predicate: false,
170 }
171 }
172
173 pub(crate) fn to_dyn_skip_batch_predicate(
175 &self,
176 schema: SchemaRef,
177 ) -> Option<Arc<dyn SkipBatchPredicate>> {
178 let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
179 Some(Arc::new(SkipBatchPredicateHelper {
180 skip_batch_predicate,
181 schema,
182 }))
183 }
184
185 pub fn to_io(
186 &self,
187 skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
188 schema: SchemaRef,
189 ) -> ScanIOPredicate {
190 ScanIOPredicate {
191 predicate: phys_expr_to_io_expr(self.predicate.clone()),
192 live_columns: self.live_columns.clone(),
193 skip_batch_predicate: skip_batch_predicate
194 .cloned()
195 .or_else(|| self.to_dyn_skip_batch_predicate(schema)),
196 column_predicates: Arc::new(ColumnPredicates {
197 predicates: self
198 .column_predicates
199 .predicates
200 .iter()
201 .map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))
202 .collect(),
203 is_sumwise_complete: self.column_predicates.is_sumwise_complete,
204 }),
205 hive_predicate: self.hive_predicate.clone().map(phys_expr_to_io_expr),
206 hive_predicate_is_full_predicate: self.hive_predicate_is_full_predicate,
207 }
208 }
209}
210
211impl SkipBatchPredicate for SkipBatchPredicateHelper {
212 fn schema(&self) -> &SchemaRef {
213 &self.schema
214 }
215
216 fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
217 let array = self
218 .skip_batch_predicate
219 .evaluate(df, &Default::default())?;
220 let array = array.bool()?.rechunk();
221 let array = array.downcast_as_array();
222
223 let array = if let Some(validity) = array.validity() {
224 array.values() & validity
225 } else {
226 array.values().clone()
227 };
228
229 if array.len() == 1 && df.height() != 0 {
232 return Ok(Bitmap::new_with_value(array.get_bit(0), df.height()));
233 }
234
235 assert_eq!(array.len(), df.height());
236 Ok(array)
237 }
238}