Skip to main content

polars_expr/
planner.rs

1use polars_core::prelude::*;
2use polars_plan::constants::{get_literal_name, get_pl_element_name, get_pl_structfields_name};
3use polars_plan::prelude::expr_ir::ExprIR;
4use polars_plan::prelude::*;
5use recursive::recursive;
6
7use crate::dispatch::{function_expr_to_groups_udf, function_expr_to_udf};
8use crate::expressions as phys_expr;
9use crate::expressions::*;
10use crate::reduce::GroupedReduction;
11
12pub fn get_expr_depth_limit() -> PolarsResult<u16> {
13    let depth = if let Ok(d) = std::env::var("POLARS_MAX_EXPR_DEPTH") {
14        let v = d
15            .parse::<u64>()
16            .map_err(|_| polars_err!(ComputeError: "could not parse 'max_expr_depth': {}", d))?;
17        u16::try_from(v).unwrap_or(0)
18    } else {
19        512
20    };
21    Ok(depth)
22}
23
24fn ok_checker(_i: usize, _state: &ExpressionConversionState) -> PolarsResult<()> {
25    Ok(())
26}
27
28pub fn create_physical_expressions_from_irs(
29    exprs: &[ExprIR],
30    expr_arena: &mut Arena<AExpr>,
31    schema: &SchemaRef,
32    state: &mut ExpressionConversionState,
33) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
34    create_physical_expressions_check_state(exprs, expr_arena, schema, state, ok_checker)
35}
36
37pub(crate) fn create_physical_expressions_check_state<F>(
38    exprs: &[ExprIR],
39    expr_arena: &mut Arena<AExpr>,
40    schema: &SchemaRef,
41    state: &mut ExpressionConversionState,
42    checker: F,
43) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
44where
45    F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>,
46{
47    exprs
48        .iter()
49        .enumerate()
50        .map(|(i, e)| {
51            state.reset();
52            let out = create_physical_expr(e, expr_arena, schema, state);
53            checker(i, state)?;
54            out
55        })
56        .collect()
57}
58
59pub(crate) fn create_physical_expressions_from_nodes(
60    exprs: &[Node],
61    expr_arena: &mut Arena<AExpr>,
62    schema: &SchemaRef,
63    state: &mut ExpressionConversionState,
64) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
65    create_physical_expressions_from_nodes_check_state(exprs, expr_arena, schema, state, ok_checker)
66}
67
68pub(crate) fn create_physical_expressions_from_nodes_check_state<F>(
69    exprs: &[Node],
70    expr_arena: &mut Arena<AExpr>,
71    schema: &SchemaRef,
72    state: &mut ExpressionConversionState,
73    checker: F,
74) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
75where
76    F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>,
77{
78    exprs
79        .iter()
80        .enumerate()
81        .map(|(i, e)| {
82            state.reset();
83            let out = create_physical_expr_inner(*e, expr_arena, schema, state);
84            checker(i, state)?;
85            out
86        })
87        .collect()
88}
89
90#[derive(Copy, Clone)]
91pub struct ExpressionConversionState {
92    // settings per context
93    // they remain activate between
94    // expressions
95    pub allow_threading: bool,
96    pub has_windows: bool,
97    // settings per expression
98    // those are reset every expression
99    local: LocalConversionState,
100}
101
102#[derive(Copy, Clone, Default)]
103struct LocalConversionState {
104    has_window: bool,
105    has_lit: bool,
106}
107
108impl ExpressionConversionState {
109    pub fn new(allow_threading: bool) -> Self {
110        Self {
111            allow_threading,
112            has_windows: false,
113            local: LocalConversionState {
114                ..Default::default()
115            },
116        }
117    }
118
119    fn reset(&mut self) {
120        self.local = LocalConversionState::default();
121    }
122
123    fn set_window(&mut self) {
124        self.has_windows = true;
125        self.local.has_window = true;
126    }
127}
128
129pub fn create_physical_expr(
130    expr_ir: &ExprIR,
131    expr_arena: &mut Arena<AExpr>,
132    schema: &SchemaRef, // Schema of the input.
133    state: &mut ExpressionConversionState,
134) -> PolarsResult<Arc<dyn PhysicalExpr>> {
135    let phys_expr = create_physical_expr_inner(expr_ir.node(), expr_arena, schema, state)?;
136
137    if let Some(name) = expr_ir.get_alias() {
138        Ok(Arc::new(AliasExpr::new(
139            phys_expr,
140            name.clone(),
141            node_to_expr(expr_ir.node(), expr_arena),
142        )))
143    } else {
144        Ok(phys_expr)
145    }
146}
147
148#[recursive]
149fn create_physical_expr_inner(
150    expression: Node,
151    expr_arena: &mut Arena<AExpr>,
152    schema: &SchemaRef, // Schema of the input.
153    state: &mut ExpressionConversionState,
154) -> PolarsResult<Arc<dyn PhysicalExpr>> {
155    use AExpr::*;
156
157    let aexpr = expr_arena.get(expression);
158    match aexpr.clone() {
159        Len => Ok(Arc::new(phys_expr::CountExpr::new())),
160        #[cfg(feature = "dynamic_group_by")]
161        Rolling {
162            function,
163            index_column,
164            period,
165            offset,
166            closed_window,
167        } => {
168            let output_field = aexpr.to_field(&ToFieldContext::new(expr_arena, schema))?;
169            let index_column = create_physical_expr_inner(index_column, expr_arena, schema, state)?;
170
171            state.set_window();
172            let phys_function = create_physical_expr_inner(function, expr_arena, schema, state)?;
173            let expr = node_to_expr(expression, expr_arena);
174
175            // set again as the state can be reset
176            state.set_window();
177            Ok(Arc::new(RollingExpr {
178                phys_function,
179                index_column,
180                period,
181                offset,
182                closed_window,
183                expr,
184                output_field,
185            }))
186        },
187        Over {
188            function,
189            partition_by,
190            order_by,
191            mapping,
192        } => {
193            let output_field = aexpr.to_field(&ToFieldContext::new(expr_arena, schema))?;
194            state.set_window();
195            let phys_function = create_physical_expr_inner(function, expr_arena, schema, state)?;
196
197            let mut order_by_is_elementwise = false;
198            let order_by = order_by
199                .map(|(node, options)| {
200                    order_by_is_elementwise |= is_elementwise_rec(node, expr_arena);
201                    PolarsResult::Ok((
202                        create_physical_expr_inner(node, expr_arena, schema, state)?,
203                        options,
204                    ))
205                })
206                .transpose()?;
207
208            let expr = node_to_expr(expression, expr_arena);
209
210            // set again as the state can be reset
211            state.set_window();
212            let all_group_by_are_elementwise = partition_by
213                .iter()
214                .all(|n| is_elementwise_rec(*n, expr_arena));
215            let group_by =
216                create_physical_expressions_from_nodes(&partition_by, expr_arena, schema, state)?;
217            let mut apply_columns = aexpr_to_leaf_names(function, expr_arena);
218            // sort and then dedup removes consecutive duplicates == all duplicates
219            apply_columns.sort();
220            apply_columns.dedup();
221
222            if apply_columns.is_empty() {
223                if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Literal(_))) {
224                    apply_columns.push(get_literal_name())
225                } else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Len)) {
226                    apply_columns.push(PlSmallStr::from_static("len"))
227                } else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Element)) {
228                    apply_columns.push(PlSmallStr::from_static("element"))
229                } else {
230                    let e = node_to_expr(function, expr_arena);
231                    polars_bail!(
232                        ComputeError:
233                        "cannot apply a window function, did not find a root column; \
234                        this is likely due to a syntax error in this expression: {:?}", e
235                    );
236                }
237            }
238
239            // Check if the branches have an aggregation
240            // when(a > sum)
241            // then (foo)
242            // otherwise(bar - sum)
243            let mut has_arity = false;
244            let mut agg_col = false;
245            for (_, e) in expr_arena.iter(function) {
246                match e {
247                    AExpr::Ternary { .. } | AExpr::BinaryExpr { .. } => {
248                        has_arity = true;
249                    },
250                    AExpr::Agg(_) => {
251                        agg_col = true;
252                    },
253                    AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => {
254                        if options.flags.returns_scalar() {
255                            agg_col = true;
256                        }
257                    },
258                    _ => {},
259                }
260            }
261            let has_different_group_sources = has_arity && agg_col;
262
263            Ok(Arc::new(WindowExpr {
264                group_by,
265                order_by,
266                apply_columns,
267                phys_function,
268                mapping,
269                expr,
270                has_different_group_sources,
271                output_field,
272
273                order_by_is_elementwise,
274                all_group_by_are_elementwise,
275            }))
276        },
277        Literal(value) => {
278            state.local.has_lit = true;
279            Ok(Arc::new(LiteralExpr::new(
280                value.clone(),
281                node_to_expr(expression, expr_arena),
282            )))
283        },
284        BinaryExpr { left, op, right } => {
285            let output_field = expr_arena
286                .get(expression)
287                .to_field(&ToFieldContext::new(expr_arena, schema))?;
288            let is_scalar = is_scalar_ae(expression, expr_arena);
289            let lhs = create_physical_expr_inner(left, expr_arena, schema, state)?;
290            let rhs = create_physical_expr_inner(right, expr_arena, schema, state)?;
291            Ok(Arc::new(phys_expr::BinaryExpr::new(
292                lhs,
293                op,
294                rhs,
295                node_to_expr(expression, expr_arena),
296                state.local.has_lit,
297                state.allow_threading,
298                is_scalar,
299                output_field,
300            )))
301        },
302        Column(column) => Ok(Arc::new(ColumnExpr::new(
303            column.clone(),
304            node_to_expr(expression, expr_arena),
305            schema.clone(),
306        ))),
307        Element => {
308            let output_field = expr_arena
309                .get(expression)
310                .to_field(&ToFieldContext::new(expr_arena, schema))?;
311
312            Ok(Arc::new(ElementExpr::new(output_field)))
313        },
314        #[cfg(feature = "dtype-struct")]
315        StructField(field) => {
316            let output_field = expr_arena
317                .get(expression)
318                .to_field(&ToFieldContext::new(expr_arena, schema))?;
319
320            Ok(Arc::new(FieldExpr::new(
321                field.clone(),
322                node_to_expr(expression, expr_arena),
323                output_field,
324            )))
325        },
326        Sort { expr, options } => {
327            let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
328            Ok(Arc::new(SortExpr::new(
329                phys_expr,
330                options,
331                node_to_expr(expression, expr_arena),
332            )))
333        },
334        Gather {
335            expr,
336            idx,
337            returns_scalar,
338            null_on_oob,
339        } => {
340            let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
341            let phys_idx = create_physical_expr_inner(idx, expr_arena, schema, state)?;
342            Ok(Arc::new(GatherExpr {
343                phys_expr,
344                idx: phys_idx,
345                expr: node_to_expr(expression, expr_arena),
346                returns_scalar,
347                null_on_oob,
348            }))
349        },
350        SortBy {
351            expr,
352            by,
353            sort_options,
354        } => {
355            let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
356            let phys_by = create_physical_expressions_from_nodes(&by, expr_arena, schema, state)?;
357            Ok(Arc::new(SortByExpr::new(
358                phys_expr,
359                phys_by,
360                node_to_expr(expression, expr_arena),
361                sort_options.clone(),
362            )))
363        },
364        Filter { input, by } => {
365            let phys_input = create_physical_expr_inner(input, expr_arena, schema, state)?;
366            let phys_by = create_physical_expr_inner(by, expr_arena, schema, state)?;
367            Ok(Arc::new(FilterExpr::new(
368                phys_input,
369                phys_by,
370                node_to_expr(expression, expr_arena),
371            )))
372        },
373        Agg(agg) => {
374            let expr = agg.get_input().first();
375            let input = create_physical_expr_inner(expr, expr_arena, schema, state)?;
376            let allow_threading = state.allow_threading;
377
378            let output_field = expr_arena
379                .get(expression)
380                .to_field(&ToFieldContext::new(expr_arena, schema))?;
381
382            // Special case: Quantile supports multiple inputs.
383            // TODO refactor to FunctionExpr.
384            if let IRAggExpr::Quantile {
385                quantile, method, ..
386            } = agg
387            {
388                let quantile = create_physical_expr_inner(quantile, expr_arena, schema, state)?;
389                return Ok(Arc::new(AggQuantileExpr::new(input, quantile, method)));
390            }
391
392            let groupby = GroupByMethod::from(agg.clone());
393
394            let agg_type = AggregationType {
395                groupby,
396                allow_threading,
397            };
398
399            Ok(Arc::new(AggregationExpr::new(
400                input,
401                agg_type,
402                output_field,
403            )))
404        },
405        Function {
406            input,
407            function: function @ (IRFunctionExpr::ArgMin | IRFunctionExpr::ArgMax),
408            options: _,
409        } => {
410            let phys_input =
411                create_physical_expr_inner(input[0].node(), expr_arena, schema, state)?;
412
413            let mut output_field = expr_arena
414                .get(expression)
415                .to_field(&ToFieldContext::new(expr_arena, schema))?;
416            output_field = Field::new(output_field.name().clone(), IDX_DTYPE.clone());
417
418            let groupby = match function {
419                IRFunctionExpr::ArgMin => GroupByMethod::ArgMin,
420                IRFunctionExpr::ArgMax => GroupByMethod::ArgMax,
421                _ => unreachable!(), // guaranteed by pattern
422            };
423
424            let agg_type = AggregationType {
425                groupby,
426                allow_threading: state.allow_threading,
427            };
428
429            Ok(Arc::new(AggregationExpr::new(
430                phys_input,
431                agg_type,
432                output_field,
433            )))
434        },
435        Function {
436            input: inputs,
437            function: function @ (IRFunctionExpr::MinBy | IRFunctionExpr::MaxBy),
438            options: _,
439        } => {
440            assert!(inputs.len() == 2);
441            let input = inputs[0].node();
442            let by = inputs[1].node();
443            let arg_fn = match function {
444                IRFunctionExpr::MinBy => IRFunctionExpr::ArgMin,
445                IRFunctionExpr::MaxBy => IRFunctionExpr::ArgMax,
446                _ => unreachable!(), // guaranteed by pattern
447            };
448
449            let arg_min_aexpr = AExpr::Function {
450                input: vec![ExprIR::from_node(by, expr_arena)],
451                function: arg_fn,
452                options: FunctionOptions::aggregation(),
453            };
454            let arg_min = expr_arena.add(arg_min_aexpr);
455            let gather_aexpr = AExpr::Gather {
456                expr: input,
457                idx: arg_min,
458                returns_scalar: true,
459                null_on_oob: false,
460            };
461            let gather = expr_arena.add(gather_aexpr);
462
463            return create_physical_expr_inner(gather, expr_arena, schema, state);
464        },
465        Cast {
466            expr,
467            dtype,
468            options,
469        } => {
470            let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
471            Ok(Arc::new(CastExpr {
472                input: phys_expr,
473                dtype: dtype.clone(),
474                expr: node_to_expr(expression, expr_arena),
475                options,
476            }))
477        },
478        Ternary {
479            predicate,
480            truthy,
481            falsy,
482        } => {
483            let is_scalar = is_scalar_ae(expression, expr_arena);
484            let mut lit_count = 0u8;
485            state.reset();
486            let predicate = create_physical_expr_inner(predicate, expr_arena, schema, state)?;
487            lit_count += state.local.has_lit as u8;
488            state.reset();
489            let truthy = create_physical_expr_inner(truthy, expr_arena, schema, state)?;
490            lit_count += state.local.has_lit as u8;
491            state.reset();
492            let falsy = create_physical_expr_inner(falsy, expr_arena, schema, state)?;
493            lit_count += state.local.has_lit as u8;
494            Ok(Arc::new(TernaryExpr::new(
495                predicate,
496                truthy,
497                falsy,
498                node_to_expr(expression, expr_arena),
499                state.allow_threading && lit_count < 2,
500                is_scalar,
501            )))
502        },
503        AExpr::AnonymousAgg {
504            input,
505            fmt_str: _,
506            function,
507        } => {
508            let output_field = expr_arena
509                .get(expression)
510                .to_field(&ToFieldContext::new(expr_arena, schema))?;
511
512            let inputs = create_physical_expressions_from_irs(&input, expr_arena, schema, state)?;
513            let grouped_reduction = function
514                .clone()
515                .materialize()?
516                .as_any()
517                .downcast_ref::<Box<dyn GroupedReduction>>()
518                .unwrap()
519                .new_empty();
520
521            Ok(Arc::new(AnonymousAggregationExpr::new(
522                inputs,
523                grouped_reduction,
524                output_field,
525            )))
526        },
527        AnonymousFunction {
528            input,
529            function,
530            options,
531            fmt_str: _,
532        } => {
533            let is_scalar = is_scalar_ae(expression, expr_arena);
534            let output_field = expr_arena
535                .get(expression)
536                .to_field(&ToFieldContext::new(expr_arena, schema))?;
537
538            let input = create_physical_expressions_from_irs(&input, expr_arena, schema, state)?;
539
540            let function = function.clone().materialize()?;
541            let function = function.into_inner().as_column_udf();
542
543            Ok(Arc::new(ApplyExpr::new(
544                input,
545                SpecialEq::new(function),
546                None,
547                node_to_expr(expression, expr_arena),
548                options,
549                state.allow_threading,
550                schema.clone(),
551                output_field,
552                is_scalar,
553                true,
554            )))
555        },
556        Eval {
557            expr,
558            evaluation,
559            variant,
560        } => {
561            let is_scalar = is_scalar_ae(expression, expr_arena);
562            let evaluation_is_scalar = is_scalar_ae(evaluation, expr_arena);
563            let evaluation_is_elementwise = is_elementwise_rec(evaluation, expr_arena);
564            // @NOTE: This is actually also something the downstream apply code should care about.
565            let mut pd_group = ExprPushdownGroup::Pushable;
566            pd_group.update_with_expr_rec(expr_arena.get(evaluation), expr_arena, None);
567            let evaluation_is_fallible = matches!(pd_group, ExprPushdownGroup::Fallible);
568
569            let output_field = expr_arena
570                .get(expression)
571                .to_field(&ToFieldContext::new(expr_arena, schema))?;
572            let input_field = expr_arena
573                .get(expr)
574                .to_field(&ToFieldContext::new(expr_arena, schema))?;
575            let expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
576
577            let element_dtype = variant.element_dtype(&input_field.dtype)?;
578            let mut eval_schema = schema.as_ref().clone();
579            eval_schema.insert(get_pl_element_name(), element_dtype.clone());
580            let evaluation =
581                create_physical_expr_inner(evaluation, expr_arena, &Arc::new(eval_schema), state)?;
582
583            Ok(Arc::new(EvalExpr::new(
584                expr,
585                evaluation,
586                variant,
587                node_to_expr(expression, expr_arena),
588                output_field,
589                is_scalar,
590                evaluation_is_scalar,
591                evaluation_is_elementwise,
592                evaluation_is_fallible,
593            )))
594        },
595        #[cfg(feature = "dtype-struct")]
596        StructEval { expr, evaluation } => {
597            let is_scalar = is_scalar_ae(expression, expr_arena);
598            let output_field = expr_arena
599                .get(expression)
600                .to_field(&ToFieldContext::new(expr_arena, schema))?;
601            let input_field = expr_arena
602                .get(expr)
603                .to_field(&ToFieldContext::new(expr_arena, schema))?;
604
605            let input = create_physical_expr_inner(expr, expr_arena, schema, state)?;
606
607            let mut eval_schema = schema.as_ref().clone();
608            eval_schema.insert(get_pl_structfields_name(), input_field.dtype().clone());
609            let eval_schema = Arc::new(eval_schema);
610
611            let evaluation = evaluation
612                .iter()
613                .map(|e| create_physical_expr(e, expr_arena, &eval_schema, state))
614                .collect::<PolarsResult<Vec<_>>>()?;
615
616            Ok(Arc::new(StructEvalExpr::new(
617                input,
618                evaluation,
619                node_to_expr(expression, expr_arena),
620                output_field,
621                is_scalar,
622                state.allow_threading,
623            )))
624        },
625        Function {
626            input,
627            function,
628            options,
629        } => {
630            let is_scalar = is_scalar_ae(expression, expr_arena);
631
632            let output_field = expr_arena
633                .get(expression)
634                .to_field(&ToFieldContext::new(expr_arena, schema))?;
635
636            let input = create_physical_expressions_from_irs(&input, expr_arena, schema, state)?;
637            let is_fallible = expr_arena.get(expression).is_fallible_top_level(expr_arena);
638
639            Ok(Arc::new(ApplyExpr::new(
640                input,
641                function_expr_to_udf(function.clone()),
642                function_expr_to_groups_udf(&function),
643                node_to_expr(expression, expr_arena),
644                options,
645                state.allow_threading,
646                schema.clone(),
647                output_field,
648                is_scalar,
649                is_fallible,
650            )))
651        },
652
653        Slice {
654            input,
655            offset,
656            length,
657        } => {
658            let input = create_physical_expr_inner(input, expr_arena, schema, state)?;
659            let offset = create_physical_expr_inner(offset, expr_arena, schema, state)?;
660            let length = create_physical_expr_inner(length, expr_arena, schema, state)?;
661            Ok(Arc::new(SliceExpr {
662                input,
663                offset,
664                length,
665                expr: node_to_expr(expression, expr_arena),
666            }))
667        },
668        Explode { expr, options } => {
669            let input = create_physical_expr_inner(expr, expr_arena, schema, state)?;
670            let function = SpecialEq::new(Arc::new(
671                move |c: &mut [polars_core::frame::column::Column]| c[0].explode(options),
672            ) as Arc<dyn ColumnsUdf>);
673
674            let output_field = expr_arena
675                .get(expression)
676                .to_field(&ToFieldContext::new(expr_arena, schema))?;
677
678            Ok(Arc::new(ApplyExpr::new(
679                vec![input],
680                function,
681                None,
682                node_to_expr(expression, expr_arena),
683                FunctionOptions::groupwise(),
684                state.allow_threading,
685                schema.clone(),
686                output_field,
687                false,
688                false,
689            )))
690        },
691    }
692}