1use crate::optimizer::OptimizerPass;
21use crate::planner::LogicalPlan;
22use alloc::boxed::Box;
23use alloc::vec::Vec;
24
25pub struct CrossProductPass;
27
28impl OptimizerPass for CrossProductPass {
29 fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
30 self.traverse(plan)
31 }
32
33 fn name(&self) -> &'static str {
34 "cross_product"
35 }
36}
37
38impl CrossProductPass {
39 fn traverse(&self, plan: LogicalPlan) -> LogicalPlan {
41 match plan {
42 LogicalPlan::CrossProduct { left, right } => {
43 let mut tables = Vec::new();
45 self.collect_cross_product_children(*left, &mut tables);
46 self.collect_cross_product_children(*right, &mut tables);
47
48 if tables.len() > 2 {
50 self.build_binary_cross_product(tables)
51 } else if tables.len() == 2 {
52 LogicalPlan::CrossProduct {
53 left: Box::new(self.traverse(tables.remove(0))),
54 right: Box::new(self.traverse(tables.remove(0))),
55 }
56 } else if tables.len() == 1 {
57 self.traverse(tables.remove(0))
58 } else {
59 LogicalPlan::Empty
60 }
61 }
62
63 LogicalPlan::Filter { input, predicate } => LogicalPlan::Filter {
64 input: Box::new(self.traverse(*input)),
65 predicate,
66 },
67
68 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
69 input: Box::new(self.traverse(*input)),
70 columns,
71 },
72
73 LogicalPlan::Join {
74 left,
75 right,
76 condition,
77 join_type,
78 } => LogicalPlan::Join {
79 left: Box::new(self.traverse(*left)),
80 right: Box::new(self.traverse(*right)),
81 condition,
82 join_type,
83 },
84
85 LogicalPlan::Aggregate {
86 input,
87 group_by,
88 aggregates,
89 } => LogicalPlan::Aggregate {
90 input: Box::new(self.traverse(*input)),
91 group_by,
92 aggregates,
93 },
94
95 LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
96 input: Box::new(self.traverse(*input)),
97 order_by,
98 },
99
100 LogicalPlan::Limit {
101 input,
102 limit,
103 offset,
104 } => LogicalPlan::Limit {
105 input: Box::new(self.traverse(*input)),
106 limit,
107 offset,
108 },
109
110 LogicalPlan::Union { left, right, all } => LogicalPlan::Union {
111 left: Box::new(self.traverse(*left)),
112 right: Box::new(self.traverse(*right)),
113 all,
114 },
115
116 plan @ (LogicalPlan::Scan { .. }
118 | LogicalPlan::IndexScan { .. }
119 | LogicalPlan::IndexGet { .. }
120 | LogicalPlan::IndexInGet { .. }
121 | LogicalPlan::GinIndexScan { .. }
122 | LogicalPlan::GinIndexScanMulti { .. }
123 | LogicalPlan::Empty) => plan,
124 }
125 }
126
127 fn collect_cross_product_children(&self, plan: LogicalPlan, children: &mut Vec<LogicalPlan>) {
130 match plan {
131 LogicalPlan::CrossProduct { left, right } => {
132 self.collect_cross_product_children(*left, children);
133 self.collect_cross_product_children(*right, children);
134 }
135 other => children.push(other),
136 }
137 }
138
139 fn build_binary_cross_product(&self, mut tables: Vec<LogicalPlan>) -> LogicalPlan {
142 tables = tables.into_iter().map(|t| self.traverse(t)).collect();
144
145 while tables.len() > 1 {
147 let mut new_level = Vec::new();
148 let mut i = 0;
149 while i < tables.len() {
150 if i + 1 < tables.len() {
151 new_level.push(LogicalPlan::CrossProduct {
152 left: Box::new(tables[i].clone()),
153 right: Box::new(tables[i + 1].clone()),
154 });
155 i += 2;
156 } else {
157 new_level.push(tables[i].clone());
159 i += 1;
160 }
161 }
162 tables = new_level;
163 }
164
165 tables.pop().unwrap_or(LogicalPlan::Empty)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 fn count_cross_products(plan: &LogicalPlan) -> usize {
174 match plan {
175 LogicalPlan::CrossProduct { left, right } => {
176 1 + count_cross_products(left) + count_cross_products(right)
177 }
178 LogicalPlan::Filter { input, .. }
179 | LogicalPlan::Project { input, .. }
180 | LogicalPlan::Aggregate { input, .. }
181 | LogicalPlan::Sort { input, .. }
182 | LogicalPlan::Limit { input, .. } => count_cross_products(input),
183 LogicalPlan::Join { left, right, .. } | LogicalPlan::Union { left, right, .. } => {
184 count_cross_products(left) + count_cross_products(right)
185 }
186 _ => 0,
187 }
188 }
189
190 fn count_scans(plan: &LogicalPlan) -> usize {
191 match plan {
192 LogicalPlan::Scan { .. } => 1,
193 LogicalPlan::CrossProduct { left, right } => count_scans(left) + count_scans(right),
194 LogicalPlan::Filter { input, .. }
195 | LogicalPlan::Project { input, .. }
196 | LogicalPlan::Aggregate { input, .. }
197 | LogicalPlan::Sort { input, .. }
198 | LogicalPlan::Limit { input, .. } => count_scans(input),
199 LogicalPlan::Join { left, right, .. } | LogicalPlan::Union { left, right, .. } => {
200 count_scans(left) + count_scans(right)
201 }
202 _ => 0,
203 }
204 }
205
206 #[test]
207 fn test_two_table_cross_product_unchanged() {
208 let pass = CrossProductPass;
209 let plan = LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b"));
210
211 let result = pass.optimize(plan);
212
213 assert!(matches!(result, LogicalPlan::CrossProduct { .. }));
214 assert_eq!(count_cross_products(&result), 1);
215 assert_eq!(count_scans(&result), 2);
216 }
217
218 #[test]
219 fn test_three_table_cross_product() {
220 let pass = CrossProductPass;
221
222 let plan = LogicalPlan::cross_product(
224 LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b")),
225 LogicalPlan::scan("c"),
226 );
227
228 let result = pass.optimize(plan);
229
230 assert_eq!(count_cross_products(&result), 2);
232 assert_eq!(count_scans(&result), 3);
233 }
234
235 #[test]
236 fn test_four_table_cross_product() {
237 let pass = CrossProductPass;
238
239 let plan = LogicalPlan::cross_product(
241 LogicalPlan::cross_product(
242 LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b")),
243 LogicalPlan::scan("c"),
244 ),
245 LogicalPlan::scan("d"),
246 );
247
248 let result = pass.optimize(plan);
249
250 assert_eq!(count_cross_products(&result), 3);
253 assert_eq!(count_scans(&result), 4);
254 }
255
256 #[test]
257 fn test_cross_product_with_filter() {
258 let pass = CrossProductPass;
259
260 let cross = LogicalPlan::cross_product(
262 LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b")),
263 LogicalPlan::scan("c"),
264 );
265 let plan = LogicalPlan::filter(
266 cross,
267 crate::ast::Expr::eq(
268 crate::ast::Expr::column("a", "id", 0),
269 crate::ast::Expr::literal(1i64),
270 ),
271 );
272
273 let result = pass.optimize(plan);
274
275 assert!(matches!(result, LogicalPlan::Filter { .. }));
277 if let LogicalPlan::Filter { input, .. } = result {
278 assert_eq!(count_cross_products(&input), 2);
279 assert_eq!(count_scans(&input), 3);
280 }
281 }
282
283 #[test]
284 fn test_single_table_no_cross_product() {
285 let pass = CrossProductPass;
286 let plan = LogicalPlan::scan("a");
287
288 let result = pass.optimize(plan);
289
290 assert!(matches!(result, LogicalPlan::Scan { .. }));
291 }
292}