1use std::collections::HashSet;
2
3use crate::expr::{Expr as E, Operator};
4use crate::lazy::{LogicalPlan, ProjectionKind};
5use crate::Expr;
6
7pub struct Optimizer;
9
10impl Optimizer {
11 pub fn optimize(plan: &LogicalPlan) -> LogicalPlan {
13 let plan = predicate_pushdown(plan.clone());
14 projection_pushdown(plan)
15 }
16}
17
18fn predicate_pushdown(plan: LogicalPlan) -> LogicalPlan {
19 match plan {
20 LogicalPlan::Filter { input, predicate } => {
21 let input = predicate_pushdown(*input);
22
23 match input {
24 LogicalPlan::Filter {
25 input: inner,
26 predicate: inner_predicate,
27 } => {
28 let combined = and_expr(inner_predicate, predicate);
29 predicate_pushdown(LogicalPlan::Filter {
30 input: inner,
31 predicate: combined,
32 })
33 }
34 LogicalPlan::Projection { input, exprs, kind } => {
35 if can_push_filter_through_projection(&predicate, &exprs, &kind) {
36 predicate_pushdown(LogicalPlan::Projection {
37 input: Box::new(LogicalPlan::Filter { input, predicate }),
38 exprs,
39 kind,
40 })
41 } else {
42 LogicalPlan::Filter {
43 input: Box::new(LogicalPlan::Projection { input, exprs, kind }),
44 predicate,
45 }
46 }
47 }
48 LogicalPlan::CsvScan {
49 path,
50 predicate: existing,
51 projection,
52 } => LogicalPlan::CsvScan {
53 path,
54 predicate: Some(match existing {
55 Some(existing) => and_expr(existing, predicate),
56 None => predicate,
57 }),
58 projection,
59 },
60 LogicalPlan::ParquetScan {
61 path,
62 predicate: existing,
63 projection,
64 } => LogicalPlan::ParquetScan {
65 path,
66 predicate: Some(match existing {
67 Some(existing) => and_expr(existing, predicate),
68 None => predicate,
69 }),
70 projection,
71 },
72 other => LogicalPlan::Filter {
73 input: Box::new(other),
74 predicate,
75 },
76 }
77 }
78 LogicalPlan::Projection { input, exprs, kind } => LogicalPlan::Projection {
79 input: Box::new(predicate_pushdown(*input)),
80 exprs,
81 kind,
82 },
83 LogicalPlan::Aggregate {
84 input,
85 group_by,
86 aggs,
87 } => LogicalPlan::Aggregate {
88 input: Box::new(predicate_pushdown(*input)),
89 group_by,
90 aggs,
91 },
92 LogicalPlan::Join {
93 left,
94 right,
95 keys,
96 how,
97 } => LogicalPlan::Join {
98 left: Box::new(predicate_pushdown(*left)),
99 right: Box::new(predicate_pushdown(*right)),
100 keys,
101 how,
102 },
103 LogicalPlan::Sort { input, options } => LogicalPlan::Sort {
104 input: Box::new(predicate_pushdown(*input)),
105 options,
106 },
107 LogicalPlan::Slice {
108 input,
109 offset,
110 len,
111 from_end,
112 } => LogicalPlan::Slice {
113 input: Box::new(predicate_pushdown(*input)),
114 offset,
115 len,
116 from_end,
117 },
118 LogicalPlan::Unique { input, subset } => LogicalPlan::Unique {
119 input: Box::new(predicate_pushdown(*input)),
120 subset,
121 },
122 LogicalPlan::FillNull { input, fill } => LogicalPlan::FillNull {
123 input: Box::new(predicate_pushdown(*input)),
124 fill,
125 },
126 LogicalPlan::DropNulls { input, subset } => LogicalPlan::DropNulls {
127 input: Box::new(predicate_pushdown(*input)),
128 subset,
129 },
130 LogicalPlan::NullCount { input } => LogicalPlan::NullCount {
131 input: Box::new(predicate_pushdown(*input)),
132 },
133 other => other,
134 }
135}
136
137fn and_expr(left: Expr, right: Expr) -> Expr {
138 let mut conjuncts = Vec::new();
139 conjuncts.extend(flatten_and(left));
140 conjuncts.extend(flatten_and(right));
141 build_and(conjuncts)
142}
143
144fn flatten_and(expr: Expr) -> Vec<Expr> {
145 match expr {
146 E::BinaryOp {
147 left,
148 op: Operator::And,
149 right,
150 } => {
151 let mut out = flatten_and(*left);
152 out.extend(flatten_and(*right));
153 out
154 }
155 other => vec![other],
156 }
157}
158
159fn build_and(mut conjuncts: Vec<Expr>) -> Expr {
160 let first = conjuncts
161 .pop()
162 .expect("build_and must be called with non-empty conjuncts");
163 conjuncts
164 .into_iter()
165 .rev()
166 .fold(first, |acc, expr| E::BinaryOp {
167 left: Box::new(expr),
168 op: Operator::And,
169 right: Box::new(acc),
170 })
171}
172
173fn can_push_filter_through_projection(
174 predicate: &Expr,
175 exprs: &[Expr],
176 kind: &ProjectionKind,
177) -> bool {
178 let referenced = referenced_columns(predicate);
179
180 match kind {
181 ProjectionKind::Select => match projection_select_output_columns(exprs) {
182 OutputColumns::Some(cols) => referenced.is_subset(&cols),
183 OutputColumns::All | OutputColumns::Unknown => false,
184 },
185 ProjectionKind::WithColumns => {
186 let assigned = projection_assigned_columns(exprs);
187 !referenced.iter().any(|c| assigned.contains(c))
188 }
189 }
190}
191
192fn referenced_columns(expr: &Expr) -> HashSet<String> {
193 let mut out = HashSet::new();
194 collect_referenced_columns(expr, &mut out);
195 out
196}
197
198fn collect_referenced_columns(expr: &Expr, out: &mut HashSet<String>) {
199 match expr {
200 E::Column(name) => {
201 out.insert(name.clone());
202 }
203 E::Alias { expr, .. } => collect_referenced_columns(expr, out),
204 E::UnaryOp { expr, .. } => collect_referenced_columns(expr, out),
205 E::BinaryOp { left, right, .. } => {
206 collect_referenced_columns(left, out);
207 collect_referenced_columns(right, out);
208 }
209 E::Agg { expr, .. } => collect_referenced_columns(expr, out),
210 E::Literal(_) | E::Wildcard => {}
211 }
212}
213
214enum OutputColumns {
215 All,
216 Some(HashSet<String>),
217 Unknown,
218}
219
220fn projection_select_output_columns(exprs: &[Expr]) -> OutputColumns {
221 let mut cols = HashSet::new();
222 for expr in exprs {
223 match expr {
224 E::Wildcard => return OutputColumns::All,
225 E::Alias { name, .. } => {
226 cols.insert(name.clone());
227 }
228 E::Column(name) => {
229 cols.insert(name.clone());
230 }
231 _ => return OutputColumns::Unknown,
232 }
233 }
234 OutputColumns::Some(cols)
235}
236
237fn projection_assigned_columns(exprs: &[Expr]) -> HashSet<String> {
238 let mut cols = HashSet::new();
239 for expr in exprs {
240 match expr {
241 E::Alias { name, .. } => {
242 cols.insert(name.clone());
243 }
244 E::Column(name) => {
245 cols.insert(name.clone());
246 }
247 _ => {}
248 }
249 }
250 cols
251}
252
253fn projection_pushdown(plan: LogicalPlan) -> LogicalPlan {
254 projection_pushdown_inner(plan, RequiredColumns::All).0
255}
256
257#[derive(Debug, Clone)]
258enum RequiredColumns {
259 All,
260 Some(HashSet<String>),
261}
262
263impl RequiredColumns {
264 fn union(self, other: Self) -> Self {
265 match (self, other) {
266 (RequiredColumns::All, _) | (_, RequiredColumns::All) => RequiredColumns::All,
267 (RequiredColumns::Some(mut a), RequiredColumns::Some(b)) => {
268 a.extend(b);
269 RequiredColumns::Some(a)
270 }
271 }
272 }
273}
274
275fn projection_pushdown_inner(
276 plan: LogicalPlan,
277 required: RequiredColumns,
278) -> (LogicalPlan, RequiredColumns) {
279 match plan {
280 LogicalPlan::Projection { input, exprs, kind } => match kind {
281 ProjectionKind::Select => {
282 let input_required = required_columns_for_select(&exprs);
283 let (new_input, _) = projection_pushdown_inner(*input, input_required);
284 (
285 LogicalPlan::Projection {
286 input: Box::new(new_input),
287 exprs,
288 kind: ProjectionKind::Select,
289 },
290 required,
291 )
292 }
293 ProjectionKind::WithColumns => {
294 let mut needed = HashSet::new();
296 if let RequiredColumns::Some(ref req) = required {
297 needed.extend(req.clone());
298 }
299 for expr in &exprs {
300 match expr {
301 E::Alias { expr, .. } => needed.extend(referenced_columns(expr)),
302 E::Column(_) => {}
303 _ => {}
304 }
305 }
306 let (new_input, _) =
307 projection_pushdown_inner(*input, RequiredColumns::Some(needed));
308 (
309 LogicalPlan::Projection {
310 input: Box::new(new_input),
311 exprs,
312 kind: ProjectionKind::WithColumns,
313 },
314 required,
315 )
316 }
317 },
318 LogicalPlan::Filter { input, predicate } => {
319 let input_required = required
320 .clone()
321 .union(RequiredColumns::Some(referenced_columns(&predicate)));
322 let (new_input, _) = projection_pushdown_inner(*input, input_required);
323 (
324 LogicalPlan::Filter {
325 input: Box::new(new_input),
326 predicate,
327 },
328 required,
329 )
330 }
331 LogicalPlan::Aggregate {
332 input,
333 group_by,
334 aggs,
335 } => {
336 let mut needed = HashSet::new();
337 for e in group_by.iter().chain(aggs.iter()) {
338 needed.extend(referenced_columns(e));
339 }
340 let (new_input, _) = projection_pushdown_inner(*input, RequiredColumns::Some(needed));
341 (
342 LogicalPlan::Aggregate {
343 input: Box::new(new_input),
344 group_by,
345 aggs,
346 },
347 required,
348 )
349 }
350 LogicalPlan::Join {
351 left,
352 right,
353 keys,
354 how,
355 } => {
356 let (new_left, _) = projection_pushdown_inner(*left, RequiredColumns::All);
357 let (new_right, _) = projection_pushdown_inner(*right, RequiredColumns::All);
358 (
359 LogicalPlan::Join {
360 left: Box::new(new_left),
361 right: Box::new(new_right),
362 keys,
363 how,
364 },
365 required,
366 )
367 }
368 LogicalPlan::Sort { input, options } => {
369 let (new_input, _) = projection_pushdown_inner(*input, required.clone());
370 (
371 LogicalPlan::Sort {
372 input: Box::new(new_input),
373 options,
374 },
375 required,
376 )
377 }
378 LogicalPlan::Slice {
379 input,
380 offset,
381 len,
382 from_end,
383 } => {
384 let (new_input, _) = projection_pushdown_inner(*input, required.clone());
385 (
386 LogicalPlan::Slice {
387 input: Box::new(new_input),
388 offset,
389 len,
390 from_end,
391 },
392 required,
393 )
394 }
395 LogicalPlan::Unique { input, subset } => {
396 let (new_input, _) = projection_pushdown_inner(*input, required.clone());
397 (
398 LogicalPlan::Unique {
399 input: Box::new(new_input),
400 subset,
401 },
402 required,
403 )
404 }
405 LogicalPlan::FillNull { input, fill } => {
406 let (new_input, _) = projection_pushdown_inner(*input, required.clone());
407 (
408 LogicalPlan::FillNull {
409 input: Box::new(new_input),
410 fill,
411 },
412 required,
413 )
414 }
415 LogicalPlan::DropNulls { input, subset } => {
416 let (new_input, _) = projection_pushdown_inner(*input, required.clone());
417 (
418 LogicalPlan::DropNulls {
419 input: Box::new(new_input),
420 subset,
421 },
422 required,
423 )
424 }
425 LogicalPlan::NullCount { input } => {
426 let (new_input, _) = projection_pushdown_inner(*input, RequiredColumns::All);
427 (
428 LogicalPlan::NullCount {
429 input: Box::new(new_input),
430 },
431 RequiredColumns::All,
432 )
433 }
434 LogicalPlan::CsvScan {
435 path,
436 predicate,
437 projection,
438 } => {
439 let mut needed = match required {
440 RequiredColumns::All => None,
441 RequiredColumns::Some(s) => Some(s),
442 };
443 if let Some(pred) = &predicate {
444 let cols = referenced_columns(pred);
445 needed = Some(match needed {
446 Some(mut s) => {
447 s.extend(cols);
448 s
449 }
450 None => cols,
451 });
452 }
453 (
454 LogicalPlan::CsvScan {
455 path,
456 predicate,
457 projection: merge_projection(projection, needed),
458 },
459 RequiredColumns::All,
460 )
461 }
462 LogicalPlan::ParquetScan {
463 path,
464 predicate,
465 projection,
466 } => {
467 let mut needed = match required {
468 RequiredColumns::All => None,
469 RequiredColumns::Some(s) => Some(s),
470 };
471 if let Some(pred) = &predicate {
472 let cols = referenced_columns(pred);
473 needed = Some(match needed {
474 Some(mut s) => {
475 s.extend(cols);
476 s
477 }
478 None => cols,
479 });
480 }
481 (
482 LogicalPlan::ParquetScan {
483 path,
484 predicate,
485 projection: merge_projection(projection, needed),
486 },
487 RequiredColumns::All,
488 )
489 }
490 other => (other, RequiredColumns::All),
491 }
492}
493
494fn required_columns_for_select(exprs: &[Expr]) -> RequiredColumns {
495 let mut needed = HashSet::new();
496 for expr in exprs {
497 match expr {
498 E::Wildcard => return RequiredColumns::All,
499 other => needed.extend(referenced_columns(other)),
500 }
501 }
502 RequiredColumns::Some(needed)
503}
504
505fn merge_projection(
506 existing: Option<Vec<String>>,
507 needed: Option<HashSet<String>>,
508) -> Option<Vec<String>> {
509 let Some(needed) = needed else {
510 return existing;
511 };
512
513 let mut out = Vec::new();
514 let mut seen = HashSet::new();
515
516 if let Some(existing) = existing {
517 for c in existing {
518 if seen.insert(c.clone()) {
519 out.push(c);
520 }
521 }
522 }
523
524 for c in needed {
525 if seen.insert(c.clone()) {
526 out.push(c);
527 }
528 }
529
530 Some(out)
531}
532
533#[cfg(test)]
534mod tests {
535 use super::Optimizer;
536 use crate::expr::{col, lit};
537 use crate::lazy::LogicalPlan;
538 use crate::lazy::ProjectionKind;
539
540 #[test]
541 fn predicate_pushdown_moves_filter_into_scan() {
542 let plan = LogicalPlan::Filter {
543 input: Box::new(LogicalPlan::CsvScan {
544 path: "data.csv".into(),
545 predicate: None,
546 projection: None,
547 }),
548 predicate: col("a").gt(lit(1_i64)),
549 };
550
551 let optimized = Optimizer::optimize(&plan);
552 match optimized {
553 LogicalPlan::CsvScan { predicate, .. } => assert!(predicate.is_some()),
554 other => panic!("expected CsvScan, got {other:?}"),
555 }
556 }
557
558 #[test]
559 fn predicate_pushdown_combines_multiple_filters_with_and() {
560 let plan = LogicalPlan::Filter {
561 input: Box::new(LogicalPlan::Filter {
562 input: Box::new(LogicalPlan::CsvScan {
563 path: "data.csv".into(),
564 predicate: None,
565 projection: None,
566 }),
567 predicate: col("a").gt(lit(1_i64)),
568 }),
569 predicate: col("b").lt(lit(10_i64)),
570 };
571
572 let optimized = Optimizer::optimize(&plan);
573 match optimized {
574 LogicalPlan::CsvScan {
575 predicate: Some(p), ..
576 } => {
577 let s = format!("{p:?}");
578 assert!(s.contains("And"));
579 }
580 other => panic!("expected CsvScan with predicate, got {other:?}"),
581 }
582 }
583
584 #[test]
585 fn predicate_pushdown_does_not_cross_select_when_column_not_selected() {
586 let plan = LogicalPlan::Filter {
587 input: Box::new(LogicalPlan::Projection {
588 input: Box::new(LogicalPlan::CsvScan {
589 path: "data.csv".into(),
590 predicate: None,
591 projection: None,
592 }),
593 exprs: vec![col("a")],
594 kind: ProjectionKind::Select,
595 }),
596 predicate: col("b").gt(lit(1_i64)),
597 };
598
599 let optimized = Optimizer::optimize(&plan);
600 assert!(matches!(optimized, LogicalPlan::Filter { .. }));
601 }
602
603 #[test]
604 fn projection_pushdown_sets_scan_projection() {
605 let plan = LogicalPlan::Projection {
606 input: Box::new(LogicalPlan::CsvScan {
607 path: "data.csv".into(),
608 predicate: None,
609 projection: None,
610 }),
611 exprs: vec![col("a"), col("b")],
612 kind: ProjectionKind::Select,
613 };
614
615 let optimized = Optimizer::optimize(&plan);
616 let s = optimized.display();
617 assert!(s.contains("projection"));
618 }
619}