1use core::fmt;
2use std::sync::Arc;
3
4use arrow::bitmap::Bitmap;
5use polars_core::frame::DataFrame;
6use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
7use polars_core::scalar::Scalar;
8use polars_core::schema::{Schema, SchemaRef};
9use polars_error::PolarsResult;
10use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};
11use polars_expr::state::ExecutionState;
12use polars_io::predicates::{
13 ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicateExpr,
14};
15use polars_utils::pl_str::PlSmallStr;
16use polars_utils::{IdxSize, format_pl_smallstr};
17
18#[derive(Clone)]
20pub struct ScanPredicate {
21 pub predicate: Arc<dyn PhysicalExpr>,
22
23 pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
25
26 pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,
33
34 pub column_predicates: PhysicalColumnPredicates,
36
37 pub hive_predicate: Option<Arc<dyn PhysicalExpr>>,
39 pub hive_predicate_is_full_predicate: bool,
40}
41
42impl fmt::Debug for ScanPredicate {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.write_str("scan_predicate")
45 }
46}
47
48#[derive(Clone)]
49pub struct PhysicalColumnPredicates {
50 pub predicates: PlHashMap<
51 PlSmallStr,
52 (
53 Arc<dyn PhysicalExpr>,
54 Option<SpecializedColumnPredicateExpr>,
55 ),
56 >,
57 pub is_sumwise_complete: bool,
58}
59
60struct SkipBatchPredicateHelper {
62 skip_batch_predicate: Arc<dyn PhysicalExpr>,
63 schema: SchemaRef,
64}
65
66pub struct PhysicalExprWithConstCols {
68 constants: Vec<(PlSmallStr, Scalar)>,
69 child: Arc<dyn PhysicalExpr>,
70}
71
72impl PhysicalExpr for PhysicalExprWithConstCols {
73 fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
74 let mut df = df.clone();
75 for (name, scalar) in &self.constants {
76 df.with_column(Column::new_scalar(
77 name.clone(),
78 scalar.clone(),
79 df.height(),
80 ))?;
81 }
82
83 self.child.evaluate(&df, state)
84 }
85
86 fn evaluate_on_groups<'a>(
87 &self,
88 df: &DataFrame,
89 groups: &'a GroupPositions,
90 state: &ExecutionState,
91 ) -> PolarsResult<AggregationContext<'a>> {
92 let mut df = df.clone();
93 for (name, scalar) in &self.constants {
94 df.with_column(Column::new_scalar(
95 name.clone(),
96 scalar.clone(),
97 df.height(),
98 ))?;
99 }
100
101 self.child.evaluate_on_groups(&df, groups, state)
102 }
103
104 fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
105 self.child.to_field(input_schema)
106 }
107 fn is_scalar(&self) -> bool {
108 self.child.is_scalar()
109 }
110}
111
112impl ScanPredicate {
113 pub fn with_constant_columns(
114 &self,
115 constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,
116 ) -> Self {
117 let constant_columns = constant_columns.into_iter();
118
119 let mut live_columns = self.live_columns.as_ref().clone();
120 let mut skip_batch_predicate_constants =
121 Vec::with_capacity(if self.skip_batch_predicate.is_some() {
122 1 + constant_columns.size_hint().0 * 3
123 } else {
124 Default::default()
125 });
126
127 let predicate_constants = constant_columns
128 .filter_map(|(name, scalar): (PlSmallStr, Scalar)| {
129 if !live_columns.swap_remove(&name) {
130 return None;
131 }
132
133 if self.skip_batch_predicate.is_some() {
134 let mut null_count: Scalar = (0 as IdxSize).into();
135
136 if scalar.is_null() {
139 null_count.update(AnyValue::Null);
140 }
141
142 skip_batch_predicate_constants.extend([
143 (format_pl_smallstr!("{name}_min"), scalar.clone()),
144 (format_pl_smallstr!("{name}_max"), scalar.clone()),
145 (format_pl_smallstr!("{name}_nc"), null_count),
146 ]);
147 }
148
149 Some((name, scalar))
150 })
151 .collect();
152
153 let predicate = Arc::new(PhysicalExprWithConstCols {
154 constants: predicate_constants,
155 child: self.predicate.clone(),
156 });
157 let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {
158 Arc::new(PhysicalExprWithConstCols {
159 constants: skip_batch_predicate_constants,
160 child: skp.clone(),
161 }) as _
162 });
163
164 Self {
165 predicate,
166 live_columns: Arc::new(live_columns),
167 skip_batch_predicate,
168 column_predicates: self.column_predicates.clone(), hive_predicate: None,
171 hive_predicate_is_full_predicate: false,
172 }
173 }
174
175 pub(crate) fn to_dyn_skip_batch_predicate(
177 &self,
178 schema: SchemaRef,
179 ) -> Option<Arc<dyn SkipBatchPredicate>> {
180 let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
181 Some(Arc::new(SkipBatchPredicateHelper {
182 skip_batch_predicate,
183 schema,
184 }))
185 }
186
187 pub fn to_io(
188 &self,
189 skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
190 schema: SchemaRef,
191 ) -> ScanIOPredicate {
192 ScanIOPredicate {
193 predicate: phys_expr_to_io_expr(self.predicate.clone()),
194 live_columns: self.live_columns.clone(),
195 skip_batch_predicate: skip_batch_predicate
196 .cloned()
197 .or_else(|| self.to_dyn_skip_batch_predicate(schema)),
198 column_predicates: Arc::new(ColumnPredicates {
199 predicates: self
200 .column_predicates
201 .predicates
202 .iter()
203 .map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))
204 .collect(),
205 is_sumwise_complete: self.column_predicates.is_sumwise_complete,
206 }),
207 hive_predicate: self.hive_predicate.clone().map(phys_expr_to_io_expr),
208 hive_predicate_is_full_predicate: self.hive_predicate_is_full_predicate,
209 }
210 }
211}
212
213impl SkipBatchPredicate for SkipBatchPredicateHelper {
214 fn schema(&self) -> &SchemaRef {
215 &self.schema
216 }
217
218 fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
219 let array = self
220 .skip_batch_predicate
221 .evaluate(df, &Default::default())?;
222 let array = array.bool()?;
223 let array = array.downcast_as_array();
224
225 if let Some(validity) = array.validity() {
226 Ok(array.values() & validity)
227 } else {
228 Ok(array.values().clone())
229 }
230 }
231}