1use crate::ast::{Expr, JoinType};
13use crate::optimizer::OptimizerPass;
14use crate::planner::LogicalPlan;
15use alloc::boxed::Box;
16use alloc::string::String;
17use hashbrown::HashSet;
18
19pub struct PredicatePushdown;
21
22impl OptimizerPass for PredicatePushdown {
23 fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
24 self.pushdown(plan)
25 }
26
27 fn name(&self) -> &'static str {
28 "predicate_pushdown"
29 }
30}
31
32impl PredicatePushdown {
33 fn pushdown(&self, plan: LogicalPlan) -> LogicalPlan {
34 match plan {
35 LogicalPlan::Filter { input, predicate } => {
36 let optimized_input = self.pushdown(*input);
37 self.try_push_filter(optimized_input, predicate)
38 }
39
40 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
41 input: Box::new(self.pushdown(*input)),
42 columns,
43 },
44
45 LogicalPlan::Join {
46 left,
47 right,
48 condition,
49 join_type,
50 } => LogicalPlan::Join {
51 left: Box::new(self.pushdown(*left)),
52 right: Box::new(self.pushdown(*right)),
53 condition,
54 join_type,
55 },
56
57 LogicalPlan::Aggregate {
58 input,
59 group_by,
60 aggregates,
61 } => LogicalPlan::Aggregate {
62 input: Box::new(self.pushdown(*input)),
63 group_by,
64 aggregates,
65 },
66
67 LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
68 input: Box::new(self.pushdown(*input)),
69 order_by,
70 },
71
72 LogicalPlan::Limit {
73 input,
74 limit,
75 offset,
76 } => LogicalPlan::Limit {
77 input: Box::new(self.pushdown(*input)),
78 limit,
79 offset,
80 },
81
82 LogicalPlan::CrossProduct { left, right } => LogicalPlan::CrossProduct {
83 left: Box::new(self.pushdown(*left)),
84 right: Box::new(self.pushdown(*right)),
85 },
86
87 LogicalPlan::Union { left, right, all } => LogicalPlan::Union {
88 left: Box::new(self.pushdown(*left)),
89 right: Box::new(self.pushdown(*right)),
90 all,
91 },
92
93 LogicalPlan::Scan { .. }
95 | LogicalPlan::IndexScan { .. }
96 | LogicalPlan::IndexGet { .. }
97 | LogicalPlan::IndexInGet { .. }
98 | LogicalPlan::GinIndexScan { .. }
99 | LogicalPlan::GinIndexScanMulti { .. }
100 | LogicalPlan::Empty => plan,
101 }
102 }
103
104 fn try_push_filter(&self, input: LogicalPlan, predicate: Expr) -> LogicalPlan {
105 match input {
106 LogicalPlan::Project {
108 input: proj_input,
109 columns,
110 } => {
111 LogicalPlan::Filter {
114 input: Box::new(LogicalPlan::Project {
115 input: proj_input,
116 columns,
117 }),
118 predicate,
119 }
120 }
121
122 LogicalPlan::Join {
124 left,
125 right,
126 condition,
127 join_type,
128 } => {
129 self.push_filter_into_join(*left, *right, condition, join_type, predicate)
130 }
131
132 LogicalPlan::Aggregate { .. } => LogicalPlan::Filter {
134 input: Box::new(input),
135 predicate,
136 },
137
138 LogicalPlan::Sort {
140 input: sort_input,
141 order_by,
142 } => LogicalPlan::Sort {
143 input: Box::new(self.try_push_filter(*sort_input, predicate)),
144 order_by,
145 },
146
147 LogicalPlan::Limit { .. } => LogicalPlan::Filter {
150 input: Box::new(input),
151 predicate,
152 },
153
154 LogicalPlan::Scan { .. }
156 | LogicalPlan::IndexScan { .. }
157 | LogicalPlan::IndexGet { .. }
158 | LogicalPlan::IndexInGet { .. }
159 | LogicalPlan::GinIndexScan { .. }
160 | LogicalPlan::GinIndexScanMulti { .. } => LogicalPlan::Filter {
161 input: Box::new(input),
162 predicate,
163 },
164
165 LogicalPlan::Filter {
167 input: inner_input,
168 predicate: inner_pred,
169 } => LogicalPlan::Filter {
170 input: inner_input,
171 predicate: Expr::and(inner_pred, predicate),
172 },
173
174 _ => LogicalPlan::Filter {
175 input: Box::new(input),
176 predicate,
177 },
178 }
179 }
180
181 fn push_filter_into_join(
183 &self,
184 left: LogicalPlan,
185 right: LogicalPlan,
186 condition: Expr,
187 join_type: JoinType,
188 predicate: Expr,
189 ) -> LogicalPlan {
190 let left_tables = self.extract_tables(&left);
192 let right_tables = self.extract_tables(&right);
193
194 let pred_tables = self.extract_predicate_tables(&predicate);
196
197 let refs_left = pred_tables.iter().any(|t| left_tables.contains(t));
199 let refs_right = pred_tables.iter().any(|t| right_tables.contains(t));
200
201 match join_type {
202 JoinType::Inner => {
203 if refs_left && !refs_right {
205 LogicalPlan::Join {
207 left: Box::new(self.try_push_filter(left, predicate)),
208 right: Box::new(right),
209 condition,
210 join_type,
211 }
212 } else if refs_right && !refs_left {
213 LogicalPlan::Join {
215 left: Box::new(left),
216 right: Box::new(self.try_push_filter(right, predicate)),
217 condition,
218 join_type,
219 }
220 } else {
221 LogicalPlan::Filter {
223 input: Box::new(LogicalPlan::Join {
224 left: Box::new(left),
225 right: Box::new(right),
226 condition,
227 join_type,
228 }),
229 predicate,
230 }
231 }
232 }
233
234 JoinType::LeftOuter => {
235 if refs_left && !refs_right {
239 LogicalPlan::Join {
240 left: Box::new(self.try_push_filter(left, predicate)),
241 right: Box::new(right),
242 condition,
243 join_type,
244 }
245 } else {
246 LogicalPlan::Filter {
248 input: Box::new(LogicalPlan::Join {
249 left: Box::new(left),
250 right: Box::new(right),
251 condition,
252 join_type,
253 }),
254 predicate,
255 }
256 }
257 }
258
259 JoinType::RightOuter => {
260 if refs_right && !refs_left {
264 LogicalPlan::Join {
265 left: Box::new(left),
266 right: Box::new(self.try_push_filter(right, predicate)),
267 condition,
268 join_type,
269 }
270 } else {
271 LogicalPlan::Filter {
272 input: Box::new(LogicalPlan::Join {
273 left: Box::new(left),
274 right: Box::new(right),
275 condition,
276 join_type,
277 }),
278 predicate,
279 }
280 }
281 }
282
283 JoinType::FullOuter | JoinType::Cross => {
284 LogicalPlan::Filter {
286 input: Box::new(LogicalPlan::Join {
287 left: Box::new(left),
288 right: Box::new(right),
289 condition,
290 join_type,
291 }),
292 predicate,
293 }
294 }
295 }
296 }
297
298 fn extract_tables(&self, plan: &LogicalPlan) -> HashSet<String> {
300 let mut tables = HashSet::new();
301 self.collect_tables(plan, &mut tables);
302 tables
303 }
304
305 fn collect_tables(&self, plan: &LogicalPlan, tables: &mut HashSet<String>) {
306 match plan {
307 LogicalPlan::Scan { table } => {
308 tables.insert(table.clone());
309 }
310 LogicalPlan::IndexScan { table, .. }
311 | LogicalPlan::IndexGet { table, .. }
312 | LogicalPlan::IndexInGet { table, .. }
313 | LogicalPlan::GinIndexScan { table, .. }
314 | LogicalPlan::GinIndexScanMulti { table, .. } => {
315 tables.insert(table.clone());
316 }
317 LogicalPlan::Filter { input, .. }
318 | LogicalPlan::Project { input, .. }
319 | LogicalPlan::Aggregate { input, .. }
320 | LogicalPlan::Sort { input, .. }
321 | LogicalPlan::Limit { input, .. } => {
322 self.collect_tables(input, tables);
323 }
324 LogicalPlan::Join { left, right, .. }
325 | LogicalPlan::CrossProduct { left, right }
326 | LogicalPlan::Union { left, right, .. } => {
327 self.collect_tables(left, tables);
328 self.collect_tables(right, tables);
329 }
330 LogicalPlan::Empty => {}
331 }
332 }
333
334 fn extract_predicate_tables(&self, expr: &Expr) -> HashSet<String> {
336 let mut tables = HashSet::new();
337 self.collect_expr_tables(expr, &mut tables);
338 tables
339 }
340
341 fn collect_expr_tables(&self, expr: &Expr, tables: &mut HashSet<String>) {
342 match expr {
343 Expr::Column(col) => {
344 tables.insert(col.table.clone());
345 }
346 Expr::BinaryOp { left, right, .. } => {
347 self.collect_expr_tables(left, tables);
348 self.collect_expr_tables(right, tables);
349 }
350 Expr::UnaryOp { expr, .. } => {
351 self.collect_expr_tables(expr, tables);
352 }
353 Expr::Function { args, .. } => {
354 for arg in args {
355 self.collect_expr_tables(arg, tables);
356 }
357 }
358 Expr::Aggregate { expr, .. } => {
359 if let Some(e) = expr {
360 self.collect_expr_tables(e, tables);
361 }
362 }
363 Expr::Between { expr, low, high } => {
364 self.collect_expr_tables(expr, tables);
365 self.collect_expr_tables(low, tables);
366 self.collect_expr_tables(high, tables);
367 }
368 Expr::In { expr, list } => {
369 self.collect_expr_tables(expr, tables);
370 for item in list {
371 self.collect_expr_tables(item, tables);
372 }
373 }
374 Expr::Like { expr, .. } => {
375 self.collect_expr_tables(expr, tables);
376 }
377 Expr::NotBetween { expr, low, high } => {
378 self.collect_expr_tables(expr, tables);
379 self.collect_expr_tables(low, tables);
380 self.collect_expr_tables(high, tables);
381 }
382 Expr::NotIn { expr, list } => {
383 self.collect_expr_tables(expr, tables);
384 for item in list {
385 self.collect_expr_tables(item, tables);
386 }
387 }
388 Expr::NotLike { expr, .. } => {
389 self.collect_expr_tables(expr, tables);
390 }
391 Expr::Match { expr, .. } => {
392 self.collect_expr_tables(expr, tables);
393 }
394 Expr::NotMatch { expr, .. } => {
395 self.collect_expr_tables(expr, tables);
396 }
397 Expr::Literal(_) => {}
398 }
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use crate::ast::{BinaryOp, SortOrder};
406
407 #[test]
408 fn test_predicate_pushdown_basic() {
409 let pass = PredicatePushdown;
410
411 let plan = LogicalPlan::filter(
413 LogicalPlan::scan("users"),
414 Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64)),
415 );
416
417 let optimized = pass.optimize(plan);
418 assert!(matches!(optimized, LogicalPlan::Filter { .. }));
419 }
420
421 #[test]
422 fn test_predicate_pushdown_through_sort() {
423 let pass = PredicatePushdown;
424
425 let plan = LogicalPlan::filter(
427 LogicalPlan::sort(
428 LogicalPlan::scan("users"),
429 alloc::vec![(Expr::column("users", "name", 1), SortOrder::Asc)],
430 ),
431 Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64)),
432 );
433
434 let optimized = pass.optimize(plan);
435
436 assert!(matches!(optimized, LogicalPlan::Sort { .. }));
438 }
439
440 #[test]
441 fn test_merge_consecutive_filters() {
442 let pass = PredicatePushdown;
443
444 let plan = LogicalPlan::filter(
446 LogicalPlan::filter(
447 LogicalPlan::scan("users"),
448 Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64)),
449 ),
450 Expr::eq(Expr::column("users", "active", 2), Expr::literal(true)),
451 );
452
453 let optimized = pass.optimize(plan);
454
455 if let LogicalPlan::Filter { predicate, .. } = optimized {
457 assert!(matches!(
458 predicate,
459 Expr::BinaryOp {
460 op: BinaryOp::And,
461 ..
462 }
463 ));
464 } else {
465 panic!("Expected Filter");
466 }
467 }
468
469 #[test]
470 fn test_push_filter_into_inner_join_left() {
471 let pass = PredicatePushdown;
472
473 let plan = LogicalPlan::filter(
475 LogicalPlan::inner_join(
476 LogicalPlan::scan("users"),
477 LogicalPlan::scan("orders"),
478 Expr::eq(
479 Expr::column("users", "id", 0),
480 Expr::column("orders", "user_id", 0),
481 ),
482 ),
483 Expr::eq(Expr::column("users", "active", 1), Expr::literal(true)),
484 );
485
486 let optimized = pass.optimize(plan);
487
488 if let LogicalPlan::Join { left, .. } = optimized {
490 assert!(matches!(*left, LogicalPlan::Filter { .. }));
491 } else {
492 panic!("Expected Join, got {:?}", optimized);
493 }
494 }
495
496 #[test]
497 fn test_push_filter_into_inner_join_right() {
498 let pass = PredicatePushdown;
499
500 let plan = LogicalPlan::filter(
502 LogicalPlan::inner_join(
503 LogicalPlan::scan("users"),
504 LogicalPlan::scan("orders"),
505 Expr::eq(
506 Expr::column("users", "id", 0),
507 Expr::column("orders", "user_id", 0),
508 ),
509 ),
510 Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
511 );
512
513 let optimized = pass.optimize(plan);
514
515 if let LogicalPlan::Join { right, .. } = optimized {
517 assert!(matches!(*right, LogicalPlan::Filter { .. }));
518 } else {
519 panic!("Expected Join, got {:?}", optimized);
520 }
521 }
522
523 #[test]
524 fn test_filter_on_both_sides_stays_above_join() {
525 let pass = PredicatePushdown;
526
527 let plan = LogicalPlan::filter(
529 LogicalPlan::inner_join(
530 LogicalPlan::scan("users"),
531 LogicalPlan::scan("orders"),
532 Expr::eq(
533 Expr::column("users", "id", 0),
534 Expr::column("orders", "user_id", 0),
535 ),
536 ),
537 Expr::gt(
538 Expr::column("users", "balance", 2),
539 Expr::column("orders", "amount", 1),
540 ),
541 );
542
543 let optimized = pass.optimize(plan);
544
545 assert!(matches!(optimized, LogicalPlan::Filter { .. }));
547 if let LogicalPlan::Filter { input, .. } = optimized {
548 assert!(matches!(*input, LogicalPlan::Join { .. }));
549 }
550 }
551
552 #[test]
553 fn test_left_join_push_to_left_only() {
554 let pass = PredicatePushdown;
555
556 let plan = LogicalPlan::filter(
558 LogicalPlan::left_join(
559 LogicalPlan::scan("users"),
560 LogicalPlan::scan("orders"),
561 Expr::eq(
562 Expr::column("users", "id", 0),
563 Expr::column("orders", "user_id", 0),
564 ),
565 ),
566 Expr::eq(Expr::column("users", "active", 1), Expr::literal(true)),
567 );
568
569 let optimized = pass.optimize(plan);
570
571 if let LogicalPlan::Join { left, join_type, .. } = optimized {
573 assert_eq!(join_type, JoinType::LeftOuter);
574 assert!(matches!(*left, LogicalPlan::Filter { .. }));
575 } else {
576 panic!("Expected Join, got {:?}", optimized);
577 }
578 }
579
580 #[test]
581 fn test_left_join_right_predicate_stays_above() {
582 let pass = PredicatePushdown;
583
584 let plan = LogicalPlan::filter(
586 LogicalPlan::left_join(
587 LogicalPlan::scan("users"),
588 LogicalPlan::scan("orders"),
589 Expr::eq(
590 Expr::column("users", "id", 0),
591 Expr::column("orders", "user_id", 0),
592 ),
593 ),
594 Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
595 );
596
597 let optimized = pass.optimize(plan);
598
599 assert!(matches!(optimized, LogicalPlan::Filter { .. }));
601 if let LogicalPlan::Filter { input, .. } = optimized {
602 assert!(matches!(*input, LogicalPlan::Join { .. }));
603 }
604 }
605
606 #[test]
607 fn test_extract_tables() {
608 let pass = PredicatePushdown;
609
610 let plan = LogicalPlan::inner_join(
611 LogicalPlan::scan("users"),
612 LogicalPlan::filter(
613 LogicalPlan::scan("orders"),
614 Expr::gt(Expr::column("orders", "amount", 0), Expr::literal(0i64)),
615 ),
616 Expr::eq(
617 Expr::column("users", "id", 0),
618 Expr::column("orders", "user_id", 1),
619 ),
620 );
621
622 let tables = pass.extract_tables(&plan);
623 assert!(tables.contains("users"));
624 assert!(tables.contains("orders"));
625 assert_eq!(tables.len(), 2);
626 }
627
628 #[test]
629 fn test_extract_predicate_tables() {
630 let pass = PredicatePushdown;
631
632 let pred = Expr::and(
633 Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64)),
634 Expr::gt(
635 Expr::column("orders", "amount", 0),
636 Expr::column("products", "price", 0),
637 ),
638 );
639
640 let tables = pass.extract_predicate_tables(&pred);
641 assert!(tables.contains("users"));
642 assert!(tables.contains("orders"));
643 assert!(tables.contains("products"));
644 assert_eq!(tables.len(), 3);
645 }
646}