Skip to main content

cynos_query/optimizer/
predicate_pushdown.rs

1//! Predicate pushdown optimization pass.
2//!
3//! Pushes filter predicates down the plan tree as close to the data source
4//! as possible to reduce the amount of data processed.
5//!
6//! Key optimizations:
7//! 1. Push filters through Sort (doesn't change semantics)
8//! 2. Merge consecutive filters into AND predicates
9//! 3. Push filters into Join when predicate references only one side
10//! 4. Cannot push through Aggregate or Limit (changes semantics)
11
12use crate::ast::{Expr, JoinType};
13use crate::optimizer::OptimizerPass;
14use crate::planner::LogicalPlan;
15use alloc::boxed::Box;
16use alloc::string::String;
17use hashbrown::HashSet;
18
19/// Predicate pushdown optimization.
20pub struct PredicatePushdown;
21
22impl OptimizerPass for PredicatePushdown {
23    fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
24        self.pushdown(plan)
25    }
26
27    fn name(&self) -> &'static str {
28        "predicate_pushdown"
29    }
30}
31
32impl PredicatePushdown {
33    fn pushdown(&self, plan: LogicalPlan) -> LogicalPlan {
34        match plan {
35            LogicalPlan::Filter { input, predicate } => {
36                let optimized_input = self.pushdown(*input);
37                self.try_push_filter(optimized_input, predicate)
38            }
39
40            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
41                input: Box::new(self.pushdown(*input)),
42                columns,
43            },
44
45            LogicalPlan::Join {
46                left,
47                right,
48                condition,
49                join_type,
50            } => LogicalPlan::Join {
51                left: Box::new(self.pushdown(*left)),
52                right: Box::new(self.pushdown(*right)),
53                condition,
54                join_type,
55            },
56
57            LogicalPlan::Aggregate {
58                input,
59                group_by,
60                aggregates,
61            } => LogicalPlan::Aggregate {
62                input: Box::new(self.pushdown(*input)),
63                group_by,
64                aggregates,
65            },
66
67            LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
68                input: Box::new(self.pushdown(*input)),
69                order_by,
70            },
71
72            LogicalPlan::Limit {
73                input,
74                limit,
75                offset,
76            } => LogicalPlan::Limit {
77                input: Box::new(self.pushdown(*input)),
78                limit,
79                offset,
80            },
81
82            LogicalPlan::CrossProduct { left, right } => LogicalPlan::CrossProduct {
83                left: Box::new(self.pushdown(*left)),
84                right: Box::new(self.pushdown(*right)),
85            },
86
87            LogicalPlan::Union { left, right, all } => LogicalPlan::Union {
88                left: Box::new(self.pushdown(*left)),
89                right: Box::new(self.pushdown(*right)),
90                all,
91            },
92
93            // Leaf nodes - no transformation
94            LogicalPlan::Scan { .. }
95            | LogicalPlan::IndexScan { .. }
96            | LogicalPlan::IndexGet { .. }
97            | LogicalPlan::IndexInGet { .. }
98            | LogicalPlan::GinIndexScan { .. }
99            | LogicalPlan::GinIndexScanMulti { .. }
100            | LogicalPlan::Empty => plan,
101        }
102    }
103
104    fn try_push_filter(&self, input: LogicalPlan, predicate: Expr) -> LogicalPlan {
105        match input {
106            // Push filter below projection if predicate doesn't reference projected columns
107            LogicalPlan::Project {
108                input: proj_input,
109                columns,
110            } => {
111                // For simplicity, we don't push through projection in this basic implementation
112                // A full implementation would check if the predicate can be evaluated before projection
113                LogicalPlan::Filter {
114                    input: Box::new(LogicalPlan::Project {
115                        input: proj_input,
116                        columns,
117                    }),
118                    predicate,
119                }
120            }
121
122            // Push filter into join if predicate references only one side
123            LogicalPlan::Join {
124                left,
125                right,
126                condition,
127                join_type,
128            } => {
129                self.push_filter_into_join(*left, *right, condition, join_type, predicate)
130            }
131
132            // Can't push filter below aggregate
133            LogicalPlan::Aggregate { .. } => LogicalPlan::Filter {
134                input: Box::new(input),
135                predicate,
136            },
137
138            // Push filter below sort
139            LogicalPlan::Sort {
140                input: sort_input,
141                order_by,
142            } => LogicalPlan::Sort {
143                input: Box::new(self.try_push_filter(*sort_input, predicate)),
144                order_by,
145            },
146
147            // Push filter below limit (careful - this changes semantics for LIMIT)
148            // For correctness, we don't push below LIMIT
149            LogicalPlan::Limit { .. } => LogicalPlan::Filter {
150                input: Box::new(input),
151                predicate,
152            },
153
154            // Filter on scan - keep as is
155            LogicalPlan::Scan { .. }
156            | LogicalPlan::IndexScan { .. }
157            | LogicalPlan::IndexGet { .. }
158            | LogicalPlan::IndexInGet { .. }
159            | LogicalPlan::GinIndexScan { .. }
160            | LogicalPlan::GinIndexScanMulti { .. } => LogicalPlan::Filter {
161                input: Box::new(input),
162                predicate,
163            },
164
165            // Merge consecutive filters
166            LogicalPlan::Filter {
167                input: inner_input,
168                predicate: inner_pred,
169            } => LogicalPlan::Filter {
170                input: inner_input,
171                predicate: Expr::and(inner_pred, predicate),
172            },
173
174            _ => LogicalPlan::Filter {
175                input: Box::new(input),
176                predicate,
177            },
178        }
179    }
180
181    /// Push filter into join based on which tables the predicate references.
182    fn push_filter_into_join(
183        &self,
184        left: LogicalPlan,
185        right: LogicalPlan,
186        condition: Expr,
187        join_type: JoinType,
188        predicate: Expr,
189    ) -> LogicalPlan {
190        // Extract tables referenced by each side of the join
191        let left_tables = self.extract_tables(&left);
192        let right_tables = self.extract_tables(&right);
193
194        // Extract tables referenced by the predicate
195        let pred_tables = self.extract_predicate_tables(&predicate);
196
197        // Check if predicate references only left side
198        let refs_left = pred_tables.iter().any(|t| left_tables.contains(t));
199        let refs_right = pred_tables.iter().any(|t| right_tables.contains(t));
200
201        match join_type {
202            JoinType::Inner => {
203                // For inner join, we can push to either side
204                if refs_left && !refs_right {
205                    // Push to left side
206                    LogicalPlan::Join {
207                        left: Box::new(self.try_push_filter(left, predicate)),
208                        right: Box::new(right),
209                        condition,
210                        join_type,
211                    }
212                } else if refs_right && !refs_left {
213                    // Push to right side
214                    LogicalPlan::Join {
215                        left: Box::new(left),
216                        right: Box::new(self.try_push_filter(right, predicate)),
217                        condition,
218                        join_type,
219                    }
220                } else {
221                    // References both sides or neither - keep above join
222                    LogicalPlan::Filter {
223                        input: Box::new(LogicalPlan::Join {
224                            left: Box::new(left),
225                            right: Box::new(right),
226                            condition,
227                            join_type,
228                        }),
229                        predicate,
230                    }
231                }
232            }
233
234            JoinType::LeftOuter => {
235                // For left outer join:
236                // - Can push predicates on LEFT side down (preserves NULL extension)
237                // - Cannot push predicates on RIGHT side (would filter out NULLs incorrectly)
238                if refs_left && !refs_right {
239                    LogicalPlan::Join {
240                        left: Box::new(self.try_push_filter(left, predicate)),
241                        right: Box::new(right),
242                        condition,
243                        join_type,
244                    }
245                } else {
246                    // Keep above join
247                    LogicalPlan::Filter {
248                        input: Box::new(LogicalPlan::Join {
249                            left: Box::new(left),
250                            right: Box::new(right),
251                            condition,
252                            join_type,
253                        }),
254                        predicate,
255                    }
256                }
257            }
258
259            JoinType::RightOuter => {
260                // For right outer join:
261                // - Can push predicates on RIGHT side down
262                // - Cannot push predicates on LEFT side
263                if refs_right && !refs_left {
264                    LogicalPlan::Join {
265                        left: Box::new(left),
266                        right: Box::new(self.try_push_filter(right, predicate)),
267                        condition,
268                        join_type,
269                    }
270                } else {
271                    LogicalPlan::Filter {
272                        input: Box::new(LogicalPlan::Join {
273                            left: Box::new(left),
274                            right: Box::new(right),
275                            condition,
276                            join_type,
277                        }),
278                        predicate,
279                    }
280                }
281            }
282
283            JoinType::FullOuter | JoinType::Cross => {
284                // For full outer join and cross join, cannot push predicates
285                LogicalPlan::Filter {
286                    input: Box::new(LogicalPlan::Join {
287                        left: Box::new(left),
288                        right: Box::new(right),
289                        condition,
290                        join_type,
291                    }),
292                    predicate,
293                }
294            }
295        }
296    }
297
298    /// Extract all table names referenced by a plan.
299    fn extract_tables(&self, plan: &LogicalPlan) -> HashSet<String> {
300        let mut tables = HashSet::new();
301        self.collect_tables(plan, &mut tables);
302        tables
303    }
304
305    fn collect_tables(&self, plan: &LogicalPlan, tables: &mut HashSet<String>) {
306        match plan {
307            LogicalPlan::Scan { table } => {
308                tables.insert(table.clone());
309            }
310            LogicalPlan::IndexScan { table, .. }
311            | LogicalPlan::IndexGet { table, .. }
312            | LogicalPlan::IndexInGet { table, .. }
313            | LogicalPlan::GinIndexScan { table, .. }
314            | LogicalPlan::GinIndexScanMulti { table, .. } => {
315                tables.insert(table.clone());
316            }
317            LogicalPlan::Filter { input, .. }
318            | LogicalPlan::Project { input, .. }
319            | LogicalPlan::Aggregate { input, .. }
320            | LogicalPlan::Sort { input, .. }
321            | LogicalPlan::Limit { input, .. } => {
322                self.collect_tables(input, tables);
323            }
324            LogicalPlan::Join { left, right, .. }
325            | LogicalPlan::CrossProduct { left, right }
326            | LogicalPlan::Union { left, right, .. } => {
327                self.collect_tables(left, tables);
328                self.collect_tables(right, tables);
329            }
330            LogicalPlan::Empty => {}
331        }
332    }
333
334    /// Extract all table names referenced by a predicate expression.
335    fn extract_predicate_tables(&self, expr: &Expr) -> HashSet<String> {
336        let mut tables = HashSet::new();
337        self.collect_expr_tables(expr, &mut tables);
338        tables
339    }
340
341    fn collect_expr_tables(&self, expr: &Expr, tables: &mut HashSet<String>) {
342        match expr {
343            Expr::Column(col) => {
344                tables.insert(col.table.clone());
345            }
346            Expr::BinaryOp { left, right, .. } => {
347                self.collect_expr_tables(left, tables);
348                self.collect_expr_tables(right, tables);
349            }
350            Expr::UnaryOp { expr, .. } => {
351                self.collect_expr_tables(expr, tables);
352            }
353            Expr::Function { args, .. } => {
354                for arg in args {
355                    self.collect_expr_tables(arg, tables);
356                }
357            }
358            Expr::Aggregate { expr, .. } => {
359                if let Some(e) = expr {
360                    self.collect_expr_tables(e, tables);
361                }
362            }
363            Expr::Between { expr, low, high } => {
364                self.collect_expr_tables(expr, tables);
365                self.collect_expr_tables(low, tables);
366                self.collect_expr_tables(high, tables);
367            }
368            Expr::In { expr, list } => {
369                self.collect_expr_tables(expr, tables);
370                for item in list {
371                    self.collect_expr_tables(item, tables);
372                }
373            }
374            Expr::Like { expr, .. } => {
375                self.collect_expr_tables(expr, tables);
376            }
377            Expr::NotBetween { expr, low, high } => {
378                self.collect_expr_tables(expr, tables);
379                self.collect_expr_tables(low, tables);
380                self.collect_expr_tables(high, tables);
381            }
382            Expr::NotIn { expr, list } => {
383                self.collect_expr_tables(expr, tables);
384                for item in list {
385                    self.collect_expr_tables(item, tables);
386                }
387            }
388            Expr::NotLike { expr, .. } => {
389                self.collect_expr_tables(expr, tables);
390            }
391            Expr::Match { expr, .. } => {
392                self.collect_expr_tables(expr, tables);
393            }
394            Expr::NotMatch { expr, .. } => {
395                self.collect_expr_tables(expr, tables);
396            }
397            Expr::Literal(_) => {}
398        }
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::ast::{BinaryOp, SortOrder};
406
407    #[test]
408    fn test_predicate_pushdown_basic() {
409        let pass = PredicatePushdown;
410
411        // Filter on scan should stay as is
412        let plan = LogicalPlan::filter(
413            LogicalPlan::scan("users"),
414            Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64)),
415        );
416
417        let optimized = pass.optimize(plan);
418        assert!(matches!(optimized, LogicalPlan::Filter { .. }));
419    }
420
421    #[test]
422    fn test_predicate_pushdown_through_sort() {
423        let pass = PredicatePushdown;
424
425        // Filter above sort should be pushed below sort
426        let plan = LogicalPlan::filter(
427            LogicalPlan::sort(
428                LogicalPlan::scan("users"),
429                alloc::vec![(Expr::column("users", "name", 1), SortOrder::Asc)],
430            ),
431            Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64)),
432        );
433
434        let optimized = pass.optimize(plan);
435
436        // Should be Sort(Filter(Scan))
437        assert!(matches!(optimized, LogicalPlan::Sort { .. }));
438    }
439
440    #[test]
441    fn test_merge_consecutive_filters() {
442        let pass = PredicatePushdown;
443
444        // Two consecutive filters should be merged
445        let plan = LogicalPlan::filter(
446            LogicalPlan::filter(
447                LogicalPlan::scan("users"),
448                Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64)),
449            ),
450            Expr::eq(Expr::column("users", "active", 2), Expr::literal(true)),
451        );
452
453        let optimized = pass.optimize(plan);
454
455        // Should be a single Filter with AND predicate
456        if let LogicalPlan::Filter { predicate, .. } = optimized {
457            assert!(matches!(
458                predicate,
459                Expr::BinaryOp {
460                    op: BinaryOp::And,
461                    ..
462                }
463            ));
464        } else {
465            panic!("Expected Filter");
466        }
467    }
468
469    #[test]
470    fn test_push_filter_into_inner_join_left() {
471        let pass = PredicatePushdown;
472
473        // Filter on left table should be pushed into left side of inner join
474        let plan = LogicalPlan::filter(
475            LogicalPlan::inner_join(
476                LogicalPlan::scan("users"),
477                LogicalPlan::scan("orders"),
478                Expr::eq(
479                    Expr::column("users", "id", 0),
480                    Expr::column("orders", "user_id", 0),
481                ),
482            ),
483            Expr::eq(Expr::column("users", "active", 1), Expr::literal(true)),
484        );
485
486        let optimized = pass.optimize(plan);
487
488        // Should be Join(Filter(Scan(users)), Scan(orders))
489        if let LogicalPlan::Join { left, .. } = optimized {
490            assert!(matches!(*left, LogicalPlan::Filter { .. }));
491        } else {
492            panic!("Expected Join, got {:?}", optimized);
493        }
494    }
495
496    #[test]
497    fn test_push_filter_into_inner_join_right() {
498        let pass = PredicatePushdown;
499
500        // Filter on right table should be pushed into right side of inner join
501        let plan = LogicalPlan::filter(
502            LogicalPlan::inner_join(
503                LogicalPlan::scan("users"),
504                LogicalPlan::scan("orders"),
505                Expr::eq(
506                    Expr::column("users", "id", 0),
507                    Expr::column("orders", "user_id", 0),
508                ),
509            ),
510            Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
511        );
512
513        let optimized = pass.optimize(plan);
514
515        // Should be Join(Scan(users), Filter(Scan(orders)))
516        if let LogicalPlan::Join { right, .. } = optimized {
517            assert!(matches!(*right, LogicalPlan::Filter { .. }));
518        } else {
519            panic!("Expected Join, got {:?}", optimized);
520        }
521    }
522
523    #[test]
524    fn test_filter_on_both_sides_stays_above_join() {
525        let pass = PredicatePushdown;
526
527        // Filter referencing both tables should stay above join
528        let plan = LogicalPlan::filter(
529            LogicalPlan::inner_join(
530                LogicalPlan::scan("users"),
531                LogicalPlan::scan("orders"),
532                Expr::eq(
533                    Expr::column("users", "id", 0),
534                    Expr::column("orders", "user_id", 0),
535                ),
536            ),
537            Expr::gt(
538                Expr::column("users", "balance", 2),
539                Expr::column("orders", "amount", 1),
540            ),
541        );
542
543        let optimized = pass.optimize(plan);
544
545        // Should be Filter(Join(...))
546        assert!(matches!(optimized, LogicalPlan::Filter { .. }));
547        if let LogicalPlan::Filter { input, .. } = optimized {
548            assert!(matches!(*input, LogicalPlan::Join { .. }));
549        }
550    }
551
552    #[test]
553    fn test_left_join_push_to_left_only() {
554        let pass = PredicatePushdown;
555
556        // For left outer join, can only push predicates on left side
557        let plan = LogicalPlan::filter(
558            LogicalPlan::left_join(
559                LogicalPlan::scan("users"),
560                LogicalPlan::scan("orders"),
561                Expr::eq(
562                    Expr::column("users", "id", 0),
563                    Expr::column("orders", "user_id", 0),
564                ),
565            ),
566            Expr::eq(Expr::column("users", "active", 1), Expr::literal(true)),
567        );
568
569        let optimized = pass.optimize(plan);
570
571        // Should push to left side
572        if let LogicalPlan::Join { left, join_type, .. } = optimized {
573            assert_eq!(join_type, JoinType::LeftOuter);
574            assert!(matches!(*left, LogicalPlan::Filter { .. }));
575        } else {
576            panic!("Expected Join, got {:?}", optimized);
577        }
578    }
579
580    #[test]
581    fn test_left_join_right_predicate_stays_above() {
582        let pass = PredicatePushdown;
583
584        // For left outer join, predicates on right side must stay above
585        let plan = LogicalPlan::filter(
586            LogicalPlan::left_join(
587                LogicalPlan::scan("users"),
588                LogicalPlan::scan("orders"),
589                Expr::eq(
590                    Expr::column("users", "id", 0),
591                    Expr::column("orders", "user_id", 0),
592                ),
593            ),
594            Expr::gt(Expr::column("orders", "amount", 1), Expr::literal(100i64)),
595        );
596
597        let optimized = pass.optimize(plan);
598
599        // Should stay above join
600        assert!(matches!(optimized, LogicalPlan::Filter { .. }));
601        if let LogicalPlan::Filter { input, .. } = optimized {
602            assert!(matches!(*input, LogicalPlan::Join { .. }));
603        }
604    }
605
606    #[test]
607    fn test_extract_tables() {
608        let pass = PredicatePushdown;
609
610        let plan = LogicalPlan::inner_join(
611            LogicalPlan::scan("users"),
612            LogicalPlan::filter(
613                LogicalPlan::scan("orders"),
614                Expr::gt(Expr::column("orders", "amount", 0), Expr::literal(0i64)),
615            ),
616            Expr::eq(
617                Expr::column("users", "id", 0),
618                Expr::column("orders", "user_id", 1),
619            ),
620        );
621
622        let tables = pass.extract_tables(&plan);
623        assert!(tables.contains("users"));
624        assert!(tables.contains("orders"));
625        assert_eq!(tables.len(), 2);
626    }
627
628    #[test]
629    fn test_extract_predicate_tables() {
630        let pass = PredicatePushdown;
631
632        let pred = Expr::and(
633            Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64)),
634            Expr::gt(
635                Expr::column("orders", "amount", 0),
636                Expr::column("products", "price", 0),
637            ),
638        );
639
640        let tables = pass.extract_predicate_tables(&pred);
641        assert!(tables.contains("users"));
642        assert!(tables.contains("orders"));
643        assert!(tables.contains("products"));
644        assert_eq!(tables.len(), 3);
645    }
646}