polars_expr/
planner.rs

1use polars_core::prelude::*;
2use polars_plan::constants::PL_ELEMENT_NAME;
3use polars_plan::prelude::expr_ir::ExprIR;
4use polars_plan::prelude::*;
5use recursive::recursive;
6
7use crate::dispatch::{function_expr_to_groups_udf, function_expr_to_udf};
8use crate::expressions as phys_expr;
9use crate::expressions::*;
10
11pub fn get_expr_depth_limit() -> PolarsResult<u16> {
12    let depth = if let Ok(d) = std::env::var("POLARS_MAX_EXPR_DEPTH") {
13        let v = d
14            .parse::<u64>()
15            .map_err(|_| polars_err!(ComputeError: "could not parse 'max_expr_depth': {}", d))?;
16        u16::try_from(v).unwrap_or(0)
17    } else {
18        512
19    };
20    Ok(depth)
21}
22
23fn ok_checker(_i: usize, _state: &ExpressionConversionState) -> PolarsResult<()> {
24    Ok(())
25}
26
27pub fn create_physical_expressions_from_irs(
28    exprs: &[ExprIR],
29    context: Context,
30    expr_arena: &Arena<AExpr>,
31    schema: &SchemaRef,
32    state: &mut ExpressionConversionState,
33) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
34    create_physical_expressions_check_state(exprs, context, expr_arena, schema, state, ok_checker)
35}
36
37pub(crate) fn create_physical_expressions_check_state<F>(
38    exprs: &[ExprIR],
39    context: Context,
40    expr_arena: &Arena<AExpr>,
41    schema: &SchemaRef,
42    state: &mut ExpressionConversionState,
43    checker: F,
44) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
45where
46    F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>,
47{
48    exprs
49        .iter()
50        .enumerate()
51        .map(|(i, e)| {
52            state.reset();
53            let out = create_physical_expr(e, context, expr_arena, schema, state);
54            checker(i, state)?;
55            out
56        })
57        .collect()
58}
59
60pub(crate) fn create_physical_expressions_from_nodes(
61    exprs: &[Node],
62    context: Context,
63    expr_arena: &Arena<AExpr>,
64    schema: &SchemaRef,
65    state: &mut ExpressionConversionState,
66) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
67    create_physical_expressions_from_nodes_check_state(
68        exprs, context, expr_arena, schema, state, ok_checker,
69    )
70}
71
72pub(crate) fn create_physical_expressions_from_nodes_check_state<F>(
73    exprs: &[Node],
74    context: Context,
75    expr_arena: &Arena<AExpr>,
76    schema: &SchemaRef,
77    state: &mut ExpressionConversionState,
78    checker: F,
79) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
80where
81    F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>,
82{
83    exprs
84        .iter()
85        .enumerate()
86        .map(|(i, e)| {
87            state.reset();
88            let out = create_physical_expr_inner(*e, context, expr_arena, schema, state);
89            checker(i, state)?;
90            out
91        })
92        .collect()
93}
94
95#[derive(Copy, Clone)]
96pub struct ExpressionConversionState {
97    // settings per context
98    // they remain activate between
99    // expressions
100    pub allow_threading: bool,
101    pub has_windows: bool,
102    // settings per expression
103    // those are reset every expression
104    local: LocalConversionState,
105}
106
107#[derive(Copy, Clone, Default)]
108struct LocalConversionState {
109    has_implode: bool,
110    has_window: bool,
111    has_lit: bool,
112}
113
114impl ExpressionConversionState {
115    pub fn new(allow_threading: bool) -> Self {
116        Self {
117            allow_threading,
118            has_windows: false,
119            local: LocalConversionState {
120                ..Default::default()
121            },
122        }
123    }
124
125    fn reset(&mut self) {
126        self.local = LocalConversionState::default();
127    }
128
129    fn has_implode(&self) -> bool {
130        self.local.has_implode
131    }
132
133    fn set_window(&mut self) {
134        self.has_windows = true;
135        self.local.has_window = true;
136    }
137}
138
139pub fn create_physical_expr(
140    expr_ir: &ExprIR,
141    ctxt: Context,
142    expr_arena: &Arena<AExpr>,
143    schema: &SchemaRef,
144    state: &mut ExpressionConversionState,
145) -> PolarsResult<Arc<dyn PhysicalExpr>> {
146    let phys_expr = create_physical_expr_inner(expr_ir.node(), ctxt, expr_arena, schema, state)?;
147
148    if let Some(name) = expr_ir.get_alias() {
149        Ok(Arc::new(AliasExpr::new(
150            phys_expr,
151            name.clone(),
152            node_to_expr(expr_ir.node(), expr_arena),
153        )))
154    } else {
155        Ok(phys_expr)
156    }
157}
158
159#[recursive]
160fn create_physical_expr_inner(
161    expression: Node,
162    ctxt: Context,
163    expr_arena: &Arena<AExpr>,
164    schema: &SchemaRef,
165    state: &mut ExpressionConversionState,
166) -> PolarsResult<Arc<dyn PhysicalExpr>> {
167    use AExpr::*;
168
169    match expr_arena.get(expression) {
170        Len => Ok(Arc::new(phys_expr::CountExpr::new())),
171        aexpr @ Window {
172            function,
173            partition_by,
174            order_by,
175            options,
176        } => {
177            let output_field = aexpr.to_field(&ToFieldContext::new(expr_arena, schema))?;
178            let function = *function;
179            state.set_window();
180            let phys_function =
181                create_physical_expr_inner(function, Context::Default, expr_arena, schema, state)?;
182
183            let order_by = order_by
184                .map(|(node, options)| {
185                    PolarsResult::Ok((
186                        create_physical_expr_inner(
187                            node,
188                            Context::Default,
189                            expr_arena,
190                            schema,
191                            state,
192                        )?,
193                        options,
194                    ))
195                })
196                .transpose()?;
197
198            let expr = node_to_expr(expression, expr_arena);
199
200            // set again as the state can be reset
201            state.set_window();
202            match options {
203                WindowType::Over(mapping) => {
204                    // TODO! Order by
205                    let group_by = create_physical_expressions_from_nodes(
206                        partition_by,
207                        Context::Default,
208                        expr_arena,
209                        schema,
210                        state,
211                    )?;
212                    let mut apply_columns = aexpr_to_leaf_names(function, expr_arena);
213                    // sort and then dedup removes consecutive duplicates == all duplicates
214                    apply_columns.sort();
215                    apply_columns.dedup();
216
217                    if apply_columns.is_empty() {
218                        if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Literal(_))) {
219                            apply_columns.push(PlSmallStr::from_static("literal"))
220                        } else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Len)) {
221                            apply_columns.push(PlSmallStr::from_static("len"))
222                        } else {
223                            let e = node_to_expr(function, expr_arena);
224                            polars_bail!(
225                                ComputeError:
226                                "cannot apply a window function, did not find a root column; \
227                                this is likely due to a syntax error in this expression: {:?}", e
228                            );
229                        }
230                    }
231
232                    // Check if the branches have an aggregation
233                    // when(a > sum)
234                    // then (foo)
235                    // otherwise(bar - sum)
236                    let mut has_arity = false;
237                    let mut agg_col = false;
238                    for (_, e) in expr_arena.iter(function) {
239                        match e {
240                            AExpr::Ternary { .. } | AExpr::BinaryExpr { .. } => {
241                                has_arity = true;
242                            },
243                            AExpr::Agg(_) => {
244                                agg_col = true;
245                            },
246                            AExpr::Function { options, .. }
247                            | AExpr::AnonymousFunction { options, .. } => {
248                                if options.flags.returns_scalar() {
249                                    agg_col = true;
250                                }
251                            },
252                            _ => {},
253                        }
254                    }
255                    let has_different_group_sources = has_arity && agg_col;
256
257                    Ok(Arc::new(WindowExpr {
258                        group_by,
259                        order_by,
260                        apply_columns,
261                        phys_function,
262                        mapping: *mapping,
263                        expr,
264                        has_different_group_sources,
265                        output_field,
266                    }))
267                },
268                #[cfg(feature = "dynamic_group_by")]
269                WindowType::Rolling(options) => Ok(Arc::new(RollingExpr {
270                    phys_function,
271                    options: options.clone(),
272                    expr,
273                    output_field,
274                })),
275            }
276        },
277        Literal(value) => {
278            state.local.has_lit = true;
279            Ok(Arc::new(LiteralExpr::new(
280                value.clone(),
281                node_to_expr(expression, expr_arena),
282            )))
283        },
284        BinaryExpr { left, op, right } => {
285            let output_field = expr_arena
286                .get(expression)
287                .to_field(&ToFieldContext::new(expr_arena, schema))?;
288            let is_scalar = is_scalar_ae(expression, expr_arena);
289            let lhs = create_physical_expr_inner(*left, ctxt, expr_arena, schema, state)?;
290            let rhs = create_physical_expr_inner(*right, ctxt, expr_arena, schema, state)?;
291            Ok(Arc::new(phys_expr::BinaryExpr::new(
292                lhs,
293                *op,
294                rhs,
295                node_to_expr(expression, expr_arena),
296                state.local.has_lit,
297                state.allow_threading,
298                is_scalar,
299                output_field,
300            )))
301        },
302        Column(column) => Ok(Arc::new(ColumnExpr::new(
303            column.clone(),
304            node_to_expr(expression, expr_arena),
305            schema.clone(),
306        ))),
307        Element => Ok(Arc::new(ColumnExpr::new(
308            PL_ELEMENT_NAME.clone(),
309            node_to_expr(expression, expr_arena),
310            schema.clone(),
311        ))),
312        Sort { expr, options } => {
313            let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
314            Ok(Arc::new(SortExpr::new(
315                phys_expr,
316                *options,
317                node_to_expr(expression, expr_arena),
318            )))
319        },
320        Gather {
321            expr,
322            idx,
323            returns_scalar,
324        } => {
325            let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
326            let phys_idx = create_physical_expr_inner(*idx, ctxt, expr_arena, schema, state)?;
327            Ok(Arc::new(GatherExpr {
328                phys_expr,
329                idx: phys_idx,
330                expr: node_to_expr(expression, expr_arena),
331                returns_scalar: *returns_scalar,
332            }))
333        },
334        SortBy {
335            expr,
336            by,
337            sort_options,
338        } => {
339            let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
340            let phys_by =
341                create_physical_expressions_from_nodes(by, ctxt, expr_arena, schema, state)?;
342            Ok(Arc::new(SortByExpr::new(
343                phys_expr,
344                phys_by,
345                node_to_expr(expression, expr_arena),
346                sort_options.clone(),
347            )))
348        },
349        Filter { input, by } => {
350            let phys_input = create_physical_expr_inner(*input, ctxt, expr_arena, schema, state)?;
351            let phys_by = create_physical_expr_inner(*by, ctxt, expr_arena, schema, state)?;
352            Ok(Arc::new(FilterExpr::new(
353                phys_input,
354                phys_by,
355                node_to_expr(expression, expr_arena),
356            )))
357        },
358        Agg(agg) => {
359            let expr = agg.get_input().first();
360            let input = create_physical_expr_inner(expr, ctxt, expr_arena, schema, state)?;
361            polars_ensure!(!(state.has_implode() && matches!(ctxt, Context::Aggregation)), InvalidOperation: "'implode' followed by an aggregation is not allowed");
362            state.local.has_implode |= matches!(agg, IRAggExpr::Implode(_));
363            let allow_threading = state.allow_threading;
364
365            match ctxt {
366                Context::Default if !matches!(agg, IRAggExpr::Quantile { .. }) => {
367                    use {GroupByMethod as GBM, IRAggExpr as I};
368
369                    let output_field = expr_arena
370                        .get(expression)
371                        .to_field(&ToFieldContext::new(expr_arena, schema))?;
372                    let groupby = match agg {
373                        I::Min { propagate_nans, .. } if *propagate_nans => GBM::NanMin,
374                        I::Min { .. } => GBM::Min,
375                        I::Max { propagate_nans, .. } if *propagate_nans => GBM::NanMax,
376                        I::Max { .. } => GBM::Max,
377                        I::Median(_) => GBM::Median,
378                        I::NUnique(_) => GBM::NUnique,
379                        I::First(_) => GBM::First,
380                        I::Last(_) => GBM::Last,
381                        I::Item { allow_empty, .. } => GBM::Item {
382                            allow_empty: *allow_empty,
383                        },
384                        I::Mean(_) => GBM::Mean,
385                        I::Implode(_) => GBM::Implode,
386                        I::Quantile { .. } => unreachable!(),
387                        I::Sum(_) => GBM::Sum,
388                        I::Count {
389                            input: _,
390                            include_nulls,
391                        } => GBM::Count {
392                            include_nulls: *include_nulls,
393                        },
394                        I::Std(_, ddof) => GBM::Std(*ddof),
395                        I::Var(_, ddof) => GBM::Var(*ddof),
396                        I::AggGroups(_) => {
397                            polars_bail!(InvalidOperation: "agg groups expression only supported in aggregation context")
398                        },
399                    };
400
401                    let agg_type = AggregationType {
402                        groupby,
403                        allow_threading,
404                    };
405
406                    Ok(Arc::new(AggregationExpr::new(
407                        input,
408                        agg_type,
409                        output_field,
410                    )))
411                },
412                _ => {
413                    if let IRAggExpr::Quantile {
414                        quantile,
415                        method: interpol,
416                        ..
417                    } = agg
418                    {
419                        let quantile =
420                            create_physical_expr_inner(*quantile, ctxt, expr_arena, schema, state)?;
421                        return Ok(Arc::new(AggQuantileExpr::new(input, quantile, *interpol)));
422                    }
423
424                    let mut output_field = expr_arena
425                        .get(expression)
426                        .to_field(&ToFieldContext::new(expr_arena, schema))?;
427
428                    if matches!(ctxt, Context::Aggregation) && !is_scalar_ae(expression, expr_arena)
429                    {
430                        output_field.coerce(output_field.dtype.clone().implode());
431                    }
432
433                    let groupby = GroupByMethod::from(agg.clone());
434                    let agg_type = AggregationType {
435                        groupby,
436                        allow_threading: false,
437                    };
438                    Ok(Arc::new(AggregationExpr::new(
439                        input,
440                        agg_type,
441                        output_field,
442                    )))
443                },
444            }
445        },
446        Cast {
447            expr,
448            dtype,
449            options,
450        } => {
451            let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
452            Ok(Arc::new(CastExpr {
453                input: phys_expr,
454                dtype: dtype.clone(),
455                expr: node_to_expr(expression, expr_arena),
456                options: *options,
457            }))
458        },
459        Ternary {
460            predicate,
461            truthy,
462            falsy,
463        } => {
464            let is_scalar = is_scalar_ae(expression, expr_arena);
465            let mut lit_count = 0u8;
466            state.reset();
467            let predicate =
468                create_physical_expr_inner(*predicate, ctxt, expr_arena, schema, state)?;
469            lit_count += state.local.has_lit as u8;
470            state.reset();
471            let truthy = create_physical_expr_inner(*truthy, ctxt, expr_arena, schema, state)?;
472            lit_count += state.local.has_lit as u8;
473            state.reset();
474            let falsy = create_physical_expr_inner(*falsy, ctxt, expr_arena, schema, state)?;
475            lit_count += state.local.has_lit as u8;
476            Ok(Arc::new(TernaryExpr::new(
477                predicate,
478                truthy,
479                falsy,
480                node_to_expr(expression, expr_arena),
481                state.allow_threading && lit_count < 2,
482                is_scalar,
483            )))
484        },
485        AnonymousFunction {
486            input,
487            function,
488            options,
489            fmt_str: _,
490        } => {
491            let is_scalar = is_scalar_ae(expression, expr_arena);
492            let output_field = expr_arena
493                .get(expression)
494                .to_field(&ToFieldContext::new(expr_arena, schema))?;
495
496            let input =
497                create_physical_expressions_from_irs(input, ctxt, expr_arena, schema, state)?;
498
499            let function = function.clone().materialize()?;
500            let function = function.into_inner().as_column_udf();
501
502            Ok(Arc::new(ApplyExpr::new(
503                input,
504                SpecialEq::new(function),
505                None,
506                node_to_expr(expression, expr_arena),
507                *options,
508                state.allow_threading,
509                schema.clone(),
510                output_field,
511                is_scalar,
512                true,
513            )))
514        },
515        Eval {
516            expr,
517            evaluation,
518            variant,
519        } => {
520            let is_scalar = is_scalar_ae(expression, expr_arena);
521            let evaluation_is_scalar = is_scalar_ae(*evaluation, expr_arena);
522            let evaluation_is_elementwise = is_elementwise_rec(*evaluation, expr_arena);
523            // @NOTE: This is actually also something the downstream apply code should care about.
524            let mut pd_group = ExprPushdownGroup::Pushable;
525            pd_group.update_with_expr_rec(expr_arena.get(*evaluation), expr_arena, None);
526            let evaluation_is_fallible = matches!(pd_group, ExprPushdownGroup::Fallible);
527
528            let output_field = expr_arena
529                .get(expression)
530                .to_field(&ToFieldContext::new(expr_arena, schema))?;
531            let input_field = expr_arena
532                .get(*expr)
533                .to_field(&ToFieldContext::new(expr_arena, schema))?;
534            let expr =
535                create_physical_expr_inner(*expr, Context::Default, expr_arena, schema, state)?;
536
537            let element_dtype = variant.element_dtype(&input_field.dtype)?;
538            let mut eval_schema = schema.as_ref().clone();
539            eval_schema.insert(PL_ELEMENT_NAME.clone(), element_dtype.clone());
540            let evaluation = create_physical_expr_inner(
541                *evaluation,
542                // @Hack. Since EvalVariant::Array uses `evaluate_on_groups` to determine the
543                // output and that expects to be outputting a list, we need to pretend like we are
544                // aggregating here.
545                //
546                // EvalVariant::List also has this problem but that has a List datatype, so that
547                // goes wrong by pure change and some black magic.
548                if matches!(variant, EvalVariant::Array { .. }) && !evaluation_is_elementwise {
549                    Context::Aggregation
550                } else {
551                    Context::Default
552                },
553                expr_arena,
554                &Arc::new(eval_schema),
555                state,
556            )?;
557
558            Ok(Arc::new(EvalExpr::new(
559                expr,
560                evaluation,
561                *variant,
562                node_to_expr(expression, expr_arena),
563                output_field,
564                is_scalar,
565                evaluation_is_scalar,
566                evaluation_is_elementwise,
567                evaluation_is_fallible,
568            )))
569        },
570        Function {
571            input,
572            function,
573            options,
574        } => {
575            let is_scalar = is_scalar_ae(expression, expr_arena);
576            let output_field = expr_arena
577                .get(expression)
578                .to_field(&ToFieldContext::new(expr_arena, schema))?;
579            let input =
580                create_physical_expressions_from_irs(input, ctxt, expr_arena, schema, state)?;
581            let is_fallible = expr_arena.get(expression).is_fallible_top_level(expr_arena);
582
583            Ok(Arc::new(ApplyExpr::new(
584                input,
585                function_expr_to_udf(function.clone()),
586                function_expr_to_groups_udf(function),
587                node_to_expr(expression, expr_arena),
588                *options,
589                state.allow_threading,
590                schema.clone(),
591                output_field,
592                is_scalar,
593                is_fallible,
594            )))
595        },
596        Slice {
597            input,
598            offset,
599            length,
600        } => {
601            let input = create_physical_expr_inner(*input, ctxt, expr_arena, schema, state)?;
602            let offset = create_physical_expr_inner(*offset, ctxt, expr_arena, schema, state)?;
603            let length = create_physical_expr_inner(*length, ctxt, expr_arena, schema, state)?;
604            polars_ensure!(!(state.has_implode() && matches!(ctxt, Context::Aggregation)),
605                InvalidOperation: "'implode' followed by a slice during aggregation is not allowed");
606            Ok(Arc::new(SliceExpr {
607                input,
608                offset,
609                length,
610                expr: node_to_expr(expression, expr_arena),
611            }))
612        },
613        Explode { expr, skip_empty } => {
614            let input = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
615            let skip_empty = *skip_empty;
616            let function = SpecialEq::new(Arc::new(
617                move |c: &mut [polars_core::frame::column::Column]| c[0].explode(skip_empty),
618            ) as Arc<dyn ColumnsUdf>);
619
620            let output_field = expr_arena
621                .get(expression)
622                .to_field(&ToFieldContext::new(expr_arena, schema))?;
623
624            Ok(Arc::new(ApplyExpr::new(
625                vec![input],
626                function,
627                None,
628                node_to_expr(expression, expr_arena),
629                FunctionOptions::groupwise(),
630                state.allow_threading,
631                schema.clone(),
632                output_field,
633                false,
634                false,
635            )))
636        },
637        AnonymousStreamingAgg { .. } => {
638            polars_bail!(ComputeError: "anonymous agg not supported in in-memory engine")
639        },
640    }
641}