1use std::borrow::Cow;
2
3use super::*;
4
5pub struct IRBuilder<'a> {
6 root: Node,
7 expr_arena: &'a mut Arena<AExpr>,
8 lp_arena: &'a mut Arena<IR>,
9}
10
11impl<'a> IRBuilder<'a> {
12 pub fn new(root: Node, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {
13 IRBuilder {
14 root,
15 expr_arena,
16 lp_arena,
17 }
18 }
19
20 pub fn from_lp(lp: IR, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {
21 let root = lp_arena.add(lp);
22 IRBuilder {
23 root,
24 expr_arena,
25 lp_arena,
26 }
27 }
28
29 pub fn add_alp(self, lp: IR) -> Self {
30 let node = self.lp_arena.add(lp);
31 IRBuilder::new(node, self.expr_arena, self.lp_arena)
32 }
33
34 pub fn add_alp_optimize_exprs<F>(self, f: F) -> PolarsResult<Self>
36 where
37 F: FnOnce(Node) -> IR,
38 {
39 let lp = f(self.root);
40 let ir_name = lp.name();
41
42 let b = self.add_alp(lp);
43
44 let mut conversion_optimizer = ConversionOptimizer::new(true, true, true);
46 conversion_optimizer.fill_scratch(b.lp_arena.get(b.root).exprs(), b.expr_arena);
47 conversion_optimizer
48 .optimize_exprs(b.expr_arena, b.lp_arena, b.root, false)
49 .map_err(|e| e.context(format!("optimizing '{ir_name}' failed").into()))?;
50
51 Ok(b)
52 }
53
54 pub fn add_expr(&mut self, expr: Expr) -> PolarsResult<ExprIR> {
56 let schema = self.lp_arena.get(self.root).schema(self.lp_arena);
57 let mut ctx = ExprToIRContext::new(self.expr_arena, &schema);
58 to_expr_ir(expr, &mut ctx)
59 }
60
61 pub fn project(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {
62 if exprs.is_empty() {
64 self
65 } else {
66 let input_schema = self.schema();
67 let schema = expr_irs_to_schema(&exprs, &input_schema, self.expr_arena);
68
69 let lp = IR::Select {
70 expr: exprs,
71 input: self.root,
72 schema: Arc::new(schema),
73 options,
74 };
75 let node = self.lp_arena.add(lp);
76 IRBuilder::new(node, self.expr_arena, self.lp_arena)
77 }
78 }
79
80 pub fn project_simple_nodes<I, N>(self, nodes: I) -> PolarsResult<Self>
81 where
82 I: IntoIterator<Item = N>,
83 N: Into<Node>,
84 I::IntoIter: ExactSizeIterator,
85 {
86 let names = nodes
87 .into_iter()
88 .map(|node| match self.expr_arena.get(node.into()) {
89 AExpr::Column(name) => name,
90 _ => unreachable!(),
91 });
92 if names.size_hint().0 == 0 {
94 Ok(self)
95 } else {
96 let input_schema = self.schema();
97 let mut count = 0;
98 let schema = names
99 .map(|name| {
100 let dtype = input_schema.try_get(name)?;
101 count += 1;
102 Ok(Field::new(name.clone(), dtype.clone()))
103 })
104 .collect::<PolarsResult<Schema>>()?;
105
106 polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");
107
108 let lp = IR::SimpleProjection {
109 input: self.root,
110 columns: Arc::new(schema),
111 };
112 let node = self.lp_arena.add(lp);
113 Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))
114 }
115 }
116
117 pub fn project_simple<I, S>(self, names: I) -> PolarsResult<Self>
118 where
119 I: IntoIterator<Item = S>,
120 I::IntoIter: ExactSizeIterator,
121 S: Into<PlSmallStr>,
122 {
123 let names = names.into_iter();
124 if names.size_hint().0 == 0 {
126 Ok(self)
127 } else {
128 let input_schema = self.schema();
129 let mut count = 0;
130 let schema = names
131 .map(|name| {
132 let name: PlSmallStr = name.into();
133 let dtype = input_schema.try_get(name.as_str())?;
134 count += 1;
135 Ok(Field::new(name, dtype.clone()))
136 })
137 .collect::<PolarsResult<Schema>>()?;
138
139 polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");
140
141 let lp = IR::SimpleProjection {
142 input: self.root,
143 columns: Arc::new(schema),
144 };
145 let node = self.lp_arena.add(lp);
146 Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))
147 }
148 }
149
150 pub fn drop<I, S>(self, names: I) -> Self
151 where
152 I: IntoIterator<Item = S>,
153 I::IntoIter: ExactSizeIterator,
154 S: Into<PlSmallStr>,
155 {
156 let names = names.into_iter();
157 if names.size_hint().0 == 0 {
159 self
160 } else {
161 let mut schema = self.schema().as_ref().as_ref().clone();
162
163 for name in names {
164 let name: PlSmallStr = name.into();
165 schema.remove(&name);
166 }
167
168 let lp = IR::SimpleProjection {
169 input: self.root,
170 columns: Arc::new(schema),
171 };
172 let node = self.lp_arena.add(lp);
173 IRBuilder::new(node, self.expr_arena, self.lp_arena)
174 }
175 }
176
177 pub fn sort(
178 self,
179 by_column: Vec<ExprIR>,
180 slice: Option<(i64, usize)>,
181 sort_options: SortMultipleOptions,
182 ) -> Self {
183 let ir = IR::Sort {
184 input: self.root,
185 by_column,
186 slice,
187 sort_options,
188 };
189 let node = self.lp_arena.add(ir);
190 IRBuilder::new(node, self.expr_arena, self.lp_arena)
191 }
192
193 pub fn node(self) -> Node {
194 self.root
195 }
196
197 pub fn build(self) -> IR {
198 if self.root.0 == self.lp_arena.len() {
199 self.lp_arena.pop().unwrap()
200 } else {
201 self.lp_arena.take(self.root)
202 }
203 }
204
205 pub fn schema(&'a self) -> Cow<'a, SchemaRef> {
206 self.lp_arena.get(self.root).schema(self.lp_arena)
207 }
208
209 pub fn with_columns(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {
210 let schema = self.schema();
211 let mut new_schema = (**schema).clone();
212
213 let hstack_schema = expr_irs_to_schema(&exprs, &schema, self.expr_arena);
214 new_schema.merge(hstack_schema);
215
216 let lp = IR::HStack {
217 input: self.root,
218 exprs,
219 schema: Arc::new(new_schema),
220 options,
221 };
222 self.add_alp(lp)
223 }
224
225 pub fn with_columns_simple<I, J: Into<Node>>(self, exprs: I, options: ProjectionOptions) -> Self
226 where
227 I: IntoIterator<Item = J>,
228 {
229 let schema = self.schema();
230 let mut new_schema = (**schema).clone();
231
232 let iter = exprs.into_iter();
233 let mut expr_irs = Vec::with_capacity(iter.size_hint().0);
234 for node in iter {
235 let node = node.into();
236 let field = self
237 .expr_arena
238 .get(node)
239 .to_field(&schema, self.expr_arena)
240 .unwrap();
241
242 expr_irs.push(
243 ExprIR::new(node, OutputName::ColumnLhs(field.name.clone()))
244 .with_dtype(field.dtype.clone()),
245 );
246 new_schema.with_column(field.name().clone(), field.dtype().clone());
247 }
248
249 let lp = IR::HStack {
250 input: self.root,
251 exprs: expr_irs,
252 schema: Arc::new(new_schema),
253 options,
254 };
255 self.add_alp(lp)
256 }
257
258 pub fn explode(self, columns: Arc<[PlSmallStr]>) -> Self {
260 let lp = IR::MapFunction {
261 input: self.root,
262 function: FunctionIR::Explode {
263 columns,
264 schema: Default::default(),
265 },
266 };
267 self.add_alp(lp)
268 }
269
270 pub fn group_by(
271 self,
272 keys: Vec<ExprIR>,
273 aggs: Vec<ExprIR>,
274 apply: Option<PlanCallback<DataFrame, DataFrame>>,
275 maintain_order: bool,
276 options: Arc<GroupbyOptions>,
277 ) -> Self {
278 let current_schema = self.schema();
279 let mut schema = expr_irs_to_schema(&keys, ¤t_schema, self.expr_arena);
280
281 #[cfg(feature = "dynamic_group_by")]
282 {
283 if let Some(options) = options.rolling.as_ref() {
284 let name = &options.index_column;
285 let dtype = current_schema.get(name).unwrap();
286 schema.with_column(name.clone(), dtype.clone());
287 } else if let Some(options) = options.dynamic.as_ref() {
288 let name = &options.index_column;
289 let dtype = current_schema.get(name).unwrap();
290 if options.include_boundaries {
291 schema.with_column("_lower_boundary".into(), dtype.clone());
292 schema.with_column("_upper_boundary".into(), dtype.clone());
293 }
294 schema.with_column(name.clone(), dtype.clone());
295 }
296 }
297
298 let mut aggs_schema = expr_irs_to_schema(&aggs, ¤t_schema, self.expr_arena);
299
300 debug_assert!(aggs_schema.len() == aggs.len());
302 for ((_name, dtype), expr) in aggs_schema.iter_mut().zip(&aggs) {
303 if !expr.is_scalar(self.expr_arena) {
304 *dtype = dtype.clone().implode();
305 }
306 }
307
308 schema.merge(aggs_schema);
309
310 let lp = IR::GroupBy {
311 input: self.root,
312 keys,
313 aggs,
314 schema: Arc::new(schema),
315 apply,
316 maintain_order,
317 options,
318 };
319 self.add_alp(lp)
320 }
321
322 pub fn join(
323 self,
324 other: Node,
325 left_on: Vec<ExprIR>,
326 right_on: Vec<ExprIR>,
327 options: Arc<JoinOptionsIR>,
328 ) -> Self {
329 let schema_left = self.schema();
330 let schema_right = self.lp_arena.get(other).schema(self.lp_arena);
331
332 let schema = det_join_schema(
333 &schema_left,
334 &schema_right,
335 &left_on,
336 &right_on,
337 &options,
338 self.expr_arena,
339 )
340 .unwrap();
341
342 let lp = IR::Join {
343 input_left: self.root,
344 input_right: other,
345 schema,
346 left_on,
347 right_on,
348 options,
349 };
350
351 self.add_alp(lp)
352 }
353
354 #[cfg(feature = "pivot")]
355 pub fn unpivot(self, args: Arc<UnpivotArgsIR>) -> Self {
356 let lp = IR::MapFunction {
357 input: self.root,
358 function: FunctionIR::Unpivot {
359 args,
360 schema: Default::default(),
361 },
362 };
363 self.add_alp(lp)
364 }
365
366 pub fn row_index(self, name: PlSmallStr, offset: Option<IdxSize>) -> Self {
367 let lp = IR::MapFunction {
368 input: self.root,
369 function: FunctionIR::RowIndex {
370 name,
371 offset,
372 schema: Default::default(),
373 },
374 };
375 self.add_alp(lp)
376 }
377}