polars_plan/plans/conversion/
expr_to_ir.rs

1use super::*;
2use crate::plans::conversion::functions::convert_functions;
3
4pub fn to_expr_ir(expr: Expr, arena: &mut Arena<AExpr>, schema: &Schema) -> PolarsResult<ExprIR> {
5    let (node, output_name) = to_aexpr_impl(expr, arena, schema)?;
6    Ok(ExprIR::new(node, OutputName::Alias(output_name)))
7}
8
9pub fn to_expr_ir_materialized_lit(
10    expr: Expr,
11    arena: &mut Arena<AExpr>,
12    schema: &Schema,
13) -> PolarsResult<ExprIR> {
14    let (node, output_name) = to_aexpr_impl_materialized_lit(expr, arena, schema)?;
15    Ok(ExprIR::new(node, OutputName::Alias(output_name)))
16}
17
18pub(super) fn to_expr_irs(
19    input: Vec<Expr>,
20    arena: &mut Arena<AExpr>,
21    schema: &Schema,
22) -> PolarsResult<Vec<ExprIR>> {
23    input
24        .into_iter()
25        .map(|e| to_expr_ir(e, arena, schema))
26        .collect()
27}
28
29fn to_aexpr_impl_materialized_lit(
30    expr: Expr,
31    arena: &mut Arena<AExpr>,
32    schema: &Schema,
33) -> PolarsResult<(Node, PlSmallStr)> {
34    // Already convert `Lit Float and Lit Int` expressions that are not used in a binary / function expression.
35    // This means they can be materialized immediately
36    let e = match expr {
37        Expr::Literal(lv @ LiteralValue::Dyn(_)) => Expr::Literal(lv.materialize()),
38        Expr::Alias(inner, name) if matches!(&*inner, Expr::Literal(LiteralValue::Dyn(_))) => {
39            let Expr::Literal(lv) = &*inner else {
40                unreachable!()
41            };
42            Expr::Alias(Arc::new(Expr::Literal(lv.clone().materialize())), name)
43        },
44        e => e,
45    };
46    to_aexpr_impl(e, arena, schema)
47}
48
49/// Converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation.
50#[recursive]
51pub(super) fn to_aexpr_impl(
52    expr: Expr,
53    arena: &mut Arena<AExpr>,
54    schema: &Schema,
55) -> PolarsResult<(Node, PlSmallStr)> {
56    let owned = Arc::unwrap_or_clone;
57    let (v, output_name) = match expr {
58        Expr::Explode { input, skip_empty } => {
59            let (expr, output_name) = to_aexpr_impl(owned(input), arena, schema)?;
60            (AExpr::Explode { expr, skip_empty }, output_name)
61        },
62        Expr::Alias(e, name) => return Ok((to_aexpr_impl(owned(e), arena, schema)?.0, name)),
63        Expr::Literal(lv) => {
64            let output_name = lv.output_column_name().clone();
65            (AExpr::Literal(lv), output_name)
66        },
67        Expr::Column(name) => (AExpr::Column(name.clone()), name),
68        Expr::BinaryExpr { left, op, right } => {
69            let (l, output_name) = to_aexpr_impl(owned(left), arena, schema)?;
70            let (r, _) = to_aexpr_impl(owned(right), arena, schema)?;
71            (
72                AExpr::BinaryExpr {
73                    left: l,
74                    op,
75                    right: r,
76                },
77                output_name,
78            )
79        },
80        Expr::Cast {
81            expr,
82            dtype,
83            options,
84        } => {
85            let (expr, output_name) = to_aexpr_impl(owned(expr), arena, schema)?;
86            (
87                AExpr::Cast {
88                    expr,
89                    dtype: dtype.into_datatype(schema)?,
90                    options,
91                },
92                output_name,
93            )
94        },
95        Expr::Gather {
96            expr,
97            idx,
98            returns_scalar,
99        } => {
100            let (expr, output_name) = to_aexpr_impl(owned(expr), arena, schema)?;
101            let (idx, _) = to_aexpr_impl_materialized_lit(owned(idx), arena, schema)?;
102            (
103                AExpr::Gather {
104                    expr,
105                    idx,
106                    returns_scalar,
107                },
108                output_name,
109            )
110        },
111        Expr::Sort { expr, options } => {
112            let (expr, output_name) = to_aexpr_impl(owned(expr), arena, schema)?;
113            (AExpr::Sort { expr, options }, output_name)
114        },
115        Expr::SortBy {
116            expr,
117            by,
118            sort_options,
119        } => {
120            let (expr, output_name) = to_aexpr_impl(owned(expr), arena, schema)?;
121            let by = by
122                .into_iter()
123                .map(|e| Ok(to_aexpr_impl(e, arena, schema)?.0))
124                .collect::<PolarsResult<_>>()?;
125
126            (
127                AExpr::SortBy {
128                    expr,
129                    by,
130                    sort_options,
131                },
132                output_name,
133            )
134        },
135        Expr::Filter { input, by } => {
136            let (input, output_name) = to_aexpr_impl(owned(input), arena, schema)?;
137            let (by, _) = to_aexpr_impl(owned(by), arena, schema)?;
138            (AExpr::Filter { input, by }, output_name)
139        },
140        Expr::Agg(agg) => {
141            let (a_agg, output_name) = match agg {
142                AggExpr::Min {
143                    input,
144                    propagate_nans,
145                } => {
146                    let (input, output_name) =
147                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
148                    (
149                        IRAggExpr::Min {
150                            input,
151                            propagate_nans,
152                        },
153                        output_name,
154                    )
155                },
156                AggExpr::Max {
157                    input,
158                    propagate_nans,
159                } => {
160                    let (input, output_name) =
161                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
162                    (
163                        IRAggExpr::Max {
164                            input,
165                            propagate_nans,
166                        },
167                        output_name,
168                    )
169                },
170                AggExpr::Median(input) => {
171                    let (input, output_name) =
172                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
173                    (IRAggExpr::Median(input), output_name)
174                },
175                AggExpr::NUnique(input) => {
176                    let (input, output_name) =
177                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
178                    (IRAggExpr::NUnique(input), output_name)
179                },
180                AggExpr::First(input) => {
181                    let (input, output_name) =
182                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
183                    (IRAggExpr::First(input), output_name)
184                },
185                AggExpr::Last(input) => {
186                    let (input, output_name) =
187                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
188                    (IRAggExpr::Last(input), output_name)
189                },
190                AggExpr::Mean(input) => {
191                    let (input, output_name) =
192                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
193                    (IRAggExpr::Mean(input), output_name)
194                },
195                AggExpr::Implode(input) => {
196                    let (input, output_name) =
197                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
198                    (IRAggExpr::Implode(input), output_name)
199                },
200                AggExpr::Count(input, include_nulls) => {
201                    let (input, output_name) =
202                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
203                    (IRAggExpr::Count(input, include_nulls), output_name)
204                },
205                AggExpr::Quantile {
206                    expr,
207                    quantile,
208                    method,
209                } => {
210                    let (expr, output_name) =
211                        to_aexpr_impl_materialized_lit(owned(expr), arena, schema)?;
212                    let (quantile, _) =
213                        to_aexpr_impl_materialized_lit(owned(quantile), arena, schema)?;
214                    (
215                        IRAggExpr::Quantile {
216                            expr,
217                            quantile,
218                            method,
219                        },
220                        output_name,
221                    )
222                },
223                AggExpr::Sum(input) => {
224                    let (input, output_name) =
225                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
226                    (IRAggExpr::Sum(input), output_name)
227                },
228                AggExpr::Std(input, ddof) => {
229                    let (input, output_name) =
230                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
231                    (IRAggExpr::Std(input, ddof), output_name)
232                },
233                AggExpr::Var(input, ddof) => {
234                    let (input, output_name) =
235                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
236                    (IRAggExpr::Var(input, ddof), output_name)
237                },
238                AggExpr::AggGroups(input) => {
239                    let (input, output_name) =
240                        to_aexpr_impl_materialized_lit(owned(input), arena, schema)?;
241                    (IRAggExpr::AggGroups(input), output_name)
242                },
243            };
244            (AExpr::Agg(a_agg), output_name)
245        },
246        Expr::Ternary {
247            predicate,
248            truthy,
249            falsy,
250        } => {
251            let (p, _) = to_aexpr_impl_materialized_lit(owned(predicate), arena, schema)?;
252            let (t, output_name) = to_aexpr_impl(owned(truthy), arena, schema)?;
253            let (f, _) = to_aexpr_impl(owned(falsy), arena, schema)?;
254            (
255                AExpr::Ternary {
256                    predicate: p,
257                    truthy: t,
258                    falsy: f,
259                },
260                output_name,
261            )
262        },
263        Expr::AnonymousFunction {
264            input,
265            function,
266            output_type,
267            options,
268            fmt_str,
269        } => {
270            let e = to_expr_irs(input, arena, schema)?;
271            let output_name = if e.is_empty() {
272                fmt_str.as_ref().clone()
273            } else {
274                e[0].output_name().clone()
275            };
276
277            let function = function.materialize()?;
278            let output_type = output_type.materialize()?;
279            function.as_ref().resolve_dsl(schema)?;
280            output_type.as_ref().resolve_dsl(schema)?;
281
282            (
283                AExpr::AnonymousFunction {
284                    input: e,
285                    function: LazySerde::Deserialized(function),
286                    output_type: LazySerde::Deserialized(output_type),
287                    options,
288                    fmt_str,
289                },
290                output_name,
291            )
292        },
293        Expr::Function { input, function } => {
294            return convert_functions(input, function, arena, schema);
295        },
296        Expr::Window {
297            function,
298            partition_by,
299            order_by,
300            options,
301        } => {
302            let (function, output_name) = to_aexpr_impl(owned(function), arena, schema)?;
303            let order_by = if let Some((e, options)) = order_by {
304                Some((to_aexpr_impl(owned(e.clone()), arena, schema)?.0, options))
305            } else {
306                None
307            };
308
309            (
310                AExpr::Window {
311                    function,
312                    partition_by: partition_by
313                        .into_iter()
314                        .map(|e| Ok(to_aexpr_impl_materialized_lit(e, arena, schema)?.0))
315                        .collect::<PolarsResult<_>>()?,
316                    order_by,
317                    options,
318                },
319                output_name,
320            )
321        },
322        Expr::Slice {
323            input,
324            offset,
325            length,
326        } => {
327            let (input, output_name) = to_aexpr_impl(owned(input), arena, schema)?;
328            let (offset, _) = to_aexpr_impl_materialized_lit(owned(offset), arena, schema)?;
329            let (length, _) = to_aexpr_impl_materialized_lit(owned(length), arena, schema)?;
330            (
331                AExpr::Slice {
332                    input,
333                    offset,
334                    length,
335                },
336                output_name,
337            )
338        },
339        Expr::Eval {
340            expr,
341            evaluation,
342            variant,
343        } => {
344            let (expr, output_name) = to_aexpr_impl(owned(expr), arena, schema)?;
345            let expr_dtype = arena.get(expr).to_dtype(schema, Context::Default, arena)?;
346            let element_dtype = variant.element_dtype(&expr_dtype)?;
347            let evaluation_schema = Schema::from_iter([(PlSmallStr::EMPTY, element_dtype.clone())]);
348            let (evaluation, _) = to_aexpr_impl(owned(evaluation), arena, &evaluation_schema)?;
349
350            match variant {
351                EvalVariant::List => {
352                    for (_, e) in ArenaExprIter::iter(&&*arena, evaluation) {
353                        match e {
354                            #[cfg(feature = "dtype-categorical")]
355                            AExpr::Cast {
356                                dtype: DataType::Categorical(_, _) | DataType::Enum(_, _),
357                                ..
358                            } => {
359                                polars_bail!(
360                                    ComputeError: "casting to categorical not allowed in `list.eval`"
361                                )
362                            },
363                            AExpr::Column(name) => {
364                                polars_ensure!(
365                                    name.is_empty(),
366                                    ComputeError:
367                                    "named columns are not allowed in `list.eval`; consider using `element` or `col(\"\")`"
368                                );
369                            },
370                            _ => {},
371                        }
372                    }
373                },
374                EvalVariant::Cumulative { .. } => {
375                    polars_ensure!(
376                        is_scalar_ae(evaluation, arena),
377                        InvalidOperation: "`cumulative_eval` is not allowed with non-scalar output"
378                    )
379                },
380            }
381
382            (
383                AExpr::Eval {
384                    expr,
385                    evaluation,
386                    variant,
387                },
388                output_name,
389            )
390        },
391        Expr::Len => (AExpr::Len, get_len_name()),
392        #[cfg(feature = "dtype-struct")]
393        e @ Expr::Field(_) => {
394            polars_bail!(InvalidOperation: "'Expr: {}' not allowed in this context/location", e)
395        },
396        e @ Expr::IndexColumn(_)
397        | e @ Expr::Wildcard
398        | e @ Expr::Nth(_)
399        | e @ Expr::SubPlan { .. }
400        | e @ Expr::KeepName(_)
401        | e @ Expr::Exclude(_, _)
402        | e @ Expr::RenameAlias { .. }
403        | e @ Expr::Columns { .. }
404        | e @ Expr::DtypeColumn { .. }
405        | e @ Expr::Selector(_) => {
406            polars_bail!(InvalidOperation: "'Expr: {}' not allowed in this context/location", e)
407        },
408    };
409    Ok((arena.add(v), output_name))
410}