1use polars_utils::idx_vec::UnitVec;
2use polars_utils::unitvec;
3
4use super::*;
5use crate::constants::MAP_LIST_NAME;
6
7impl AExpr {
8 pub(crate) fn is_leaf(&self) -> bool {
9 matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)
10 }
11
12 pub(crate) fn is_col(&self) -> bool {
13 matches!(self, AExpr::Column(_))
14 }
15
16 pub(crate) fn is_elementwise_top_level(&self) -> bool {
18 use AExpr::*;
19
20 match self {
21 AnonymousFunction { options, .. } => options.is_elementwise(),
22
23 Function { options, .. } => options.is_elementwise(),
24
25 Literal(v) => v.is_scalar(),
26
27 Eval { variant, .. } => match variant {
28 EvalVariant::List => true,
29 EvalVariant::Cumulative { min_samples: _ } => false,
30 },
31
32 BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,
33
34 Agg { .. }
35 | Explode { .. }
36 | Filter { .. }
37 | Gather { .. }
38 | Len
39 | Slice { .. }
40 | Sort { .. }
41 | SortBy { .. }
42 | Window { .. } => false,
43 }
44 }
45
46 pub(crate) fn does_not_modify_top_level(&self) -> bool {
47 match self {
48 AExpr::Column(_) => true,
49 AExpr::Function { function, .. } => {
50 matches!(function, IRFunctionExpr::SetSortedFlag(_))
51 },
52 _ => false,
53 }
54 }
55}
56
57fn property_and_traverse<F>(stack: &mut UnitVec<Node>, ae: &AExpr, property: F) -> bool
59where
60 F: Fn(&AExpr) -> bool,
61{
62 if !property(ae) {
63 return false;
64 }
65 ae.inputs_rev(stack);
66 true
67}
68
69fn property_rec<F>(node: Node, expr_arena: &Arena<AExpr>, property: F) -> bool
70where
71 F: Fn(&mut UnitVec<Node>, &AExpr, &Arena<AExpr>) -> bool,
72{
73 let mut stack = unitvec![];
74 let mut ae = expr_arena.get(node);
75
76 loop {
77 if !property(&mut stack, ae, expr_arena) {
78 return false;
79 }
80
81 let Some(node) = stack.pop() else {
82 break;
83 };
84
85 ae = expr_arena.get(node);
86 }
87
88 true
89}
90
91fn does_not_modify(stack: &mut UnitVec<Node>, ae: &AExpr, _expr_arena: &Arena<AExpr>) -> bool {
94 property_and_traverse(stack, ae, |ae| ae.does_not_modify_top_level())
95}
96
97pub fn does_not_modify_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {
98 property_rec(node, expr_arena, does_not_modify)
99}
100
101pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {
106 use AExpr::*;
107
108 if !ae.is_elementwise_top_level() {
109 return false;
110 }
111
112 match ae {
113 #[cfg(feature = "is_in")]
116 Function {
117 function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { .. }),
118 input,
119 ..
120 } => (|| {
121 if let Some(rhs) = input.get(1) {
122 assert_eq!(input.len(), 2); let rhs = rhs.node();
124
125 if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) {
126 stack.extend([input[0].node()]);
127 return;
128 }
129 };
130
131 ae.inputs_rev(stack);
132 })(),
133 _ => ae.inputs_rev(stack),
134 }
135
136 true
137}
138
139pub fn all_elementwise<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool
140where
141 Node: From<&'a N>,
142{
143 nodes
144 .iter()
145 .all(|n| is_elementwise_rec(n.into(), expr_arena))
146}
147
148pub fn is_elementwise_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {
150 property_rec(node, expr_arena, is_elementwise)
151}
152
153pub fn is_elementwise_rec_no_cat_cast<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>) -> bool {
156 let mut stack = unitvec![];
157
158 loop {
159 if !is_elementwise(&mut stack, ae, expr_arena) {
160 return false;
161 }
162
163 #[cfg(feature = "dtype-categorical")]
164 {
165 if let AExpr::Cast {
166 dtype: DataType::Categorical(..),
167 ..
168 } = ae
169 {
170 return false;
171 }
172 }
173
174 let Some(node) = stack.pop() else {
175 break;
176 };
177
178 ae = expr_arena.get(node);
179 }
180
181 true
182}
183
184#[derive(Debug, Clone)]
185pub enum ExprPushdownGroup {
186 Pushable,
190 Fallible,
198 Barrier,
202}
203
204impl ExprPushdownGroup {
205 pub fn update_with_expr(
210 &mut self,
211 stack: &mut UnitVec<Node>,
212 ae: &AExpr,
213 expr_arena: &Arena<AExpr>,
214 ) -> &mut Self {
215 match self {
216 ExprPushdownGroup::Pushable | ExprPushdownGroup::Fallible => {
217 if match ae {
219 AExpr::Function {
222 function: IRFunctionExpr::ListExpr(IRListFunction::Get(false)),
223 ..
224 } => true,
225
226 #[cfg(feature = "list_gather")]
227 AExpr::Function {
228 function: IRFunctionExpr::ListExpr(IRListFunction::Gather(false)),
229 ..
230 } => true,
231
232 #[cfg(feature = "dtype-array")]
233 AExpr::Function {
234 function: IRFunctionExpr::ArrayExpr(IRArrayFunction::Get(false)),
235 ..
236 } => true,
237
238 #[cfg(all(feature = "strings", feature = "temporal"))]
239 AExpr::Function {
240 input,
241 function:
242 IRFunctionExpr::StringExpr(IRStringFunction::Strptime(_, strptime_options)),
243 ..
244 } => {
245 debug_assert!(input.len() <= 2);
246
247 debug_assert!(matches!(
249 input.get(1).map(|x| expr_arena.get(x.node())),
250 Some(AExpr::Literal(_)) | None
251 ));
252
253 match input.first().map(|x| expr_arena.get(x.node())) {
254 Some(AExpr::Literal(_)) | None => false,
255 _ => strptime_options.strict,
256 }
257 },
258 #[cfg(feature = "python")]
259 AExpr::AnonymousFunction {
262 options, fmt_str, ..
263 } if options.flags.contains(FunctionFlags::APPLY_LIST)
264 && fmt_str.as_ref().as_str() == MAP_LIST_NAME =>
265 {
266 return self;
267 },
268
269 AExpr::Cast {
270 expr,
271 dtype: _,
272 options: CastOptions::Strict,
273 } => !matches!(expr_arena.get(*expr), AExpr::Literal(_)),
274
275 _ => false,
276 } {
277 *self = ExprPushdownGroup::Fallible;
278 }
279
280 if !is_elementwise(stack, ae, expr_arena) {
282 *self = ExprPushdownGroup::Barrier
283 }
284 },
285
286 ExprPushdownGroup::Barrier => {},
287 }
288
289 self
290 }
291
292 pub fn update_with_expr_rec<'a>(
293 &mut self,
294 mut ae: &'a AExpr,
295 expr_arena: &'a Arena<AExpr>,
296 scratch: Option<&mut UnitVec<Node>>,
297 ) -> &mut Self {
298 let mut local_scratch = unitvec![];
299 let stack = scratch.unwrap_or(&mut local_scratch);
300
301 loop {
302 self.update_with_expr(stack, ae, expr_arena);
303
304 if let ExprPushdownGroup::Barrier = self {
305 return self;
306 }
307
308 let Some(node) = stack.pop() else {
309 break;
310 };
311
312 ae = expr_arena.get(node);
313 }
314
315 self
316 }
317
318 pub fn blocks_pushdown(&self, maintain_errors: bool) -> bool {
319 match self {
320 ExprPushdownGroup::Barrier => true,
321 ExprPushdownGroup::Fallible => maintain_errors,
322 ExprPushdownGroup::Pushable => false,
323 }
324 }
325}
326
327pub fn can_pre_agg_exprs(
328 exprs: &[ExprIR],
329 expr_arena: &Arena<AExpr>,
330 _input_schema: &Schema,
331) -> bool {
332 exprs
333 .iter()
334 .all(|e| can_pre_agg(e.node(), expr_arena, _input_schema))
335}
336
337pub fn can_pre_agg(agg: Node, expr_arena: &Arena<AExpr>, _input_schema: &Schema) -> bool {
340 let aexpr = expr_arena.get(agg);
341
342 match aexpr {
343 AExpr::Len => true,
344 AExpr::Column(_) | AExpr::Literal(_) => false,
345 AExpr::Agg(_) => {
347 let has_aggregation =
348 |node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_)));
349
350 let can_partition = (expr_arena).iter(agg).all(|(_, ae)| {
354 use AExpr::*;
355 match ae {
356 #[cfg(feature = "dtype-struct")]
358 Agg(IRAggExpr::Mean(_)) => {
359 matches!(
362 expr_arena
363 .get(agg)
364 .get_type(_input_schema, Context::Default, expr_arena)
365 .map(|dt| { dt.is_primitive_numeric() }),
366 Ok(true)
367 )
368 },
369 Agg(agg_e) => {
371 matches!(
372 agg_e,
373 IRAggExpr::Min { .. }
374 | IRAggExpr::Max { .. }
375 | IRAggExpr::Sum(_)
376 | IRAggExpr::Last(_)
377 | IRAggExpr::First(_)
378 | IRAggExpr::Count(_, true)
379 )
380 },
381 Function { input, options, .. } => {
382 options.is_elementwise()
383 && input.len() == 1
384 && !has_aggregation(input[0].node())
385 },
386 BinaryExpr { left, right, .. } => {
387 !has_aggregation(*left) && !has_aggregation(*right)
388 },
389 Ternary {
390 truthy,
391 falsy,
392 predicate,
393 ..
394 } => {
395 !has_aggregation(*truthy)
396 && !has_aggregation(*falsy)
397 && !has_aggregation(*predicate)
398 },
399 Literal(lv) => lv.is_scalar(),
400 Column(_) | Len | Cast { .. } => true,
401 _ => false,
402 }
403 });
404
405 #[cfg(feature = "object")]
406 {
407 for name in aexpr_to_leaf_names(agg, expr_arena) {
408 let dtype = _input_schema.get(&name).unwrap();
409
410 if let DataType::Object(_) = dtype {
411 return false;
412 }
413 }
414 }
415 can_partition
416 },
417 _ => false,
418 }
419}