Skip to main content

polars_expr/
planner.rs

1use polars_core::prelude::*;
2use polars_plan::constants::{get_literal_name, get_pl_element_name, get_pl_structfields_name};
3use polars_plan::prelude::expr_ir::ExprIR;
4use polars_plan::prelude::*;
5use recursive::recursive;
6
7use crate::dispatch::{function_expr_to_groups_udf, function_expr_to_udf};
8use crate::expressions as phys_expr;
9use crate::expressions::*;
10use crate::reduce::GroupedReduction;
11
12pub fn get_expr_depth_limit() -> PolarsResult<u16> {
13    let depth = if let Ok(d) = std::env::var("POLARS_MAX_EXPR_DEPTH") {
14        let v = d
15            .parse::<u64>()
16            .map_err(|_| polars_err!(ComputeError: "could not parse 'max_expr_depth': {}", d))?;
17        u16::try_from(v).unwrap_or(0)
18    } else {
19        512
20    };
21    Ok(depth)
22}
23
24fn ok_checker(_i: usize, _state: &ExpressionConversionState) -> PolarsResult<()> {
25    Ok(())
26}
27
28pub fn create_physical_expressions_from_irs(
29    exprs: &[ExprIR],
30    expr_arena: &mut Arena<AExpr>,
31    schema: &SchemaRef,
32    state: &mut ExpressionConversionState,
33) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
34    create_physical_expressions_check_state(exprs, expr_arena, schema, state, ok_checker)
35}
36
37pub(crate) fn create_physical_expressions_check_state<F>(
38    exprs: &[ExprIR],
39    expr_arena: &mut Arena<AExpr>,
40    schema: &SchemaRef,
41    state: &mut ExpressionConversionState,
42    checker: F,
43) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
44where
45    F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>,
46{
47    exprs
48        .iter()
49        .enumerate()
50        .map(|(i, e)| {
51            state.reset();
52            let out = create_physical_expr(e, expr_arena, schema, state);
53            checker(i, state)?;
54            out
55        })
56        .collect()
57}
58
59pub(crate) fn create_physical_expressions_from_nodes(
60    exprs: &[Node],
61    expr_arena: &mut Arena<AExpr>,
62    schema: &SchemaRef,
63    state: &mut ExpressionConversionState,
64) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
65    create_physical_expressions_from_nodes_check_state(exprs, expr_arena, schema, state, ok_checker)
66}
67
68pub(crate) fn create_physical_expressions_from_nodes_check_state<F>(
69    exprs: &[Node],
70    expr_arena: &mut Arena<AExpr>,
71    schema: &SchemaRef,
72    state: &mut ExpressionConversionState,
73    checker: F,
74) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
75where
76    F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>,
77{
78    exprs
79        .iter()
80        .enumerate()
81        .map(|(i, e)| {
82            state.reset();
83            let out = create_physical_expr_inner(*e, expr_arena, schema, state);
84            checker(i, state)?;
85            out
86        })
87        .collect()
88}
89
90#[derive(Copy, Clone)]
91pub struct ExpressionConversionState {
92    // settings per context
93    // they remain activate between
94    // expressions
95    pub allow_threading: bool,
96    pub has_windows: bool,
97    // settings per expression
98    // those are reset every expression
99    local: LocalConversionState,
100}
101
102#[derive(Copy, Clone, Default)]
103struct LocalConversionState {
104    has_window: bool,
105    has_lit: bool,
106}
107
108impl ExpressionConversionState {
109    pub fn new(allow_threading: bool) -> Self {
110        Self {
111            allow_threading,
112            has_windows: false,
113            local: LocalConversionState {
114                ..Default::default()
115            },
116        }
117    }
118
119    fn reset(&mut self) {
120        self.local = LocalConversionState::default();
121    }
122
123    fn set_window(&mut self) {
124        self.has_windows = true;
125        self.local.has_window = true;
126    }
127}
128
129pub fn create_physical_expr(
130    expr_ir: &ExprIR,
131    expr_arena: &mut Arena<AExpr>,
132    schema: &SchemaRef, // Schema of the input.
133    state: &mut ExpressionConversionState,
134) -> PolarsResult<Arc<dyn PhysicalExpr>> {
135    let phys_expr = create_physical_expr_inner(expr_ir.node(), expr_arena, schema, state)?;
136
137    if let Some(name) = expr_ir.get_alias() {
138        Ok(Arc::new(AliasExpr::new(
139            phys_expr,
140            name.clone(),
141            node_to_expr(expr_ir.node(), expr_arena),
142        )))
143    } else {
144        Ok(phys_expr)
145    }
146}
147
148#[recursive]
149fn create_physical_expr_inner(
150    expression: Node,
151    expr_arena: &mut Arena<AExpr>,
152    schema: &SchemaRef, // Schema of the input.
153    state: &mut ExpressionConversionState,
154) -> PolarsResult<Arc<dyn PhysicalExpr>> {
155    use AExpr::*;
156
157    let aexpr = expr_arena.get(expression);
158    match aexpr.clone() {
159        Len => Ok(Arc::new(phys_expr::LenExpr::new())),
160        #[cfg(feature = "dynamic_group_by")]
161        Rolling {
162            function,
163            index_column,
164            period,
165            offset,
166            closed_window,
167        } => {
168            let output_field = aexpr.to_field(&ToFieldContext::new(expr_arena, schema))?;
169            let index_column = create_physical_expr_inner(index_column, expr_arena, schema, state)?;
170
171            state.set_window();
172            let phys_function = create_physical_expr_inner(function, expr_arena, schema, state)?;
173            let expr = node_to_expr(expression, expr_arena);
174
175            // set again as the state can be reset
176            state.set_window();
177            Ok(Arc::new(RollingExpr {
178                phys_function,
179                index_column,
180                period,
181                offset,
182                closed_window,
183                expr,
184                output_field,
185            }))
186        },
187        Over {
188            function,
189            partition_by,
190            order_by,
191            mapping,
192        } => {
193            let output_field = aexpr.to_field(&ToFieldContext::new(expr_arena, schema))?;
194            state.set_window();
195            let phys_function = create_physical_expr_inner(function, expr_arena, schema, state)?;
196
197            let mut order_by_is_elementwise = false;
198            let order_by = order_by
199                .map(|(node, options)| {
200                    order_by_is_elementwise |= is_elementwise_rec(node, expr_arena);
201                    PolarsResult::Ok((
202                        create_physical_expr_inner(node, expr_arena, schema, state)?,
203                        options,
204                    ))
205                })
206                .transpose()?;
207
208            let expr = node_to_expr(expression, expr_arena);
209
210            // set again as the state can be reset
211            state.set_window();
212            let all_group_by_are_elementwise = partition_by
213                .iter()
214                .all(|n| is_elementwise_rec(*n, expr_arena));
215            let group_by =
216                create_physical_expressions_from_nodes(&partition_by, expr_arena, schema, state)?;
217            let mut apply_columns = aexpr_to_leaf_names(function, expr_arena);
218            // sort and then dedup removes consecutive duplicates == all duplicates
219            apply_columns.sort();
220            apply_columns.dedup();
221
222            if apply_columns.is_empty() {
223                if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Literal(_))) {
224                    apply_columns.push(get_literal_name())
225                } else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Len)) {
226                    apply_columns.push(PlSmallStr::from_static("len"))
227                } else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Element)) {
228                    apply_columns.push(PlSmallStr::from_static("element"))
229                } else {
230                    let e = node_to_expr(function, expr_arena);
231                    polars_bail!(
232                        ComputeError:
233                        "cannot apply a window function, did not find a root column; \
234                        this is likely due to a syntax error in this expression: {:?}", e
235                    );
236                }
237            }
238
239            // Check if the branches have an aggregation
240            // when(a > sum)
241            // then (foo)
242            // otherwise(bar - sum)
243            let mut has_arity = false;
244            let mut agg_col = false;
245            for (_, e) in expr_arena.iter(function) {
246                match e {
247                    AExpr::Ternary { .. } | AExpr::BinaryExpr { .. } => {
248                        has_arity = true;
249                    },
250                    AExpr::Agg(_) => {
251                        agg_col = true;
252                    },
253                    AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. }
254                        if options.flags.returns_scalar() =>
255                    {
256                        agg_col = true;
257                    },
258                    _ => {},
259                }
260            }
261            let has_different_group_sources = has_arity && agg_col;
262
263            Ok(Arc::new(WindowExpr {
264                group_by,
265                order_by,
266                apply_columns,
267                phys_function,
268                mapping,
269                expr,
270                has_different_group_sources,
271                output_field,
272
273                order_by_is_elementwise,
274                all_group_by_are_elementwise,
275            }))
276        },
277        Literal(value) => {
278            state.local.has_lit = true;
279            Ok(Arc::new(LiteralExpr::new(
280                value.clone(),
281                node_to_expr(expression, expr_arena),
282            )))
283        },
284        BinaryExpr { left, op, right } => {
285            let output_field = expr_arena
286                .get(expression)
287                .to_field(&ToFieldContext::new(expr_arena, schema))?;
288            let is_scalar = is_scalar_ae(expression, expr_arena);
289            let lhs = create_physical_expr_inner(left, expr_arena, schema, state)?;
290            let rhs = create_physical_expr_inner(right, expr_arena, schema, state)?;
291            Ok(Arc::new(phys_expr::BinaryExpr::new(
292                lhs,
293                op,
294                rhs,
295                node_to_expr(expression, expr_arena),
296                state.local.has_lit,
297                state.allow_threading,
298                is_scalar,
299                output_field,
300            )))
301        },
302        Column(column) => Ok(Arc::new(ColumnExpr::new(
303            column.clone(),
304            node_to_expr(expression, expr_arena),
305            schema.clone(),
306        ))),
307        Element => {
308            let output_field = expr_arena
309                .get(expression)
310                .to_field(&ToFieldContext::new(expr_arena, schema))?;
311
312            Ok(Arc::new(ElementExpr::new(output_field)))
313        },
314        #[cfg(feature = "dtype-struct")]
315        StructField(field) => {
316            let output_field = expr_arena
317                .get(expression)
318                .to_field(&ToFieldContext::new(expr_arena, schema))?;
319
320            Ok(Arc::new(FieldExpr::new(
321                field.clone(),
322                node_to_expr(expression, expr_arena),
323                output_field,
324            )))
325        },
326        Sort { expr, options } => {
327            let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
328            Ok(Arc::new(SortExpr::new(
329                phys_expr,
330                options,
331                node_to_expr(expression, expr_arena),
332            )))
333        },
334        Gather {
335            expr,
336            idx,
337            returns_scalar,
338            null_on_oob,
339        } => {
340            let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
341            let phys_idx = create_physical_expr_inner(idx, expr_arena, schema, state)?;
342            Ok(Arc::new(GatherExpr {
343                phys_expr,
344                idx: phys_idx,
345                expr: node_to_expr(expression, expr_arena),
346                returns_scalar,
347                null_on_oob,
348            }))
349        },
350        SortBy {
351            expr,
352            by,
353            sort_options,
354        } => {
355            let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
356            let phys_by = create_physical_expressions_from_nodes(&by, expr_arena, schema, state)?;
357            Ok(Arc::new(SortByExpr::new(
358                phys_expr,
359                phys_by,
360                node_to_expr(expression, expr_arena),
361                sort_options.clone(),
362            )))
363        },
364        Filter { input, by } => {
365            let phys_input = create_physical_expr_inner(input, expr_arena, schema, state)?;
366            let phys_by = create_physical_expr_inner(by, expr_arena, schema, state)?;
367            Ok(Arc::new(FilterExpr::new(
368                phys_input,
369                phys_by,
370                node_to_expr(expression, expr_arena),
371            )))
372        },
373        Agg(agg) => {
374            let expr = agg.get_input().first();
375            let input = create_physical_expr_inner(expr, expr_arena, schema, state)?;
376            let allow_threading = state.allow_threading;
377
378            let output_field = expr_arena
379                .get(expression)
380                .to_field(&ToFieldContext::new(expr_arena, schema))?;
381
382            let groupby = GroupByMethod::from(agg.clone());
383            let agg_type = AggregationType {
384                groupby,
385                allow_threading,
386            };
387
388            Ok(Arc::new(AggregationExpr::new(
389                input,
390                agg_type,
391                output_field,
392            )))
393        },
394        Function {
395            input,
396            function: function @ (IRFunctionExpr::ArgMin | IRFunctionExpr::ArgMax),
397            options: _,
398        } => {
399            let phys_input =
400                create_physical_expr_inner(input[0].node(), expr_arena, schema, state)?;
401
402            let mut output_field = expr_arena
403                .get(expression)
404                .to_field(&ToFieldContext::new(expr_arena, schema))?;
405            output_field = Field::new(output_field.name().clone(), IDX_DTYPE.clone());
406
407            let groupby = match function {
408                IRFunctionExpr::ArgMin => GroupByMethod::ArgMin,
409                IRFunctionExpr::ArgMax => GroupByMethod::ArgMax,
410                _ => unreachable!(), // guaranteed by pattern
411            };
412
413            let agg_type = AggregationType {
414                groupby,
415                allow_threading: state.allow_threading,
416            };
417
418            Ok(Arc::new(AggregationExpr::new(
419                phys_input,
420                agg_type,
421                output_field,
422            )))
423        },
424        Function {
425            input: inputs,
426            function: function @ (IRFunctionExpr::MinBy | IRFunctionExpr::MaxBy),
427            options: _,
428        } => {
429            assert!(inputs.len() == 2);
430            let new_minmax_by = match function {
431                IRFunctionExpr::MinBy => AggMinMaxByExpr::new_min_by,
432                IRFunctionExpr::MaxBy => AggMinMaxByExpr::new_max_by,
433                _ => unreachable!(), // guaranteed by pattern
434            };
435            let input = create_physical_expr_inner(inputs[0].node(), expr_arena, schema, state)?;
436            let by = create_physical_expr_inner(inputs[1].node(), expr_arena, schema, state)?;
437            return Ok(Arc::new(new_minmax_by(input, by)));
438        },
439        Cast {
440            expr,
441            dtype,
442            options,
443        } => {
444            let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
445            Ok(Arc::new(CastExpr {
446                input: phys_expr,
447                dtype: dtype.clone(),
448                expr: node_to_expr(expression, expr_arena),
449                options,
450            }))
451        },
452        Ternary {
453            predicate,
454            truthy,
455            falsy,
456        } => {
457            let is_scalar = is_scalar_ae(expression, expr_arena);
458            let mut lit_count = 0u8;
459            state.reset();
460            let predicate = create_physical_expr_inner(predicate, expr_arena, schema, state)?;
461            lit_count += state.local.has_lit as u8;
462            state.reset();
463            let truthy = create_physical_expr_inner(truthy, expr_arena, schema, state)?;
464            lit_count += state.local.has_lit as u8;
465            state.reset();
466            let falsy = create_physical_expr_inner(falsy, expr_arena, schema, state)?;
467            lit_count += state.local.has_lit as u8;
468            Ok(Arc::new(TernaryExpr::new(
469                predicate,
470                truthy,
471                falsy,
472                node_to_expr(expression, expr_arena),
473                state.allow_threading && lit_count < 2,
474                is_scalar,
475            )))
476        },
477        AExpr::AnonymousAgg {
478            input,
479            fmt_str: _,
480            function,
481        } => {
482            let output_field = expr_arena
483                .get(expression)
484                .to_field(&ToFieldContext::new(expr_arena, schema))?;
485
486            let inputs = create_physical_expressions_from_irs(&input, expr_arena, schema, state)?;
487            let grouped_reduction = function
488                .clone()
489                .materialize()?
490                .as_any()
491                .downcast_ref::<Box<dyn GroupedReduction>>()
492                .unwrap()
493                .new_empty();
494
495            Ok(Arc::new(AnonymousAggregationExpr::new(
496                inputs,
497                grouped_reduction,
498                output_field,
499            )))
500        },
501        AnonymousFunction {
502            input,
503            function,
504            options,
505            fmt_str: _,
506        } => {
507            let is_scalar = is_scalar_ae(expression, expr_arena);
508            let output_field = expr_arena
509                .get(expression)
510                .to_field(&ToFieldContext::new(expr_arena, schema))?;
511
512            let input = create_physical_expressions_from_irs(&input, expr_arena, schema, state)?;
513
514            let function = function.clone().materialize()?;
515            let function = function.into_inner().as_column_udf();
516
517            Ok(Arc::new(ApplyExpr::new(
518                input,
519                SpecialEq::new(function),
520                None,
521                node_to_expr(expression, expr_arena),
522                options,
523                state.allow_threading,
524                schema.clone(),
525                output_field,
526                is_scalar,
527                true,
528            )))
529        },
530        Eval {
531            expr,
532            evaluation,
533            variant,
534        } => {
535            let is_scalar = is_scalar_ae(expression, expr_arena);
536            let evaluation_is_scalar = is_scalar_ae(evaluation, expr_arena);
537            let evaluation_is_elementwise = is_elementwise_rec(evaluation, expr_arena);
538            // @NOTE: This is actually also something the downstream apply code should care about.
539            let mut pd_group = ExprPushdownGroup::Pushable;
540            pd_group.update_with_expr_rec(expr_arena.get(evaluation), expr_arena, None);
541            let evaluation_is_fallible = matches!(pd_group, ExprPushdownGroup::Fallible);
542
543            let output_field = expr_arena
544                .get(expression)
545                .to_field(&ToFieldContext::new(expr_arena, schema))?;
546            let input_field = expr_arena
547                .get(expr)
548                .to_field(&ToFieldContext::new(expr_arena, schema))?;
549            let expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
550
551            let element_dtype = variant.element_dtype(&input_field.dtype)?;
552            let mut eval_schema = schema.as_ref().clone();
553            eval_schema.insert(get_pl_element_name(), element_dtype.clone());
554            let evaluation =
555                create_physical_expr_inner(evaluation, expr_arena, &Arc::new(eval_schema), state)?;
556
557            Ok(Arc::new(EvalExpr::new(
558                expr,
559                evaluation,
560                variant,
561                node_to_expr(expression, expr_arena),
562                output_field,
563                is_scalar,
564                evaluation_is_scalar,
565                evaluation_is_elementwise,
566                evaluation_is_fallible,
567            )))
568        },
569        #[cfg(feature = "dtype-struct")]
570        StructEval { expr, evaluation } => {
571            let is_scalar = is_scalar_ae(expression, expr_arena);
572            let output_field = expr_arena
573                .get(expression)
574                .to_field(&ToFieldContext::new(expr_arena, schema))?;
575            let input_field = expr_arena
576                .get(expr)
577                .to_field(&ToFieldContext::new(expr_arena, schema))?;
578
579            let input = create_physical_expr_inner(expr, expr_arena, schema, state)?;
580
581            let mut eval_schema = schema.as_ref().clone();
582            eval_schema.insert(get_pl_structfields_name(), input_field.dtype().clone());
583            let eval_schema = Arc::new(eval_schema);
584
585            let evaluation = evaluation
586                .iter()
587                .map(|e| create_physical_expr(e, expr_arena, &eval_schema, state))
588                .collect::<PolarsResult<Vec<_>>>()?;
589
590            Ok(Arc::new(StructEvalExpr::new(
591                input,
592                evaluation,
593                node_to_expr(expression, expr_arena),
594                output_field,
595                is_scalar,
596                state.allow_threading,
597            )))
598        },
599        Function {
600            input,
601            function,
602            options,
603        } => {
604            let is_scalar = is_scalar_ae(expression, expr_arena);
605
606            let output_field = expr_arena
607                .get(expression)
608                .to_field(&ToFieldContext::new(expr_arena, schema))?;
609
610            let input = create_physical_expressions_from_irs(&input, expr_arena, schema, state)?;
611            let is_fallible = expr_arena.get(expression).is_fallible_top_level(expr_arena);
612
613            Ok(Arc::new(ApplyExpr::new(
614                input,
615                function_expr_to_udf(function.clone()),
616                function_expr_to_groups_udf(&function),
617                node_to_expr(expression, expr_arena),
618                options,
619                state.allow_threading,
620                schema.clone(),
621                output_field,
622                is_scalar,
623                is_fallible,
624            )))
625        },
626
627        Slice {
628            input,
629            offset,
630            length,
631        } => {
632            let input = create_physical_expr_inner(input, expr_arena, schema, state)?;
633            let offset = create_physical_expr_inner(offset, expr_arena, schema, state)?;
634            let length = create_physical_expr_inner(length, expr_arena, schema, state)?;
635            Ok(Arc::new(SliceExpr {
636                input,
637                offset,
638                length,
639                expr: node_to_expr(expression, expr_arena),
640            }))
641        },
642        Explode { expr, options } => {
643            let input = create_physical_expr_inner(expr, expr_arena, schema, state)?;
644            let function = SpecialEq::new(Arc::new(
645                move |c: &mut [polars_core::frame::column::Column]| c[0].explode(options),
646            ) as Arc<dyn ColumnsUdf>);
647
648            let output_field = expr_arena
649                .get(expression)
650                .to_field(&ToFieldContext::new(expr_arena, schema))?;
651
652            Ok(Arc::new(ApplyExpr::new(
653                vec![input],
654                function,
655                None,
656                node_to_expr(expression, expr_arena),
657                FunctionOptions::groupwise(),
658                state.allow_threading,
659                schema.clone(),
660                output_field,
661                false,
662                false,
663            )))
664        },
665    }
666}