Skip to main content

cynos_query/optimizer/
cross_product.rs

1//! Cross product pass - converts multi-way cross products to binary tree structure.
2//!
3//! This pass transforms a CrossProduct node with more than 2 children
4//! (represented as nested CrossProducts) into a balanced binary tree structure.
5//!
6//! Example:
7//! ```text
8//! CrossProduct(A, B, C, D)    =>    CrossProduct
9//!                                    /        \
10//!                              CrossProduct  CrossProduct
11//!                               /    \        /    \
12//!                              A      B      C      D
13//! ```
14//!
15//! This transformation is necessary because:
16//! 1. The execution engine expects binary cross products
17//! 2. It enables subsequent passes (like ImplicitJoinsPass) to convert
18//!    cross products to joins more effectively
19
20use crate::optimizer::OptimizerPass;
21use crate::planner::LogicalPlan;
22use alloc::boxed::Box;
23use alloc::vec::Vec;
24
25/// Pass that converts multi-way cross products to binary tree structure.
26pub struct CrossProductPass;
27
28impl OptimizerPass for CrossProductPass {
29    fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
30        self.traverse(plan)
31    }
32
33    fn name(&self) -> &'static str {
34        "cross_product"
35    }
36}
37
38impl CrossProductPass {
39    /// Recursively traverses the plan tree and transforms cross products.
40    fn traverse(&self, plan: LogicalPlan) -> LogicalPlan {
41        match plan {
42            LogicalPlan::CrossProduct { left, right } => {
43                // First collect all tables from nested cross products
44                let mut tables = Vec::new();
45                self.collect_cross_product_children(*left, &mut tables);
46                self.collect_cross_product_children(*right, &mut tables);
47
48                // If we have more than 2 tables, restructure into binary tree
49                if tables.len() > 2 {
50                    self.build_binary_cross_product(tables)
51                } else if tables.len() == 2 {
52                    LogicalPlan::CrossProduct {
53                        left: Box::new(self.traverse(tables.remove(0))),
54                        right: Box::new(self.traverse(tables.remove(0))),
55                    }
56                } else if tables.len() == 1 {
57                    self.traverse(tables.remove(0))
58                } else {
59                    LogicalPlan::Empty
60                }
61            }
62
63            LogicalPlan::Filter { input, predicate } => LogicalPlan::Filter {
64                input: Box::new(self.traverse(*input)),
65                predicate,
66            },
67
68            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
69                input: Box::new(self.traverse(*input)),
70                columns,
71            },
72
73            LogicalPlan::Join {
74                left,
75                right,
76                condition,
77                join_type,
78            } => LogicalPlan::Join {
79                left: Box::new(self.traverse(*left)),
80                right: Box::new(self.traverse(*right)),
81                condition,
82                join_type,
83            },
84
85            LogicalPlan::Aggregate {
86                input,
87                group_by,
88                aggregates,
89            } => LogicalPlan::Aggregate {
90                input: Box::new(self.traverse(*input)),
91                group_by,
92                aggregates,
93            },
94
95            LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
96                input: Box::new(self.traverse(*input)),
97                order_by,
98            },
99
100            LogicalPlan::Limit {
101                input,
102                limit,
103                offset,
104            } => LogicalPlan::Limit {
105                input: Box::new(self.traverse(*input)),
106                limit,
107                offset,
108            },
109
110            LogicalPlan::Union { left, right, all } => LogicalPlan::Union {
111                left: Box::new(self.traverse(*left)),
112                right: Box::new(self.traverse(*right)),
113                all,
114            },
115
116            // Leaf nodes - no transformation needed
117            plan @ (LogicalPlan::Scan { .. }
118            | LogicalPlan::IndexScan { .. }
119            | LogicalPlan::IndexGet { .. }
120            | LogicalPlan::IndexInGet { .. }
121            | LogicalPlan::GinIndexScan { .. }
122            | LogicalPlan::GinIndexScanMulti { .. }
123            | LogicalPlan::Empty) => plan,
124        }
125    }
126
127    /// Collects all children from nested cross products.
128    /// Non-cross-product nodes are added directly to the list.
129    fn collect_cross_product_children(&self, plan: LogicalPlan, children: &mut Vec<LogicalPlan>) {
130        match plan {
131            LogicalPlan::CrossProduct { left, right } => {
132                self.collect_cross_product_children(*left, children);
133                self.collect_cross_product_children(*right, children);
134            }
135            other => children.push(other),
136        }
137    }
138
139    /// Builds a balanced binary tree of cross products from a list of tables.
140    /// Uses left-to-right pairing: ((A × B) × (C × D))
141    fn build_binary_cross_product(&self, mut tables: Vec<LogicalPlan>) -> LogicalPlan {
142        // Recursively optimize each table first
143        tables = tables.into_iter().map(|t| self.traverse(t)).collect();
144
145        // Build binary tree by pairing adjacent tables
146        while tables.len() > 1 {
147            let mut new_level = Vec::new();
148            let mut i = 0;
149            while i < tables.len() {
150                if i + 1 < tables.len() {
151                    new_level.push(LogicalPlan::CrossProduct {
152                        left: Box::new(tables[i].clone()),
153                        right: Box::new(tables[i + 1].clone()),
154                    });
155                    i += 2;
156                } else {
157                    // Odd number of tables - carry the last one up
158                    new_level.push(tables[i].clone());
159                    i += 1;
160                }
161            }
162            tables = new_level;
163        }
164
165        tables.pop().unwrap_or(LogicalPlan::Empty)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    fn count_cross_products(plan: &LogicalPlan) -> usize {
174        match plan {
175            LogicalPlan::CrossProduct { left, right } => {
176                1 + count_cross_products(left) + count_cross_products(right)
177            }
178            LogicalPlan::Filter { input, .. }
179            | LogicalPlan::Project { input, .. }
180            | LogicalPlan::Aggregate { input, .. }
181            | LogicalPlan::Sort { input, .. }
182            | LogicalPlan::Limit { input, .. } => count_cross_products(input),
183            LogicalPlan::Join { left, right, .. } | LogicalPlan::Union { left, right, .. } => {
184                count_cross_products(left) + count_cross_products(right)
185            }
186            _ => 0,
187        }
188    }
189
190    fn count_scans(plan: &LogicalPlan) -> usize {
191        match plan {
192            LogicalPlan::Scan { .. } => 1,
193            LogicalPlan::CrossProduct { left, right } => count_scans(left) + count_scans(right),
194            LogicalPlan::Filter { input, .. }
195            | LogicalPlan::Project { input, .. }
196            | LogicalPlan::Aggregate { input, .. }
197            | LogicalPlan::Sort { input, .. }
198            | LogicalPlan::Limit { input, .. } => count_scans(input),
199            LogicalPlan::Join { left, right, .. } | LogicalPlan::Union { left, right, .. } => {
200                count_scans(left) + count_scans(right)
201            }
202            _ => 0,
203        }
204    }
205
206    #[test]
207    fn test_two_table_cross_product_unchanged() {
208        let pass = CrossProductPass;
209        let plan = LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b"));
210
211        let result = pass.optimize(plan);
212
213        assert!(matches!(result, LogicalPlan::CrossProduct { .. }));
214        assert_eq!(count_cross_products(&result), 1);
215        assert_eq!(count_scans(&result), 2);
216    }
217
218    #[test]
219    fn test_three_table_cross_product() {
220        let pass = CrossProductPass;
221
222        // Create: CrossProduct(CrossProduct(A, B), C)
223        let plan = LogicalPlan::cross_product(
224            LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b")),
225            LogicalPlan::scan("c"),
226        );
227
228        let result = pass.optimize(plan);
229
230        // Should have 2 cross products (binary tree with 3 leaves)
231        assert_eq!(count_cross_products(&result), 2);
232        assert_eq!(count_scans(&result), 3);
233    }
234
235    #[test]
236    fn test_four_table_cross_product() {
237        let pass = CrossProductPass;
238
239        // Create a chain: CrossProduct(CrossProduct(CrossProduct(A, B), C), D)
240        let plan = LogicalPlan::cross_product(
241            LogicalPlan::cross_product(
242                LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b")),
243                LogicalPlan::scan("c"),
244            ),
245            LogicalPlan::scan("d"),
246        );
247
248        let result = pass.optimize(plan);
249
250        // Should have 3 cross products (balanced binary tree with 4 leaves)
251        // Structure: ((A × B) × (C × D))
252        assert_eq!(count_cross_products(&result), 3);
253        assert_eq!(count_scans(&result), 4);
254    }
255
256    #[test]
257    fn test_cross_product_with_filter() {
258        let pass = CrossProductPass;
259
260        // Create: Filter(CrossProduct(A, B, C))
261        let cross = LogicalPlan::cross_product(
262            LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b")),
263            LogicalPlan::scan("c"),
264        );
265        let plan = LogicalPlan::filter(
266            cross,
267            crate::ast::Expr::eq(
268                crate::ast::Expr::column("a", "id", 0),
269                crate::ast::Expr::literal(1i64),
270            ),
271        );
272
273        let result = pass.optimize(plan);
274
275        // Filter should be preserved, cross products restructured
276        assert!(matches!(result, LogicalPlan::Filter { .. }));
277        if let LogicalPlan::Filter { input, .. } = result {
278            assert_eq!(count_cross_products(&input), 2);
279            assert_eq!(count_scans(&input), 3);
280        }
281    }
282
283    #[test]
284    fn test_single_table_no_cross_product() {
285        let pass = CrossProductPass;
286        let plan = LogicalPlan::scan("a");
287
288        let result = pass.optimize(plan);
289
290        assert!(matches!(result, LogicalPlan::Scan { .. }));
291    }
292}