polars_expr/
planner.rs

1use polars_core::prelude::*;
2use polars_plan::prelude::expr_ir::ExprIR;
3use polars_plan::prelude::*;
4
5use crate::expressions as phys_expr;
6use crate::expressions::*;
7
8pub fn get_expr_depth_limit() -> PolarsResult<u16> {
9    let depth = if let Ok(d) = std::env::var("POLARS_MAX_EXPR_DEPTH") {
10        let v = d
11            .parse::<u64>()
12            .map_err(|_| polars_err!(ComputeError: "could not parse 'max_expr_depth': {}", d))?;
13        u16::try_from(v).unwrap_or(0)
14    } else {
15        512
16    };
17    Ok(depth)
18}
19
20fn ok_checker(_state: &ExpressionConversionState) -> PolarsResult<()> {
21    Ok(())
22}
23
24pub fn create_physical_expressions_from_irs(
25    exprs: &[ExprIR],
26    context: Context,
27    expr_arena: &Arena<AExpr>,
28    schema: &SchemaRef,
29    state: &mut ExpressionConversionState,
30) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
31    create_physical_expressions_check_state(exprs, context, expr_arena, schema, state, ok_checker)
32}
33
34pub(crate) fn create_physical_expressions_check_state<F>(
35    exprs: &[ExprIR],
36    context: Context,
37    expr_arena: &Arena<AExpr>,
38    schema: &SchemaRef,
39    state: &mut ExpressionConversionState,
40    checker: F,
41) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
42where
43    F: Fn(&ExpressionConversionState) -> PolarsResult<()>,
44{
45    exprs
46        .iter()
47        .map(|e| {
48            state.reset();
49            let out = create_physical_expr(e, context, expr_arena, schema, state);
50            checker(state)?;
51            out
52        })
53        .collect()
54}
55
56pub(crate) fn create_physical_expressions_from_nodes(
57    exprs: &[Node],
58    context: Context,
59    expr_arena: &Arena<AExpr>,
60    schema: &SchemaRef,
61    state: &mut ExpressionConversionState,
62) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
63    create_physical_expressions_from_nodes_check_state(
64        exprs, context, expr_arena, schema, state, ok_checker,
65    )
66}
67
68pub(crate) fn create_physical_expressions_from_nodes_check_state<F>(
69    exprs: &[Node],
70    context: Context,
71    expr_arena: &Arena<AExpr>,
72    schema: &SchemaRef,
73    state: &mut ExpressionConversionState,
74    checker: F,
75) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
76where
77    F: Fn(&ExpressionConversionState) -> PolarsResult<()>,
78{
79    exprs
80        .iter()
81        .map(|e| {
82            state.reset();
83            let out = create_physical_expr_inner(*e, context, expr_arena, schema, state);
84            checker(state)?;
85            out
86        })
87        .collect()
88}
89
90#[derive(Copy, Clone)]
91pub struct ExpressionConversionState {
92    // settings per context
93    // they remain activate between
94    // expressions
95    pub allow_threading: bool,
96    pub has_windows: bool,
97    // settings per expression
98    // those are reset every expression
99    local: LocalConversionState,
100    depth_limit: u16,
101}
102
103#[derive(Copy, Clone)]
104struct LocalConversionState {
105    has_implode: bool,
106    has_window: bool,
107    has_lit: bool,
108    // Max depth an expression may have.
109    // 0 is unlimited.
110    depth_limit: u16,
111}
112
113impl Default for LocalConversionState {
114    fn default() -> Self {
115        Self {
116            has_lit: false,
117            has_implode: false,
118            has_window: false,
119            depth_limit: 500,
120        }
121    }
122}
123
124impl ExpressionConversionState {
125    pub fn new(allow_threading: bool, depth_limit: u16) -> Self {
126        Self {
127            depth_limit,
128            allow_threading,
129            has_windows: false,
130            local: LocalConversionState {
131                depth_limit,
132                ..Default::default()
133            },
134        }
135    }
136    fn reset(&mut self) {
137        self.local = LocalConversionState {
138            depth_limit: self.depth_limit,
139            ..Default::default()
140        }
141    }
142
143    fn has_implode(&self) -> bool {
144        self.local.has_implode
145    }
146
147    fn set_window(&mut self) {
148        self.has_windows = true;
149        self.local.has_window = true;
150    }
151
152    fn check_depth(&mut self) {
153        if self.local.depth_limit > 0 {
154            self.local.depth_limit -= 1;
155
156            if self.local.depth_limit == 0 {
157                let depth = get_expr_depth_limit().unwrap();
158                polars_warn!(format!("encountered expression deeper than {depth} elements; this may overflow the stack, consider refactoring"))
159            }
160        }
161    }
162}
163
164pub fn create_physical_expr(
165    expr_ir: &ExprIR,
166    ctxt: Context,
167    expr_arena: &Arena<AExpr>,
168    schema: &SchemaRef,
169    state: &mut ExpressionConversionState,
170) -> PolarsResult<Arc<dyn PhysicalExpr>> {
171    let phys_expr = create_physical_expr_inner(expr_ir.node(), ctxt, expr_arena, schema, state)?;
172
173    if let Some(name) = expr_ir.get_alias() {
174        Ok(Arc::new(AliasExpr::new(
175            phys_expr,
176            name.clone(),
177            node_to_expr(expr_ir.node(), expr_arena),
178        )))
179    } else {
180        Ok(phys_expr)
181    }
182}
183
184fn create_physical_expr_inner(
185    expression: Node,
186    ctxt: Context,
187    expr_arena: &Arena<AExpr>,
188    schema: &SchemaRef,
189    state: &mut ExpressionConversionState,
190) -> PolarsResult<Arc<dyn PhysicalExpr>> {
191    use AExpr::*;
192
193    state.check_depth();
194
195    match expr_arena.get(expression) {
196        Len => Ok(Arc::new(phys_expr::CountExpr::new())),
197        Window {
198            mut function,
199            partition_by,
200            order_by,
201            options,
202        } => {
203            state.set_window();
204            let phys_function = create_physical_expr_inner(
205                function,
206                Context::Aggregation,
207                expr_arena,
208                schema,
209                state,
210            )?;
211
212            let order_by = order_by
213                .map(|(node, options)| {
214                    PolarsResult::Ok((
215                        create_physical_expr_inner(
216                            node,
217                            Context::Aggregation,
218                            expr_arena,
219                            schema,
220                            state,
221                        )?,
222                        options,
223                    ))
224                })
225                .transpose()?;
226
227            let mut out_name = None;
228            if let Alias(expr, name) = expr_arena.get(function) {
229                function = *expr;
230                out_name = Some(name.clone());
231            };
232            let function_expr = node_to_expr(function, expr_arena);
233            let expr = node_to_expr(expression, expr_arena);
234
235            // set again as the state can be reset
236            state.set_window();
237            match options {
238                WindowType::Over(mapping) => {
239                    // TODO! Order by
240                    let group_by = create_physical_expressions_from_nodes(
241                        partition_by,
242                        Context::Aggregation,
243                        expr_arena,
244                        schema,
245                        state,
246                    )?;
247                    let mut apply_columns = aexpr_to_leaf_names(function, expr_arena);
248                    // sort and then dedup removes consecutive duplicates == all duplicates
249                    apply_columns.sort();
250                    apply_columns.dedup();
251
252                    if apply_columns.is_empty() {
253                        if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Literal(_))) {
254                            apply_columns.push(PlSmallStr::from_static("literal"))
255                        } else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Len)) {
256                            apply_columns.push(PlSmallStr::from_static("len"))
257                        } else {
258                            let e = node_to_expr(function, expr_arena);
259                            polars_bail!(
260                                ComputeError:
261                                "cannot apply a window function, did not find a root column; \
262                                this is likely due to a syntax error in this expression: {:?}", e
263                            );
264                        }
265                    }
266
267                    Ok(Arc::new(WindowExpr {
268                        group_by,
269                        order_by,
270                        apply_columns,
271                        out_name,
272                        function: function_expr,
273                        phys_function,
274                        mapping: *mapping,
275                        expr,
276                    }))
277                },
278                #[cfg(feature = "dynamic_group_by")]
279                WindowType::Rolling(options) => Ok(Arc::new(RollingExpr {
280                    function: function_expr,
281                    phys_function,
282                    out_name,
283                    options: options.clone(),
284                    expr,
285                })),
286            }
287        },
288        Literal(value) => {
289            state.local.has_lit = true;
290            Ok(Arc::new(LiteralExpr::new(
291                value.clone(),
292                node_to_expr(expression, expr_arena),
293            )))
294        },
295        BinaryExpr { left, op, right } => {
296            let is_scalar = is_scalar_ae(expression, expr_arena);
297            let lhs = create_physical_expr_inner(*left, ctxt, expr_arena, schema, state)?;
298            let rhs = create_physical_expr_inner(*right, ctxt, expr_arena, schema, state)?;
299            Ok(Arc::new(phys_expr::BinaryExpr::new(
300                lhs,
301                *op,
302                rhs,
303                node_to_expr(expression, expr_arena),
304                state.local.has_lit,
305                state.allow_threading,
306                is_scalar,
307            )))
308        },
309        Column(column) => Ok(Arc::new(ColumnExpr::new(
310            column.clone(),
311            node_to_expr(expression, expr_arena),
312            schema.clone(),
313        ))),
314        Sort { expr, options } => {
315            let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
316            Ok(Arc::new(SortExpr::new(
317                phys_expr,
318                *options,
319                node_to_expr(expression, expr_arena),
320            )))
321        },
322        Gather {
323            expr,
324            idx,
325            returns_scalar,
326        } => {
327            let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
328            let phys_idx = create_physical_expr_inner(*idx, ctxt, expr_arena, schema, state)?;
329            Ok(Arc::new(GatherExpr {
330                phys_expr,
331                idx: phys_idx,
332                expr: node_to_expr(expression, expr_arena),
333                returns_scalar: *returns_scalar,
334            }))
335        },
336        SortBy {
337            expr,
338            by,
339            sort_options,
340        } => {
341            let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
342            let phys_by =
343                create_physical_expressions_from_nodes(by, ctxt, expr_arena, schema, state)?;
344            Ok(Arc::new(SortByExpr::new(
345                phys_expr,
346                phys_by,
347                node_to_expr(expression, expr_arena),
348                sort_options.clone(),
349            )))
350        },
351        Filter { input, by } => {
352            let phys_input = create_physical_expr_inner(*input, ctxt, expr_arena, schema, state)?;
353            let phys_by = create_physical_expr_inner(*by, ctxt, expr_arena, schema, state)?;
354            Ok(Arc::new(FilterExpr::new(
355                phys_input,
356                phys_by,
357                node_to_expr(expression, expr_arena),
358            )))
359        },
360        Agg(agg) => {
361            let expr = agg.get_input().first();
362            let input = create_physical_expr_inner(expr, ctxt, expr_arena, schema, state)?;
363            polars_ensure!(!(state.has_implode() && matches!(ctxt, Context::Aggregation)), InvalidOperation: "'implode' followed by an aggregation is not allowed");
364            state.local.has_implode |= matches!(agg, IRAggExpr::Implode(_));
365            let allow_threading = state.allow_threading;
366
367            match ctxt {
368                Context::Default if !matches!(agg, IRAggExpr::Quantile { .. }) => {
369                    use {GroupByMethod as GBM, IRAggExpr as I};
370
371                    let groupby = match agg {
372                        I::Min { propagate_nans, .. } if *propagate_nans => GBM::NanMin,
373                        I::Min { .. } => GBM::Min,
374                        I::Max { propagate_nans, .. } if *propagate_nans => GBM::NanMax,
375                        I::Max { .. } => GBM::Max,
376                        I::Median(_) => GBM::Median,
377                        I::NUnique(_) => GBM::NUnique,
378                        I::First(_) => GBM::First,
379                        I::Last(_) => GBM::Last,
380                        I::Mean(_) => GBM::Mean,
381                        I::Implode(_) => GBM::Implode,
382                        I::Quantile { .. } => unreachable!(),
383                        I::Sum(_) => GBM::Sum,
384                        I::Count(_, include_nulls) => GBM::Count {
385                            include_nulls: *include_nulls,
386                        },
387                        I::Std(_, ddof) => GBM::Std(*ddof),
388                        I::Var(_, ddof) => GBM::Var(*ddof),
389                        I::AggGroups(_) => {
390                            polars_bail!(InvalidOperation: "agg groups expression only supported in aggregation context")
391                        },
392                    };
393
394                    let agg_type = AggregationType {
395                        groupby,
396                        allow_threading,
397                    };
398
399                    Ok(Arc::new(AggregationExpr::new(input, agg_type, None)))
400                },
401                _ => {
402                    if let IRAggExpr::Quantile {
403                        quantile,
404                        method: interpol,
405                        ..
406                    } = agg
407                    {
408                        let quantile =
409                            create_physical_expr_inner(*quantile, ctxt, expr_arena, schema, state)?;
410                        return Ok(Arc::new(AggQuantileExpr::new(input, quantile, *interpol)));
411                    }
412
413                    let field = expr_arena.get(expression).to_field(
414                        schema,
415                        Context::Aggregation,
416                        expr_arena,
417                    )?;
418
419                    let groupby = GroupByMethod::from(agg.clone());
420                    let agg_type = AggregationType {
421                        groupby,
422                        allow_threading: false,
423                    };
424                    Ok(Arc::new(AggregationExpr::new(input, agg_type, Some(field))))
425                },
426            }
427        },
428        Cast {
429            expr,
430            dtype,
431            options,
432        } => {
433            let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
434            Ok(Arc::new(CastExpr {
435                input: phys_expr,
436                dtype: dtype.clone(),
437                expr: node_to_expr(expression, expr_arena),
438                options: *options,
439                inlined_eval: Default::default(),
440            }))
441        },
442        Ternary {
443            predicate,
444            truthy,
445            falsy,
446        } => {
447            let is_scalar = is_scalar_ae(expression, expr_arena);
448            let mut lit_count = 0u8;
449            state.reset();
450            let predicate =
451                create_physical_expr_inner(*predicate, ctxt, expr_arena, schema, state)?;
452            lit_count += state.local.has_lit as u8;
453            state.reset();
454            let truthy = create_physical_expr_inner(*truthy, ctxt, expr_arena, schema, state)?;
455            lit_count += state.local.has_lit as u8;
456            state.reset();
457            let falsy = create_physical_expr_inner(*falsy, ctxt, expr_arena, schema, state)?;
458            lit_count += state.local.has_lit as u8;
459            Ok(Arc::new(TernaryExpr::new(
460                predicate,
461                truthy,
462                falsy,
463                node_to_expr(expression, expr_arena),
464                state.allow_threading && lit_count < 2,
465                is_scalar,
466            )))
467        },
468        AnonymousFunction {
469            input,
470            function,
471            output_type: _,
472            options,
473        } => {
474            let is_scalar = is_scalar_ae(expression, expr_arena);
475            let output_field = expr_arena
476                .get(expression)
477                .to_field(schema, ctxt, expr_arena)?;
478
479            let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR)
480                && matches!(options.collect_groups, ApplyOptions::GroupWise);
481            // Will be reset in the function so get that here.
482            let has_window = state.local.has_window;
483            let input = create_physical_expressions_check_state(
484                input,
485                ctxt,
486                expr_arena,
487                schema,
488                state,
489                |state| {
490                    polars_ensure!(!((is_reducing_aggregation || has_window) && state.has_implode() && matches!(ctxt, Context::Aggregation)), InvalidOperation: "'implode' followed by an aggregation is not allowed");
491                    Ok(())
492                },
493            )?;
494
495            Ok(Arc::new(ApplyExpr::new(
496                input,
497                function.clone().materialize()?,
498                node_to_expr(expression, expr_arena),
499                *options,
500                state.allow_threading,
501                schema.clone(),
502                output_field,
503                is_scalar,
504            )))
505        },
506        Function {
507            input,
508            function,
509            options,
510        } => {
511            let is_scalar = is_scalar_ae(expression, expr_arena);
512            let output_field = expr_arena
513                .get(expression)
514                .to_field(schema, ctxt, expr_arena)?;
515            let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR)
516                && matches!(options.collect_groups, ApplyOptions::GroupWise);
517            // Will be reset in the function so get that here.
518            let has_window = state.local.has_window;
519            let input = create_physical_expressions_check_state(
520                input,
521                ctxt,
522                expr_arena,
523                schema,
524                state,
525                |state| {
526                    polars_ensure!(!((is_reducing_aggregation || has_window) && state.has_implode() && matches!(ctxt, Context::Aggregation)), InvalidOperation: "'implode' followed by an aggregation is not allowed");
527                    Ok(())
528                },
529            )?;
530
531            Ok(Arc::new(ApplyExpr::new(
532                input,
533                function.clone().into(),
534                node_to_expr(expression, expr_arena),
535                *options,
536                state.allow_threading,
537                schema.clone(),
538                output_field,
539                is_scalar,
540            )))
541        },
542        Slice {
543            input,
544            offset,
545            length,
546        } => {
547            let input = create_physical_expr_inner(*input, ctxt, expr_arena, schema, state)?;
548            let offset = create_physical_expr_inner(*offset, ctxt, expr_arena, schema, state)?;
549            let length = create_physical_expr_inner(*length, ctxt, expr_arena, schema, state)?;
550            polars_ensure!(!(state.has_implode() && matches!(ctxt, Context::Aggregation)), InvalidOperation: "'implode' followed by a slice during aggregation is not allowed");
551            Ok(Arc::new(SliceExpr {
552                input,
553                offset,
554                length,
555                expr: node_to_expr(expression, expr_arena),
556            }))
557        },
558        Explode(expr) => {
559            let input = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
560            let function = SpecialEq::new(Arc::new(
561                move |c: &mut [polars_core::frame::column::Column]| c[0].explode().map(Some),
562            ) as Arc<dyn ColumnsUdf>);
563
564            let field = expr_arena
565                .get(expression)
566                .to_field(schema, ctxt, expr_arena)?;
567
568            Ok(Arc::new(ApplyExpr::new(
569                vec![input],
570                function,
571                node_to_expr(expression, expr_arena),
572                FunctionOptions {
573                    collect_groups: ApplyOptions::GroupWise,
574                    ..Default::default()
575                },
576                state.allow_threading,
577                schema.clone(),
578                field,
579                false,
580            )))
581        },
582        Alias(input, name) => {
583            let phys_expr = create_physical_expr_inner(*input, ctxt, expr_arena, schema, state)?;
584            Ok(Arc::new(AliasExpr::new(
585                phys_expr,
586                name.clone(),
587                node_to_expr(*input, expr_arena),
588            )))
589        },
590    }
591}