polars_plan/plans/
builder_ir.rs

1use std::borrow::Cow;
2
3use super::*;
4
5pub struct IRBuilder<'a> {
6    root: Node,
7    expr_arena: &'a mut Arena<AExpr>,
8    lp_arena: &'a mut Arena<IR>,
9}
10
11impl<'a> IRBuilder<'a> {
12    pub fn new(root: Node, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {
13        IRBuilder {
14            root,
15            expr_arena,
16            lp_arena,
17        }
18    }
19
20    pub fn from_lp(lp: IR, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {
21        let root = lp_arena.add(lp);
22        IRBuilder {
23            root,
24            expr_arena,
25            lp_arena,
26        }
27    }
28
29    pub fn add_alp(self, lp: IR) -> Self {
30        let node = self.lp_arena.add(lp);
31        IRBuilder::new(node, self.expr_arena, self.lp_arena)
32    }
33
34    /// Adds IR and runs optimizations on its expressions (simplify, coerce, type-check).
35    pub fn add_alp_optimize_exprs<F>(self, f: F) -> PolarsResult<Self>
36    where
37        F: FnOnce(Node) -> IR,
38    {
39        let lp = f(self.root);
40        let ir_name = lp.name();
41
42        let b = self.add_alp(lp);
43
44        // Run the optimizer
45        let mut conversion_optimizer = ConversionOptimizer::new(true, true, true);
46        conversion_optimizer.fill_scratch(b.lp_arena.get(b.root).exprs(), b.expr_arena);
47        conversion_optimizer
48            .optimize_exprs(b.expr_arena, b.lp_arena, b.root, false)
49            .map_err(|e| e.context(format!("optimizing '{ir_name}' failed").into()))?;
50
51        Ok(b)
52    }
53
54    /// An escape hatch to add an `Expr`. Working with IR is preferred.
55    pub fn add_expr(&mut self, expr: Expr) -> PolarsResult<ExprIR> {
56        let schema = self.lp_arena.get(self.root).schema(self.lp_arena);
57        let mut ctx = ExprToIRContext::new(self.expr_arena, &schema);
58        to_expr_ir(expr, &mut ctx)
59    }
60
61    pub fn project(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {
62        // if len == 0, no projection has to be done. This is a select all operation.
63        if exprs.is_empty() {
64            self
65        } else {
66            let input_schema = self.schema();
67            let schema = expr_irs_to_schema(&exprs, &input_schema, self.expr_arena);
68
69            let lp = IR::Select {
70                expr: exprs,
71                input: self.root,
72                schema: Arc::new(schema),
73                options,
74            };
75            let node = self.lp_arena.add(lp);
76            IRBuilder::new(node, self.expr_arena, self.lp_arena)
77        }
78    }
79
80    pub fn project_simple_nodes<I, N>(self, nodes: I) -> PolarsResult<Self>
81    where
82        I: IntoIterator<Item = N>,
83        N: Into<Node>,
84        I::IntoIter: ExactSizeIterator,
85    {
86        let names = nodes
87            .into_iter()
88            .map(|node| match self.expr_arena.get(node.into()) {
89                AExpr::Column(name) => name,
90                _ => unreachable!(),
91            });
92        // This is a duplication of `project_simple` because we already borrow self.expr_arena :/
93        if names.size_hint().0 == 0 {
94            Ok(self)
95        } else {
96            let input_schema = self.schema();
97            let mut count = 0;
98            let schema = names
99                .map(|name| {
100                    let dtype = input_schema.try_get(name)?;
101                    count += 1;
102                    Ok(Field::new(name.clone(), dtype.clone()))
103                })
104                .collect::<PolarsResult<Schema>>()?;
105
106            polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");
107
108            let lp = IR::SimpleProjection {
109                input: self.root,
110                columns: Arc::new(schema),
111            };
112            let node = self.lp_arena.add(lp);
113            Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))
114        }
115    }
116
117    pub fn project_simple<I, S>(self, names: I) -> PolarsResult<Self>
118    where
119        I: IntoIterator<Item = S>,
120        I::IntoIter: ExactSizeIterator,
121        S: Into<PlSmallStr>,
122    {
123        let names = names.into_iter();
124        // if len == 0, no projection has to be done. This is a select all operation.
125        if names.size_hint().0 == 0 {
126            Ok(self)
127        } else {
128            let input_schema = self.schema();
129            let mut count = 0;
130            let schema = names
131                .map(|name| {
132                    let name: PlSmallStr = name.into();
133                    let dtype = input_schema.try_get(name.as_str())?;
134                    count += 1;
135                    Ok(Field::new(name, dtype.clone()))
136                })
137                .collect::<PolarsResult<Schema>>()?;
138
139            polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");
140
141            let lp = IR::SimpleProjection {
142                input: self.root,
143                columns: Arc::new(schema),
144            };
145            let node = self.lp_arena.add(lp);
146            Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))
147        }
148    }
149
150    pub fn drop<I, S>(self, names: I) -> Self
151    where
152        I: IntoIterator<Item = S>,
153        I::IntoIter: ExactSizeIterator,
154        S: Into<PlSmallStr>,
155    {
156        let names = names.into_iter();
157        // if len == 0, no projection has to be done. This is a select all operation.
158        if names.size_hint().0 == 0 {
159            self
160        } else {
161            let mut schema = self.schema().as_ref().as_ref().clone();
162
163            for name in names {
164                let name: PlSmallStr = name.into();
165                schema.remove(&name);
166            }
167
168            let lp = IR::SimpleProjection {
169                input: self.root,
170                columns: Arc::new(schema),
171            };
172            let node = self.lp_arena.add(lp);
173            IRBuilder::new(node, self.expr_arena, self.lp_arena)
174        }
175    }
176
177    pub fn sort(
178        self,
179        by_column: Vec<ExprIR>,
180        slice: Option<(i64, usize)>,
181        sort_options: SortMultipleOptions,
182    ) -> Self {
183        let ir = IR::Sort {
184            input: self.root,
185            by_column,
186            slice,
187            sort_options,
188        };
189        let node = self.lp_arena.add(ir);
190        IRBuilder::new(node, self.expr_arena, self.lp_arena)
191    }
192
193    pub fn node(self) -> Node {
194        self.root
195    }
196
197    pub fn build(self) -> IR {
198        if self.root.0 == self.lp_arena.len() {
199            self.lp_arena.pop().unwrap()
200        } else {
201            self.lp_arena.take(self.root)
202        }
203    }
204
205    pub fn schema(&'a self) -> Cow<'a, SchemaRef> {
206        self.lp_arena.get(self.root).schema(self.lp_arena)
207    }
208
209    pub fn with_columns(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {
210        let schema = self.schema();
211        let mut new_schema = (**schema).clone();
212
213        let hstack_schema = expr_irs_to_schema(&exprs, &schema, self.expr_arena);
214        new_schema.merge(hstack_schema);
215
216        let lp = IR::HStack {
217            input: self.root,
218            exprs,
219            schema: Arc::new(new_schema),
220            options,
221        };
222        self.add_alp(lp)
223    }
224
225    pub fn with_columns_simple<I, J: Into<Node>>(self, exprs: I, options: ProjectionOptions) -> Self
226    where
227        I: IntoIterator<Item = J>,
228    {
229        let schema = self.schema();
230        let mut new_schema = (**schema).clone();
231
232        let iter = exprs.into_iter();
233        let mut expr_irs = Vec::with_capacity(iter.size_hint().0);
234        for node in iter {
235            let node = node.into();
236            let field = self
237                .expr_arena
238                .get(node)
239                .to_field(&schema, self.expr_arena)
240                .unwrap();
241
242            expr_irs.push(
243                ExprIR::new(node, OutputName::ColumnLhs(field.name.clone()))
244                    .with_dtype(field.dtype.clone()),
245            );
246            new_schema.with_column(field.name().clone(), field.dtype().clone());
247        }
248
249        let lp = IR::HStack {
250            input: self.root,
251            exprs: expr_irs,
252            schema: Arc::new(new_schema),
253            options,
254        };
255        self.add_alp(lp)
256    }
257
258    // call this if the schema needs to be updated
259    pub fn explode(self, columns: Arc<[PlSmallStr]>) -> Self {
260        let lp = IR::MapFunction {
261            input: self.root,
262            function: FunctionIR::Explode {
263                columns,
264                schema: Default::default(),
265            },
266        };
267        self.add_alp(lp)
268    }
269
270    pub fn group_by(
271        self,
272        keys: Vec<ExprIR>,
273        aggs: Vec<ExprIR>,
274        apply: Option<PlanCallback<DataFrame, DataFrame>>,
275        maintain_order: bool,
276        options: Arc<GroupbyOptions>,
277    ) -> Self {
278        let current_schema = self.schema();
279        let mut schema = expr_irs_to_schema(&keys, &current_schema, self.expr_arena);
280
281        #[cfg(feature = "dynamic_group_by")]
282        {
283            if let Some(options) = options.rolling.as_ref() {
284                let name = &options.index_column;
285                let dtype = current_schema.get(name).unwrap();
286                schema.with_column(name.clone(), dtype.clone());
287            } else if let Some(options) = options.dynamic.as_ref() {
288                let name = &options.index_column;
289                let dtype = current_schema.get(name).unwrap();
290                if options.include_boundaries {
291                    schema.with_column("_lower_boundary".into(), dtype.clone());
292                    schema.with_column("_upper_boundary".into(), dtype.clone());
293                }
294                schema.with_column(name.clone(), dtype.clone());
295            }
296        }
297
298        let mut aggs_schema = expr_irs_to_schema(&aggs, &current_schema, self.expr_arena);
299
300        // Coerce aggregation column(s) into List unless not needed (auto-implode)
301        debug_assert!(aggs_schema.len() == aggs.len());
302        for ((_name, dtype), expr) in aggs_schema.iter_mut().zip(&aggs) {
303            if !expr.is_scalar(self.expr_arena) {
304                *dtype = dtype.clone().implode();
305            }
306        }
307
308        schema.merge(aggs_schema);
309
310        let lp = IR::GroupBy {
311            input: self.root,
312            keys,
313            aggs,
314            schema: Arc::new(schema),
315            apply,
316            maintain_order,
317            options,
318        };
319        self.add_alp(lp)
320    }
321
322    pub fn join(
323        self,
324        other: Node,
325        left_on: Vec<ExprIR>,
326        right_on: Vec<ExprIR>,
327        options: Arc<JoinOptionsIR>,
328    ) -> Self {
329        let schema_left = self.schema();
330        let schema_right = self.lp_arena.get(other).schema(self.lp_arena);
331
332        let schema = det_join_schema(
333            &schema_left,
334            &schema_right,
335            &left_on,
336            &right_on,
337            &options,
338            self.expr_arena,
339        )
340        .unwrap();
341
342        let lp = IR::Join {
343            input_left: self.root,
344            input_right: other,
345            schema,
346            left_on,
347            right_on,
348            options,
349        };
350
351        self.add_alp(lp)
352    }
353
354    #[cfg(feature = "pivot")]
355    pub fn unpivot(self, args: Arc<UnpivotArgsIR>) -> Self {
356        let lp = IR::MapFunction {
357            input: self.root,
358            function: FunctionIR::Unpivot {
359                args,
360                schema: Default::default(),
361            },
362        };
363        self.add_alp(lp)
364    }
365
366    pub fn row_index(self, name: PlSmallStr, offset: Option<IdxSize>) -> Self {
367        let lp = IR::MapFunction {
368            input: self.root,
369            function: FunctionIR::RowIndex {
370                name,
371                offset,
372                schema: Default::default(),
373            },
374        };
375        self.add_alp(lp)
376    }
377}