1use core::fmt;
2use std::sync::Arc;
3
4use arrow::bitmap::Bitmap;
5use polars_core::frame::DataFrame;
6use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
7use polars_core::scalar::Scalar;
8use polars_core::schema::{Schema, SchemaRef};
9use polars_error::PolarsResult;
10use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};
11use polars_expr::state::ExecutionState;
12use polars_io::predicates::{
13 ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicateExpr,
14};
15use polars_utils::pl_str::PlSmallStr;
16use polars_utils::{IdxSize, format_pl_smallstr};
17
18#[derive(Clone)]
20pub struct ScanPredicate {
21 pub predicate: Arc<dyn PhysicalExpr>,
22
23 pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
25
26 pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,
33
34 pub column_predicates: PhysicalColumnPredicates,
36}
37
38impl fmt::Debug for ScanPredicate {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.write_str("scan_predicate")
41 }
42}
43
44#[derive(Clone)]
45pub struct PhysicalColumnPredicates {
46 pub predicates: PlHashMap<
47 PlSmallStr,
48 (
49 Arc<dyn PhysicalExpr>,
50 Option<SpecializedColumnPredicateExpr>,
51 ),
52 >,
53 pub is_sumwise_complete: bool,
54}
55
56struct SkipBatchPredicateHelper {
58 skip_batch_predicate: Arc<dyn PhysicalExpr>,
59 schema: SchemaRef,
60}
61
62pub struct PhysicalExprWithConstCols {
64 constants: Vec<(PlSmallStr, Scalar)>,
65 child: Arc<dyn PhysicalExpr>,
66}
67
68impl PhysicalExpr for PhysicalExprWithConstCols {
69 fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
70 let mut df = df.clone();
71 for (name, scalar) in &self.constants {
72 df.with_column(Column::new_scalar(
73 name.clone(),
74 scalar.clone(),
75 df.height(),
76 ))?;
77 }
78
79 self.child.evaluate(&df, state)
80 }
81
82 fn evaluate_on_groups<'a>(
83 &self,
84 df: &DataFrame,
85 groups: &'a GroupPositions,
86 state: &ExecutionState,
87 ) -> PolarsResult<AggregationContext<'a>> {
88 let mut df = df.clone();
89 for (name, scalar) in &self.constants {
90 df.with_column(Column::new_scalar(
91 name.clone(),
92 scalar.clone(),
93 df.height(),
94 ))?;
95 }
96
97 self.child.evaluate_on_groups(&df, groups, state)
98 }
99
100 fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
101 self.child.to_field(input_schema)
102 }
103 fn is_scalar(&self) -> bool {
104 self.child.is_scalar()
105 }
106}
107
108impl ScanPredicate {
109 pub fn with_constant_columns(
110 &self,
111 constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,
112 ) -> Self {
113 let constant_columns = constant_columns.into_iter();
114
115 let mut live_columns = self.live_columns.as_ref().clone();
116 let mut skip_batch_predicate_constants = Vec::with_capacity(
117 self.skip_batch_predicate
118 .is_some()
119 .then_some(1 + constant_columns.size_hint().0 * 3)
120 .unwrap_or_default(),
121 );
122
123 let predicate_constants = constant_columns
124 .filter_map(|(name, scalar): (PlSmallStr, Scalar)| {
125 if !live_columns.swap_remove(&name) {
126 return None;
127 }
128
129 if self.skip_batch_predicate.is_some() {
130 let mut null_count: Scalar = (0 as IdxSize).into();
131
132 if scalar.is_null() {
135 null_count.update(AnyValue::Null);
136 }
137
138 skip_batch_predicate_constants.extend([
139 (format_pl_smallstr!("{name}_min"), scalar.clone()),
140 (format_pl_smallstr!("{name}_max"), scalar.clone()),
141 (format_pl_smallstr!("{name}_nc"), null_count),
142 ]);
143 }
144
145 Some((name, scalar))
146 })
147 .collect();
148
149 let predicate = Arc::new(PhysicalExprWithConstCols {
150 constants: predicate_constants,
151 child: self.predicate.clone(),
152 });
153 let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {
154 Arc::new(PhysicalExprWithConstCols {
155 constants: skip_batch_predicate_constants,
156 child: skp.clone(),
157 }) as _
158 });
159
160 Self {
161 predicate,
162 live_columns: Arc::new(live_columns),
163 skip_batch_predicate,
164 column_predicates: self.column_predicates.clone(), }
167 }
168
169 pub(crate) fn to_dyn_skip_batch_predicate(
171 &self,
172 schema: SchemaRef,
173 ) -> Option<Arc<dyn SkipBatchPredicate>> {
174 let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
175 Some(Arc::new(SkipBatchPredicateHelper {
176 skip_batch_predicate,
177 schema,
178 }))
179 }
180
181 pub fn to_io(
182 &self,
183 skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
184 schema: SchemaRef,
185 ) -> ScanIOPredicate {
186 ScanIOPredicate {
187 predicate: phys_expr_to_io_expr(self.predicate.clone()),
188 live_columns: self.live_columns.clone(),
189 skip_batch_predicate: skip_batch_predicate
190 .cloned()
191 .or_else(|| self.to_dyn_skip_batch_predicate(schema)),
192 column_predicates: Arc::new(ColumnPredicates {
193 predicates: self
194 .column_predicates
195 .predicates
196 .iter()
197 .map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))
198 .collect(),
199 is_sumwise_complete: self.column_predicates.is_sumwise_complete,
200 }),
201 }
202 }
203}
204
205impl SkipBatchPredicate for SkipBatchPredicateHelper {
206 fn schema(&self) -> &SchemaRef {
207 &self.schema
208 }
209
210 fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
211 let array = self
212 .skip_batch_predicate
213 .evaluate(df, &Default::default())?;
214 let array = array.bool()?;
215 let array = array.downcast_as_array();
216
217 if let Some(validity) = array.validity() {
218 Ok(array.values() & validity)
219 } else {
220 Ok(array.values().clone())
221 }
222 }
223}