Skip to main content

cynos_query/optimizer/
implicit_joins.rs

1//! Implicit joins pass - converts CrossProduct + Filter patterns to Join nodes.
2//!
3//! This pass identifies patterns where a Filter with a join predicate sits
4//! directly above a CrossProduct, and converts them to proper Join nodes.
5//!
6//! Example:
7//! ```text
8//! Filter(a.id = b.a_id)       =>       Join(a.id = b.a_id)
9//!        |                              /            \
10//!   CrossProduct                     Scan(a)       Scan(b)
11//!    /        \
12//! Scan(a)   Scan(b)
13//! ```
14//!
15//! This transformation is important because:
16//! 1. Join nodes can use optimized join algorithms (hash join, merge join)
17//! 2. It enables further optimizations like index join selection
18
19use crate::ast::{Expr, JoinType};
20use crate::optimizer::OptimizerPass;
21use crate::planner::LogicalPlan;
22use alloc::boxed::Box;
23
24/// Pass that converts CrossProduct + Filter patterns to Join nodes.
25pub struct ImplicitJoinsPass;
26
27impl OptimizerPass for ImplicitJoinsPass {
28    fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
29        self.traverse(plan)
30    }
31
32    fn name(&self) -> &'static str {
33        "implicit_joins"
34    }
35}
36
37impl ImplicitJoinsPass {
38    /// Recursively traverses the plan tree and converts implicit joins.
39    fn traverse(&self, plan: LogicalPlan) -> LogicalPlan {
40        match plan {
41            LogicalPlan::Filter { input, predicate } => {
42                // First, recursively optimize the input
43                let optimized_input = self.traverse(*input);
44
45                // Check if this is a join predicate over a cross product
46                if let LogicalPlan::CrossProduct { left, right } = &optimized_input {
47                    if self.is_join_predicate(&predicate, left, right) {
48                        // Convert to Join
49                        return LogicalPlan::Join {
50                            left: left.clone(),
51                            right: right.clone(),
52                            condition: predicate,
53                            join_type: JoinType::Inner,
54                        };
55                    }
56                }
57
58                // Not a join pattern, keep as Filter
59                LogicalPlan::Filter {
60                    input: Box::new(optimized_input),
61                    predicate,
62                }
63            }
64
65            LogicalPlan::CrossProduct { left, right } => LogicalPlan::CrossProduct {
66                left: Box::new(self.traverse(*left)),
67                right: Box::new(self.traverse(*right)),
68            },
69
70            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
71                input: Box::new(self.traverse(*input)),
72                columns,
73            },
74
75            LogicalPlan::Join {
76                left,
77                right,
78                condition,
79                join_type,
80            } => LogicalPlan::Join {
81                left: Box::new(self.traverse(*left)),
82                right: Box::new(self.traverse(*right)),
83                condition,
84                join_type,
85            },
86
87            LogicalPlan::Aggregate {
88                input,
89                group_by,
90                aggregates,
91            } => LogicalPlan::Aggregate {
92                input: Box::new(self.traverse(*input)),
93                group_by,
94                aggregates,
95            },
96
97            LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
98                input: Box::new(self.traverse(*input)),
99                order_by,
100            },
101
102            LogicalPlan::Limit {
103                input,
104                limit,
105                offset,
106            } => LogicalPlan::Limit {
107                input: Box::new(self.traverse(*input)),
108                limit,
109                offset,
110            },
111
112            LogicalPlan::Union { left, right, all } => LogicalPlan::Union {
113                left: Box::new(self.traverse(*left)),
114                right: Box::new(self.traverse(*right)),
115                all,
116            },
117
118            // Leaf nodes - no transformation needed
119            plan @ (LogicalPlan::Scan { .. }
120            | LogicalPlan::IndexScan { .. }
121            | LogicalPlan::IndexGet { .. }
122            | LogicalPlan::IndexInGet { .. }
123            | LogicalPlan::GinIndexScan { .. }
124            | LogicalPlan::GinIndexScanMulti { .. }
125            | LogicalPlan::Empty) => plan,
126        }
127    }
128
129    /// Checks if the predicate is a join predicate that references both sides
130    /// of the cross product.
131    fn is_join_predicate(
132        &self,
133        predicate: &Expr,
134        left: &LogicalPlan,
135        right: &LogicalPlan,
136    ) -> bool {
137        let left_tables = self.collect_tables(left);
138        let right_tables = self.collect_tables(right);
139        let predicate_tables = self.collect_predicate_tables(predicate);
140
141        // A join predicate must reference at least one table from each side
142        let refs_left = predicate_tables.iter().any(|t| left_tables.contains(t));
143        let refs_right = predicate_tables.iter().any(|t| right_tables.contains(t));
144
145        refs_left && refs_right
146    }
147
148    /// Collects all table names referenced in a plan.
149    fn collect_tables(&self, plan: &LogicalPlan) -> alloc::vec::Vec<alloc::string::String> {
150        let mut tables = alloc::vec::Vec::new();
151        self.collect_tables_recursive(plan, &mut tables);
152        tables
153    }
154
155    fn collect_tables_recursive(
156        &self,
157        plan: &LogicalPlan,
158        tables: &mut alloc::vec::Vec<alloc::string::String>,
159    ) {
160        match plan {
161            LogicalPlan::Scan { table } => tables.push(table.clone()),
162            LogicalPlan::IndexScan { table, .. }
163            | LogicalPlan::IndexGet { table, .. }
164            | LogicalPlan::IndexInGet { table, .. }
165            | LogicalPlan::GinIndexScan { table, .. }
166            | LogicalPlan::GinIndexScanMulti { table, .. } => {
167                tables.push(table.clone())
168            }
169            LogicalPlan::Filter { input, .. }
170            | LogicalPlan::Project { input, .. }
171            | LogicalPlan::Aggregate { input, .. }
172            | LogicalPlan::Sort { input, .. }
173            | LogicalPlan::Limit { input, .. } => {
174                self.collect_tables_recursive(input, tables);
175            }
176            LogicalPlan::Join { left, right, .. }
177            | LogicalPlan::CrossProduct { left, right }
178            | LogicalPlan::Union { left, right, .. } => {
179                self.collect_tables_recursive(left, tables);
180                self.collect_tables_recursive(right, tables);
181            }
182            LogicalPlan::Empty => {}
183        }
184    }
185
186    /// Collects all table names referenced in a predicate expression.
187    fn collect_predicate_tables(&self, expr: &Expr) -> alloc::vec::Vec<alloc::string::String> {
188        let mut tables = alloc::vec::Vec::new();
189        self.collect_expr_tables(expr, &mut tables);
190        tables
191    }
192
193    fn collect_expr_tables(
194        &self,
195        expr: &Expr,
196        tables: &mut alloc::vec::Vec<alloc::string::String>,
197    ) {
198        match expr {
199            Expr::Column(col_ref) => {
200                if !tables.contains(&col_ref.table) {
201                    tables.push(col_ref.table.clone());
202                }
203            }
204            Expr::BinaryOp { left, right, .. } => {
205                self.collect_expr_tables(left, tables);
206                self.collect_expr_tables(right, tables);
207            }
208            Expr::UnaryOp { expr, .. } => {
209                self.collect_expr_tables(expr, tables);
210            }
211            Expr::Aggregate { expr, .. } => {
212                if let Some(e) = expr {
213                    self.collect_expr_tables(e, tables);
214                }
215            }
216            Expr::Literal(_) => {}
217            // Handle other expression types
218            Expr::Function { args, .. } => {
219                for arg in args {
220                    self.collect_expr_tables(arg, tables);
221                }
222            }
223            Expr::Between { expr, low, high } => {
224                self.collect_expr_tables(expr, tables);
225                self.collect_expr_tables(low, tables);
226                self.collect_expr_tables(high, tables);
227            }
228            Expr::In { expr, list } => {
229                self.collect_expr_tables(expr, tables);
230                for item in list {
231                    self.collect_expr_tables(item, tables);
232                }
233            }
234            Expr::Like { expr, .. } => {
235                self.collect_expr_tables(expr, tables);
236            }
237            Expr::NotBetween { expr, low, high } => {
238                self.collect_expr_tables(expr, tables);
239                self.collect_expr_tables(low, tables);
240                self.collect_expr_tables(high, tables);
241            }
242            Expr::NotIn { expr, list } => {
243                self.collect_expr_tables(expr, tables);
244                for item in list {
245                    self.collect_expr_tables(item, tables);
246                }
247            }
248            Expr::NotLike { expr, .. } => {
249                self.collect_expr_tables(expr, tables);
250            }
251            Expr::Match { expr, .. } => {
252                self.collect_expr_tables(expr, tables);
253            }
254            Expr::NotMatch { expr, .. } => {
255                self.collect_expr_tables(expr, tables);
256            }
257        }
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::ast::Expr;
265
266    #[test]
267    fn test_cross_product_with_join_predicate() {
268        let pass = ImplicitJoinsPass;
269
270        // Create: Filter(a.id = b.a_id) -> CrossProduct(Scan(a), Scan(b))
271        let cross = LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b"));
272        let join_pred = Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0));
273        let plan = LogicalPlan::filter(cross, join_pred);
274
275        let result = pass.optimize(plan);
276
277        // Should become: Join(a.id = b.a_id, Scan(a), Scan(b))
278        assert!(matches!(result, LogicalPlan::Join { .. }));
279        if let LogicalPlan::Join {
280            left,
281            right,
282            join_type,
283            ..
284        } = result
285        {
286            assert!(matches!(*left, LogicalPlan::Scan { table } if table == "a"));
287            assert!(matches!(*right, LogicalPlan::Scan { table } if table == "b"));
288            assert!(matches!(join_type, JoinType::Inner));
289        }
290    }
291
292    #[test]
293    fn test_cross_product_with_non_join_predicate() {
294        let pass = ImplicitJoinsPass;
295
296        // Create: Filter(a.id = 1) -> CrossProduct(Scan(a), Scan(b))
297        // This is NOT a join predicate (only references one table)
298        let cross = LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b"));
299        let filter_pred = Expr::eq(Expr::column("a", "id", 0), Expr::literal(1i64));
300        let plan = LogicalPlan::filter(cross, filter_pred);
301
302        let result = pass.optimize(plan);
303
304        // Should remain as Filter -> CrossProduct
305        assert!(matches!(result, LogicalPlan::Filter { .. }));
306        if let LogicalPlan::Filter { input, .. } = result {
307            assert!(matches!(*input, LogicalPlan::CrossProduct { .. }));
308        }
309    }
310
311    #[test]
312    fn test_filter_without_cross_product() {
313        let pass = ImplicitJoinsPass;
314
315        // Create: Filter(id = 1) -> Scan(a)
316        let plan = LogicalPlan::filter(
317            LogicalPlan::scan("a"),
318            Expr::eq(Expr::column("a", "id", 0), Expr::literal(1i64)),
319        );
320
321        let result = pass.optimize(plan);
322
323        // Should remain unchanged
324        assert!(matches!(result, LogicalPlan::Filter { .. }));
325        if let LogicalPlan::Filter { input, .. } = result {
326            assert!(matches!(*input, LogicalPlan::Scan { .. }));
327        }
328    }
329
330    #[test]
331    fn test_nested_cross_products_with_join() {
332        let pass = ImplicitJoinsPass;
333
334        // Create: Filter(a.id = b.a_id) -> CrossProduct(CrossProduct(Scan(a), Scan(b)), Scan(c))
335        // The predicate references a and b, which are both in the left subtree of the outer cross product
336        // So this is NOT a join predicate for the outer cross product
337        let inner_cross =
338            LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b"));
339        let outer_cross = LogicalPlan::cross_product(inner_cross, LogicalPlan::scan("c"));
340        let join_pred = Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0));
341        let plan = LogicalPlan::filter(outer_cross, join_pred);
342
343        let result = pass.optimize(plan);
344
345        // The outer cross product should remain as Filter -> CrossProduct
346        // because the predicate only references tables in the left subtree
347        assert!(matches!(result, LogicalPlan::Filter { .. }));
348        if let LogicalPlan::Filter { input, .. } = result {
349            assert!(matches!(*input, LogicalPlan::CrossProduct { .. }));
350        }
351    }
352
353    #[test]
354    fn test_is_join_predicate() {
355        let pass = ImplicitJoinsPass;
356
357        let left = LogicalPlan::scan("a");
358        let right = LogicalPlan::scan("b");
359
360        // Join predicate: a.id = b.a_id
361        let join_pred = Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0));
362        assert!(pass.is_join_predicate(&join_pred, &left, &right));
363
364        // Non-join predicate: a.id = 1
365        let filter_pred = Expr::eq(Expr::column("a", "id", 0), Expr::literal(1i64));
366        assert!(!pass.is_join_predicate(&filter_pred, &left, &right));
367
368        // Non-join predicate: b.name = 'test'
369        let filter_pred2 = Expr::eq(Expr::column("b", "name", 1), Expr::literal("test"));
370        assert!(!pass.is_join_predicate(&filter_pred2, &left, &right));
371    }
372
373    #[test]
374    fn test_collect_tables() {
375        let pass = ImplicitJoinsPass;
376
377        let plan = LogicalPlan::cross_product(
378            LogicalPlan::scan("a"),
379            LogicalPlan::cross_product(LogicalPlan::scan("b"), LogicalPlan::scan("c")),
380        );
381
382        let tables = pass.collect_tables(&plan);
383        assert_eq!(tables.len(), 3);
384        assert!(tables.contains(&"a".into()));
385        assert!(tables.contains(&"b".into()));
386        assert!(tables.contains(&"c".into()));
387    }
388}