Skip to main content

alopex_dataframe/lazy/
optimizer.rs

1use std::collections::HashSet;
2
3use crate::expr::{Expr as E, Operator};
4use crate::lazy::{LogicalPlan, ProjectionKind};
5use crate::Expr;
6
7/// Optimizer that rewrites `LogicalPlan` (e.g. predicate/projection pushdown).
8pub struct Optimizer;
9
10impl Optimizer {
11    /// Optimize a `LogicalPlan` and return the rewritten plan.
12    pub fn optimize(plan: &LogicalPlan) -> LogicalPlan {
13        let plan = predicate_pushdown(plan.clone());
14        projection_pushdown(plan)
15    }
16}
17
18fn predicate_pushdown(plan: LogicalPlan) -> LogicalPlan {
19    match plan {
20        LogicalPlan::Filter { input, predicate } => {
21            let input = predicate_pushdown(*input);
22
23            match input {
24                LogicalPlan::Filter {
25                    input: inner,
26                    predicate: inner_predicate,
27                } => {
28                    let combined = and_expr(inner_predicate, predicate);
29                    predicate_pushdown(LogicalPlan::Filter {
30                        input: inner,
31                        predicate: combined,
32                    })
33                }
34                LogicalPlan::Projection { input, exprs, kind } => {
35                    if can_push_filter_through_projection(&predicate, &exprs, &kind) {
36                        predicate_pushdown(LogicalPlan::Projection {
37                            input: Box::new(LogicalPlan::Filter { input, predicate }),
38                            exprs,
39                            kind,
40                        })
41                    } else {
42                        LogicalPlan::Filter {
43                            input: Box::new(LogicalPlan::Projection { input, exprs, kind }),
44                            predicate,
45                        }
46                    }
47                }
48                LogicalPlan::CsvScan {
49                    path,
50                    predicate: existing,
51                    projection,
52                } => LogicalPlan::CsvScan {
53                    path,
54                    predicate: Some(match existing {
55                        Some(existing) => and_expr(existing, predicate),
56                        None => predicate,
57                    }),
58                    projection,
59                },
60                LogicalPlan::ParquetScan {
61                    path,
62                    predicate: existing,
63                    projection,
64                } => LogicalPlan::ParquetScan {
65                    path,
66                    predicate: Some(match existing {
67                        Some(existing) => and_expr(existing, predicate),
68                        None => predicate,
69                    }),
70                    projection,
71                },
72                other => LogicalPlan::Filter {
73                    input: Box::new(other),
74                    predicate,
75                },
76            }
77        }
78        LogicalPlan::Projection { input, exprs, kind } => LogicalPlan::Projection {
79            input: Box::new(predicate_pushdown(*input)),
80            exprs,
81            kind,
82        },
83        LogicalPlan::Aggregate {
84            input,
85            group_by,
86            aggs,
87        } => LogicalPlan::Aggregate {
88            input: Box::new(predicate_pushdown(*input)),
89            group_by,
90            aggs,
91        },
92        LogicalPlan::Join {
93            left,
94            right,
95            keys,
96            how,
97        } => LogicalPlan::Join {
98            left: Box::new(predicate_pushdown(*left)),
99            right: Box::new(predicate_pushdown(*right)),
100            keys,
101            how,
102        },
103        LogicalPlan::Sort { input, options } => LogicalPlan::Sort {
104            input: Box::new(predicate_pushdown(*input)),
105            options,
106        },
107        LogicalPlan::Slice {
108            input,
109            offset,
110            len,
111            from_end,
112        } => LogicalPlan::Slice {
113            input: Box::new(predicate_pushdown(*input)),
114            offset,
115            len,
116            from_end,
117        },
118        LogicalPlan::Unique { input, subset } => LogicalPlan::Unique {
119            input: Box::new(predicate_pushdown(*input)),
120            subset,
121        },
122        LogicalPlan::FillNull { input, fill } => LogicalPlan::FillNull {
123            input: Box::new(predicate_pushdown(*input)),
124            fill,
125        },
126        LogicalPlan::DropNulls { input, subset } => LogicalPlan::DropNulls {
127            input: Box::new(predicate_pushdown(*input)),
128            subset,
129        },
130        LogicalPlan::NullCount { input } => LogicalPlan::NullCount {
131            input: Box::new(predicate_pushdown(*input)),
132        },
133        other => other,
134    }
135}
136
137fn and_expr(left: Expr, right: Expr) -> Expr {
138    let mut conjuncts = Vec::new();
139    conjuncts.extend(flatten_and(left));
140    conjuncts.extend(flatten_and(right));
141    build_and(conjuncts)
142}
143
144fn flatten_and(expr: Expr) -> Vec<Expr> {
145    match expr {
146        E::BinaryOp {
147            left,
148            op: Operator::And,
149            right,
150        } => {
151            let mut out = flatten_and(*left);
152            out.extend(flatten_and(*right));
153            out
154        }
155        other => vec![other],
156    }
157}
158
159fn build_and(mut conjuncts: Vec<Expr>) -> Expr {
160    let first = conjuncts
161        .pop()
162        .expect("build_and must be called with non-empty conjuncts");
163    conjuncts
164        .into_iter()
165        .rev()
166        .fold(first, |acc, expr| E::BinaryOp {
167            left: Box::new(expr),
168            op: Operator::And,
169            right: Box::new(acc),
170        })
171}
172
173fn can_push_filter_through_projection(
174    predicate: &Expr,
175    exprs: &[Expr],
176    kind: &ProjectionKind,
177) -> bool {
178    let referenced = referenced_columns(predicate);
179
180    match kind {
181        ProjectionKind::Select => match projection_select_output_columns(exprs) {
182            OutputColumns::Some(cols) => referenced.is_subset(&cols),
183            OutputColumns::All | OutputColumns::Unknown => false,
184        },
185        ProjectionKind::WithColumns => {
186            let assigned = projection_assigned_columns(exprs);
187            !referenced.iter().any(|c| assigned.contains(c))
188        }
189    }
190}
191
192fn referenced_columns(expr: &Expr) -> HashSet<String> {
193    let mut out = HashSet::new();
194    collect_referenced_columns(expr, &mut out);
195    out
196}
197
198fn collect_referenced_columns(expr: &Expr, out: &mut HashSet<String>) {
199    match expr {
200        E::Column(name) => {
201            out.insert(name.clone());
202        }
203        E::Alias { expr, .. } => collect_referenced_columns(expr, out),
204        E::UnaryOp { expr, .. } => collect_referenced_columns(expr, out),
205        E::BinaryOp { left, right, .. } => {
206            collect_referenced_columns(left, out);
207            collect_referenced_columns(right, out);
208        }
209        E::Agg { expr, .. } => collect_referenced_columns(expr, out),
210        E::Literal(_) | E::Wildcard => {}
211    }
212}
213
214enum OutputColumns {
215    All,
216    Some(HashSet<String>),
217    Unknown,
218}
219
220fn projection_select_output_columns(exprs: &[Expr]) -> OutputColumns {
221    let mut cols = HashSet::new();
222    for expr in exprs {
223        match expr {
224            E::Wildcard => return OutputColumns::All,
225            E::Alias { name, .. } => {
226                cols.insert(name.clone());
227            }
228            E::Column(name) => {
229                cols.insert(name.clone());
230            }
231            _ => return OutputColumns::Unknown,
232        }
233    }
234    OutputColumns::Some(cols)
235}
236
237fn projection_assigned_columns(exprs: &[Expr]) -> HashSet<String> {
238    let mut cols = HashSet::new();
239    for expr in exprs {
240        match expr {
241            E::Alias { name, .. } => {
242                cols.insert(name.clone());
243            }
244            E::Column(name) => {
245                cols.insert(name.clone());
246            }
247            _ => {}
248        }
249    }
250    cols
251}
252
253fn projection_pushdown(plan: LogicalPlan) -> LogicalPlan {
254    projection_pushdown_inner(plan, RequiredColumns::All).0
255}
256
257#[derive(Debug, Clone)]
258enum RequiredColumns {
259    All,
260    Some(HashSet<String>),
261}
262
263impl RequiredColumns {
264    fn union(self, other: Self) -> Self {
265        match (self, other) {
266            (RequiredColumns::All, _) | (_, RequiredColumns::All) => RequiredColumns::All,
267            (RequiredColumns::Some(mut a), RequiredColumns::Some(b)) => {
268                a.extend(b);
269                RequiredColumns::Some(a)
270            }
271        }
272    }
273}
274
275fn projection_pushdown_inner(
276    plan: LogicalPlan,
277    required: RequiredColumns,
278) -> (LogicalPlan, RequiredColumns) {
279    match plan {
280        LogicalPlan::Projection { input, exprs, kind } => match kind {
281            ProjectionKind::Select => {
282                let input_required = required_columns_for_select(&exprs);
283                let (new_input, _) = projection_pushdown_inner(*input, input_required);
284                (
285                    LogicalPlan::Projection {
286                        input: Box::new(new_input),
287                        exprs,
288                        kind: ProjectionKind::Select,
289                    },
290                    required,
291                )
292            }
293            ProjectionKind::WithColumns => {
294                // Conservative: keep required columns and any inputs to compute overwritten/new columns
295                let mut needed = HashSet::new();
296                if let RequiredColumns::Some(ref req) = required {
297                    needed.extend(req.clone());
298                }
299                for expr in &exprs {
300                    match expr {
301                        E::Alias { expr, .. } => needed.extend(referenced_columns(expr)),
302                        E::Column(_) => {}
303                        _ => {}
304                    }
305                }
306                let (new_input, _) =
307                    projection_pushdown_inner(*input, RequiredColumns::Some(needed));
308                (
309                    LogicalPlan::Projection {
310                        input: Box::new(new_input),
311                        exprs,
312                        kind: ProjectionKind::WithColumns,
313                    },
314                    required,
315                )
316            }
317        },
318        LogicalPlan::Filter { input, predicate } => {
319            let input_required = required
320                .clone()
321                .union(RequiredColumns::Some(referenced_columns(&predicate)));
322            let (new_input, _) = projection_pushdown_inner(*input, input_required);
323            (
324                LogicalPlan::Filter {
325                    input: Box::new(new_input),
326                    predicate,
327                },
328                required,
329            )
330        }
331        LogicalPlan::Aggregate {
332            input,
333            group_by,
334            aggs,
335        } => {
336            let mut needed = HashSet::new();
337            for e in group_by.iter().chain(aggs.iter()) {
338                needed.extend(referenced_columns(e));
339            }
340            let (new_input, _) = projection_pushdown_inner(*input, RequiredColumns::Some(needed));
341            (
342                LogicalPlan::Aggregate {
343                    input: Box::new(new_input),
344                    group_by,
345                    aggs,
346                },
347                required,
348            )
349        }
350        LogicalPlan::Join {
351            left,
352            right,
353            keys,
354            how,
355        } => {
356            let (new_left, _) = projection_pushdown_inner(*left, RequiredColumns::All);
357            let (new_right, _) = projection_pushdown_inner(*right, RequiredColumns::All);
358            (
359                LogicalPlan::Join {
360                    left: Box::new(new_left),
361                    right: Box::new(new_right),
362                    keys,
363                    how,
364                },
365                required,
366            )
367        }
368        LogicalPlan::Sort { input, options } => {
369            let (new_input, _) = projection_pushdown_inner(*input, required.clone());
370            (
371                LogicalPlan::Sort {
372                    input: Box::new(new_input),
373                    options,
374                },
375                required,
376            )
377        }
378        LogicalPlan::Slice {
379            input,
380            offset,
381            len,
382            from_end,
383        } => {
384            let (new_input, _) = projection_pushdown_inner(*input, required.clone());
385            (
386                LogicalPlan::Slice {
387                    input: Box::new(new_input),
388                    offset,
389                    len,
390                    from_end,
391                },
392                required,
393            )
394        }
395        LogicalPlan::Unique { input, subset } => {
396            let (new_input, _) = projection_pushdown_inner(*input, required.clone());
397            (
398                LogicalPlan::Unique {
399                    input: Box::new(new_input),
400                    subset,
401                },
402                required,
403            )
404        }
405        LogicalPlan::FillNull { input, fill } => {
406            let (new_input, _) = projection_pushdown_inner(*input, required.clone());
407            (
408                LogicalPlan::FillNull {
409                    input: Box::new(new_input),
410                    fill,
411                },
412                required,
413            )
414        }
415        LogicalPlan::DropNulls { input, subset } => {
416            let (new_input, _) = projection_pushdown_inner(*input, required.clone());
417            (
418                LogicalPlan::DropNulls {
419                    input: Box::new(new_input),
420                    subset,
421                },
422                required,
423            )
424        }
425        LogicalPlan::NullCount { input } => {
426            let (new_input, _) = projection_pushdown_inner(*input, RequiredColumns::All);
427            (
428                LogicalPlan::NullCount {
429                    input: Box::new(new_input),
430                },
431                RequiredColumns::All,
432            )
433        }
434        LogicalPlan::CsvScan {
435            path,
436            predicate,
437            projection,
438        } => {
439            let mut needed = match required {
440                RequiredColumns::All => None,
441                RequiredColumns::Some(s) => Some(s),
442            };
443            if let Some(pred) = &predicate {
444                let cols = referenced_columns(pred);
445                needed = Some(match needed {
446                    Some(mut s) => {
447                        s.extend(cols);
448                        s
449                    }
450                    None => cols,
451                });
452            }
453            (
454                LogicalPlan::CsvScan {
455                    path,
456                    predicate,
457                    projection: merge_projection(projection, needed),
458                },
459                RequiredColumns::All,
460            )
461        }
462        LogicalPlan::ParquetScan {
463            path,
464            predicate,
465            projection,
466        } => {
467            let mut needed = match required {
468                RequiredColumns::All => None,
469                RequiredColumns::Some(s) => Some(s),
470            };
471            if let Some(pred) = &predicate {
472                let cols = referenced_columns(pred);
473                needed = Some(match needed {
474                    Some(mut s) => {
475                        s.extend(cols);
476                        s
477                    }
478                    None => cols,
479                });
480            }
481            (
482                LogicalPlan::ParquetScan {
483                    path,
484                    predicate,
485                    projection: merge_projection(projection, needed),
486                },
487                RequiredColumns::All,
488            )
489        }
490        other => (other, RequiredColumns::All),
491    }
492}
493
494fn required_columns_for_select(exprs: &[Expr]) -> RequiredColumns {
495    let mut needed = HashSet::new();
496    for expr in exprs {
497        match expr {
498            E::Wildcard => return RequiredColumns::All,
499            other => needed.extend(referenced_columns(other)),
500        }
501    }
502    RequiredColumns::Some(needed)
503}
504
505fn merge_projection(
506    existing: Option<Vec<String>>,
507    needed: Option<HashSet<String>>,
508) -> Option<Vec<String>> {
509    let Some(needed) = needed else {
510        return existing;
511    };
512
513    let mut out = Vec::new();
514    let mut seen = HashSet::new();
515
516    if let Some(existing) = existing {
517        for c in existing {
518            if seen.insert(c.clone()) {
519                out.push(c);
520            }
521        }
522    }
523
524    for c in needed {
525        if seen.insert(c.clone()) {
526            out.push(c);
527        }
528    }
529
530    Some(out)
531}
532
533#[cfg(test)]
534mod tests {
535    use super::Optimizer;
536    use crate::expr::{col, lit};
537    use crate::lazy::LogicalPlan;
538    use crate::lazy::ProjectionKind;
539
540    #[test]
541    fn predicate_pushdown_moves_filter_into_scan() {
542        let plan = LogicalPlan::Filter {
543            input: Box::new(LogicalPlan::CsvScan {
544                path: "data.csv".into(),
545                predicate: None,
546                projection: None,
547            }),
548            predicate: col("a").gt(lit(1_i64)),
549        };
550
551        let optimized = Optimizer::optimize(&plan);
552        match optimized {
553            LogicalPlan::CsvScan { predicate, .. } => assert!(predicate.is_some()),
554            other => panic!("expected CsvScan, got {other:?}"),
555        }
556    }
557
558    #[test]
559    fn predicate_pushdown_combines_multiple_filters_with_and() {
560        let plan = LogicalPlan::Filter {
561            input: Box::new(LogicalPlan::Filter {
562                input: Box::new(LogicalPlan::CsvScan {
563                    path: "data.csv".into(),
564                    predicate: None,
565                    projection: None,
566                }),
567                predicate: col("a").gt(lit(1_i64)),
568            }),
569            predicate: col("b").lt(lit(10_i64)),
570        };
571
572        let optimized = Optimizer::optimize(&plan);
573        match optimized {
574            LogicalPlan::CsvScan {
575                predicate: Some(p), ..
576            } => {
577                let s = format!("{p:?}");
578                assert!(s.contains("And"));
579            }
580            other => panic!("expected CsvScan with predicate, got {other:?}"),
581        }
582    }
583
584    #[test]
585    fn predicate_pushdown_does_not_cross_select_when_column_not_selected() {
586        let plan = LogicalPlan::Filter {
587            input: Box::new(LogicalPlan::Projection {
588                input: Box::new(LogicalPlan::CsvScan {
589                    path: "data.csv".into(),
590                    predicate: None,
591                    projection: None,
592                }),
593                exprs: vec![col("a")],
594                kind: ProjectionKind::Select,
595            }),
596            predicate: col("b").gt(lit(1_i64)),
597        };
598
599        let optimized = Optimizer::optimize(&plan);
600        assert!(matches!(optimized, LogicalPlan::Filter { .. }));
601    }
602
603    #[test]
604    fn projection_pushdown_sets_scan_projection() {
605        let plan = LogicalPlan::Projection {
606            input: Box::new(LogicalPlan::CsvScan {
607                path: "data.csv".into(),
608                predicate: None,
609                projection: None,
610            }),
611            exprs: vec![col("a"), col("b")],
612            kind: ProjectionKind::Select,
613        };
614
615        let optimized = Optimizer::optimize(&plan);
616        let s = optimized.display();
617        assert!(s.contains("projection"));
618    }
619}