1use core::fmt;
2use std::sync::Arc;
3
4use arrow::bitmap::Bitmap;
5use polars_core::frame::DataFrame;
6use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
7use polars_core::scalar::Scalar;
8use polars_core::schema::{Schema, SchemaRef};
9use polars_error::PolarsResult;
10use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};
11use polars_expr::state::ExecutionState;
12use polars_io::predicates::{
13 ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicate,
14};
15use polars_utils::pl_str::PlSmallStr;
16use polars_utils::{IdxSize, format_pl_smallstr};
17
18#[derive(Clone)]
20pub struct ScanPredicate {
21 pub predicate: Arc<dyn PhysicalExpr>,
22
23 pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
25
26 pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,
33
34 pub column_predicates: PhysicalColumnPredicates,
36
37 pub hive_predicate: Option<Arc<dyn PhysicalExpr>>,
39 pub hive_predicate_is_full_predicate: bool,
40}
41
42impl fmt::Debug for ScanPredicate {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.write_str("scan_predicate")
45 }
46}
47
48#[derive(Clone)]
49pub struct PhysicalColumnPredicates {
50 pub predicates:
51 PlHashMap<PlSmallStr, (Arc<dyn PhysicalExpr>, Option<SpecializedColumnPredicate>)>,
52 pub is_sumwise_complete: bool,
53}
54
55struct SkipBatchPredicateHelper {
57 skip_batch_predicate: Arc<dyn PhysicalExpr>,
58 schema: SchemaRef,
59}
60
61pub struct PhysicalExprWithConstCols {
63 constants: Vec<(PlSmallStr, Scalar)>,
64 child: Arc<dyn PhysicalExpr>,
65}
66
67impl PhysicalExpr for PhysicalExprWithConstCols {
68 fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
69 let mut df = df.clone();
70 for (name, scalar) in &self.constants {
71 df.with_column(Column::new_scalar(
72 name.clone(),
73 scalar.clone(),
74 df.height(),
75 ))?;
76 }
77
78 self.child.evaluate(&df, state)
79 }
80
81 fn evaluate_on_groups<'a>(
82 &self,
83 df: &DataFrame,
84 groups: &'a GroupPositions,
85 state: &ExecutionState,
86 ) -> PolarsResult<AggregationContext<'a>> {
87 let mut df = df.clone();
88 for (name, scalar) in &self.constants {
89 df.with_column(Column::new_scalar(
90 name.clone(),
91 scalar.clone(),
92 df.height(),
93 ))?;
94 }
95
96 self.child.evaluate_on_groups(&df, groups, state)
97 }
98
99 fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
100 self.child.to_field(input_schema)
101 }
102 fn is_scalar(&self) -> bool {
103 self.child.is_scalar()
104 }
105}
106
107impl ScanPredicate {
108 pub fn with_constant_columns(
109 &self,
110 constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,
111 ) -> Self {
112 let constant_columns = constant_columns.into_iter();
113
114 let mut live_columns = self.live_columns.as_ref().clone();
115 let mut skip_batch_predicate_constants =
116 Vec::with_capacity(if self.skip_batch_predicate.is_some() {
117 1 + constant_columns.size_hint().0 * 3
118 } else {
119 Default::default()
120 });
121
122 let predicate_constants = constant_columns
123 .filter_map(|(name, scalar): (PlSmallStr, Scalar)| {
124 if !live_columns.swap_remove(&name) {
125 return None;
126 }
127
128 if self.skip_batch_predicate.is_some() {
129 let mut null_count: Scalar = (0 as IdxSize).into();
130
131 if scalar.is_null() {
134 null_count.update(AnyValue::Null);
135 }
136
137 skip_batch_predicate_constants.extend([
138 (format_pl_smallstr!("{name}_min"), scalar.clone()),
139 (format_pl_smallstr!("{name}_max"), scalar.clone()),
140 (format_pl_smallstr!("{name}_nc"), null_count),
141 ]);
142 }
143
144 Some((name, scalar))
145 })
146 .collect();
147
148 let predicate = Arc::new(PhysicalExprWithConstCols {
149 constants: predicate_constants,
150 child: self.predicate.clone(),
151 });
152 let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {
153 Arc::new(PhysicalExprWithConstCols {
154 constants: skip_batch_predicate_constants,
155 child: skp.clone(),
156 }) as _
157 });
158
159 Self {
160 predicate,
161 live_columns: Arc::new(live_columns),
162 skip_batch_predicate,
163 column_predicates: self.column_predicates.clone(), hive_predicate: None,
166 hive_predicate_is_full_predicate: false,
167 }
168 }
169
170 pub(crate) fn to_dyn_skip_batch_predicate(
172 &self,
173 schema: SchemaRef,
174 ) -> Option<Arc<dyn SkipBatchPredicate>> {
175 let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
176 Some(Arc::new(SkipBatchPredicateHelper {
177 skip_batch_predicate,
178 schema,
179 }))
180 }
181
182 pub fn to_io(
183 &self,
184 skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
185 schema: SchemaRef,
186 ) -> ScanIOPredicate {
187 ScanIOPredicate {
188 predicate: phys_expr_to_io_expr(self.predicate.clone()),
189 live_columns: self.live_columns.clone(),
190 skip_batch_predicate: skip_batch_predicate
191 .cloned()
192 .or_else(|| self.to_dyn_skip_batch_predicate(schema)),
193 column_predicates: Arc::new(ColumnPredicates {
194 predicates: self
195 .column_predicates
196 .predicates
197 .iter()
198 .map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))
199 .collect(),
200 is_sumwise_complete: self.column_predicates.is_sumwise_complete,
201 }),
202 hive_predicate: self.hive_predicate.clone().map(phys_expr_to_io_expr),
203 hive_predicate_is_full_predicate: self.hive_predicate_is_full_predicate,
204 }
205 }
206}
207
208impl SkipBatchPredicate for SkipBatchPredicateHelper {
209 fn schema(&self) -> &SchemaRef {
210 &self.schema
211 }
212
213 fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
214 let array = self
215 .skip_batch_predicate
216 .evaluate(df, &Default::default())?;
217 let array = array.bool()?;
218 let array = array.downcast_as_array();
219
220 if let Some(validity) = array.validity() {
221 Ok(array.values() & validity)
222 } else {
223 Ok(array.values().clone())
224 }
225 }
226}