1use crate::ast::{Expr, JoinType};
20use crate::optimizer::OptimizerPass;
21use crate::planner::LogicalPlan;
22use alloc::boxed::Box;
23
24pub struct ImplicitJoinsPass;
26
27impl OptimizerPass for ImplicitJoinsPass {
28 fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
29 self.traverse(plan)
30 }
31
32 fn name(&self) -> &'static str {
33 "implicit_joins"
34 }
35}
36
37impl ImplicitJoinsPass {
38 fn traverse(&self, plan: LogicalPlan) -> LogicalPlan {
40 match plan {
41 LogicalPlan::Filter { input, predicate } => {
42 let optimized_input = self.traverse(*input);
44
45 if let LogicalPlan::CrossProduct { left, right } = &optimized_input {
47 if self.is_join_predicate(&predicate, left, right) {
48 return LogicalPlan::Join {
50 left: left.clone(),
51 right: right.clone(),
52 condition: predicate,
53 join_type: JoinType::Inner,
54 };
55 }
56 }
57
58 LogicalPlan::Filter {
60 input: Box::new(optimized_input),
61 predicate,
62 }
63 }
64
65 LogicalPlan::CrossProduct { left, right } => LogicalPlan::CrossProduct {
66 left: Box::new(self.traverse(*left)),
67 right: Box::new(self.traverse(*right)),
68 },
69
70 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
71 input: Box::new(self.traverse(*input)),
72 columns,
73 },
74
75 LogicalPlan::Join {
76 left,
77 right,
78 condition,
79 join_type,
80 } => LogicalPlan::Join {
81 left: Box::new(self.traverse(*left)),
82 right: Box::new(self.traverse(*right)),
83 condition,
84 join_type,
85 },
86
87 LogicalPlan::Aggregate {
88 input,
89 group_by,
90 aggregates,
91 } => LogicalPlan::Aggregate {
92 input: Box::new(self.traverse(*input)),
93 group_by,
94 aggregates,
95 },
96
97 LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
98 input: Box::new(self.traverse(*input)),
99 order_by,
100 },
101
102 LogicalPlan::Limit {
103 input,
104 limit,
105 offset,
106 } => LogicalPlan::Limit {
107 input: Box::new(self.traverse(*input)),
108 limit,
109 offset,
110 },
111
112 LogicalPlan::Union { left, right, all } => LogicalPlan::Union {
113 left: Box::new(self.traverse(*left)),
114 right: Box::new(self.traverse(*right)),
115 all,
116 },
117
118 plan @ (LogicalPlan::Scan { .. }
120 | LogicalPlan::IndexScan { .. }
121 | LogicalPlan::IndexGet { .. }
122 | LogicalPlan::IndexInGet { .. }
123 | LogicalPlan::GinIndexScan { .. }
124 | LogicalPlan::GinIndexScanMulti { .. }
125 | LogicalPlan::Empty) => plan,
126 }
127 }
128
129 fn is_join_predicate(
132 &self,
133 predicate: &Expr,
134 left: &LogicalPlan,
135 right: &LogicalPlan,
136 ) -> bool {
137 let left_tables = self.collect_tables(left);
138 let right_tables = self.collect_tables(right);
139 let predicate_tables = self.collect_predicate_tables(predicate);
140
141 let refs_left = predicate_tables.iter().any(|t| left_tables.contains(t));
143 let refs_right = predicate_tables.iter().any(|t| right_tables.contains(t));
144
145 refs_left && refs_right
146 }
147
148 fn collect_tables(&self, plan: &LogicalPlan) -> alloc::vec::Vec<alloc::string::String> {
150 let mut tables = alloc::vec::Vec::new();
151 self.collect_tables_recursive(plan, &mut tables);
152 tables
153 }
154
155 fn collect_tables_recursive(
156 &self,
157 plan: &LogicalPlan,
158 tables: &mut alloc::vec::Vec<alloc::string::String>,
159 ) {
160 match plan {
161 LogicalPlan::Scan { table } => tables.push(table.clone()),
162 LogicalPlan::IndexScan { table, .. }
163 | LogicalPlan::IndexGet { table, .. }
164 | LogicalPlan::IndexInGet { table, .. }
165 | LogicalPlan::GinIndexScan { table, .. }
166 | LogicalPlan::GinIndexScanMulti { table, .. } => {
167 tables.push(table.clone())
168 }
169 LogicalPlan::Filter { input, .. }
170 | LogicalPlan::Project { input, .. }
171 | LogicalPlan::Aggregate { input, .. }
172 | LogicalPlan::Sort { input, .. }
173 | LogicalPlan::Limit { input, .. } => {
174 self.collect_tables_recursive(input, tables);
175 }
176 LogicalPlan::Join { left, right, .. }
177 | LogicalPlan::CrossProduct { left, right }
178 | LogicalPlan::Union { left, right, .. } => {
179 self.collect_tables_recursive(left, tables);
180 self.collect_tables_recursive(right, tables);
181 }
182 LogicalPlan::Empty => {}
183 }
184 }
185
186 fn collect_predicate_tables(&self, expr: &Expr) -> alloc::vec::Vec<alloc::string::String> {
188 let mut tables = alloc::vec::Vec::new();
189 self.collect_expr_tables(expr, &mut tables);
190 tables
191 }
192
193 fn collect_expr_tables(
194 &self,
195 expr: &Expr,
196 tables: &mut alloc::vec::Vec<alloc::string::String>,
197 ) {
198 match expr {
199 Expr::Column(col_ref) => {
200 if !tables.contains(&col_ref.table) {
201 tables.push(col_ref.table.clone());
202 }
203 }
204 Expr::BinaryOp { left, right, .. } => {
205 self.collect_expr_tables(left, tables);
206 self.collect_expr_tables(right, tables);
207 }
208 Expr::UnaryOp { expr, .. } => {
209 self.collect_expr_tables(expr, tables);
210 }
211 Expr::Aggregate { expr, .. } => {
212 if let Some(e) = expr {
213 self.collect_expr_tables(e, tables);
214 }
215 }
216 Expr::Literal(_) => {}
217 Expr::Function { args, .. } => {
219 for arg in args {
220 self.collect_expr_tables(arg, tables);
221 }
222 }
223 Expr::Between { expr, low, high } => {
224 self.collect_expr_tables(expr, tables);
225 self.collect_expr_tables(low, tables);
226 self.collect_expr_tables(high, tables);
227 }
228 Expr::In { expr, list } => {
229 self.collect_expr_tables(expr, tables);
230 for item in list {
231 self.collect_expr_tables(item, tables);
232 }
233 }
234 Expr::Like { expr, .. } => {
235 self.collect_expr_tables(expr, tables);
236 }
237 Expr::NotBetween { expr, low, high } => {
238 self.collect_expr_tables(expr, tables);
239 self.collect_expr_tables(low, tables);
240 self.collect_expr_tables(high, tables);
241 }
242 Expr::NotIn { expr, list } => {
243 self.collect_expr_tables(expr, tables);
244 for item in list {
245 self.collect_expr_tables(item, tables);
246 }
247 }
248 Expr::NotLike { expr, .. } => {
249 self.collect_expr_tables(expr, tables);
250 }
251 Expr::Match { expr, .. } => {
252 self.collect_expr_tables(expr, tables);
253 }
254 Expr::NotMatch { expr, .. } => {
255 self.collect_expr_tables(expr, tables);
256 }
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::ast::Expr;
265
266 #[test]
267 fn test_cross_product_with_join_predicate() {
268 let pass = ImplicitJoinsPass;
269
270 let cross = LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b"));
272 let join_pred = Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0));
273 let plan = LogicalPlan::filter(cross, join_pred);
274
275 let result = pass.optimize(plan);
276
277 assert!(matches!(result, LogicalPlan::Join { .. }));
279 if let LogicalPlan::Join {
280 left,
281 right,
282 join_type,
283 ..
284 } = result
285 {
286 assert!(matches!(*left, LogicalPlan::Scan { table } if table == "a"));
287 assert!(matches!(*right, LogicalPlan::Scan { table } if table == "b"));
288 assert!(matches!(join_type, JoinType::Inner));
289 }
290 }
291
292 #[test]
293 fn test_cross_product_with_non_join_predicate() {
294 let pass = ImplicitJoinsPass;
295
296 let cross = LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b"));
299 let filter_pred = Expr::eq(Expr::column("a", "id", 0), Expr::literal(1i64));
300 let plan = LogicalPlan::filter(cross, filter_pred);
301
302 let result = pass.optimize(plan);
303
304 assert!(matches!(result, LogicalPlan::Filter { .. }));
306 if let LogicalPlan::Filter { input, .. } = result {
307 assert!(matches!(*input, LogicalPlan::CrossProduct { .. }));
308 }
309 }
310
311 #[test]
312 fn test_filter_without_cross_product() {
313 let pass = ImplicitJoinsPass;
314
315 let plan = LogicalPlan::filter(
317 LogicalPlan::scan("a"),
318 Expr::eq(Expr::column("a", "id", 0), Expr::literal(1i64)),
319 );
320
321 let result = pass.optimize(plan);
322
323 assert!(matches!(result, LogicalPlan::Filter { .. }));
325 if let LogicalPlan::Filter { input, .. } = result {
326 assert!(matches!(*input, LogicalPlan::Scan { .. }));
327 }
328 }
329
330 #[test]
331 fn test_nested_cross_products_with_join() {
332 let pass = ImplicitJoinsPass;
333
334 let inner_cross =
338 LogicalPlan::cross_product(LogicalPlan::scan("a"), LogicalPlan::scan("b"));
339 let outer_cross = LogicalPlan::cross_product(inner_cross, LogicalPlan::scan("c"));
340 let join_pred = Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0));
341 let plan = LogicalPlan::filter(outer_cross, join_pred);
342
343 let result = pass.optimize(plan);
344
345 assert!(matches!(result, LogicalPlan::Filter { .. }));
348 if let LogicalPlan::Filter { input, .. } = result {
349 assert!(matches!(*input, LogicalPlan::CrossProduct { .. }));
350 }
351 }
352
353 #[test]
354 fn test_is_join_predicate() {
355 let pass = ImplicitJoinsPass;
356
357 let left = LogicalPlan::scan("a");
358 let right = LogicalPlan::scan("b");
359
360 let join_pred = Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0));
362 assert!(pass.is_join_predicate(&join_pred, &left, &right));
363
364 let filter_pred = Expr::eq(Expr::column("a", "id", 0), Expr::literal(1i64));
366 assert!(!pass.is_join_predicate(&filter_pred, &left, &right));
367
368 let filter_pred2 = Expr::eq(Expr::column("b", "name", 1), Expr::literal("test"));
370 assert!(!pass.is_join_predicate(&filter_pred2, &left, &right));
371 }
372
373 #[test]
374 fn test_collect_tables() {
375 let pass = ImplicitJoinsPass;
376
377 let plan = LogicalPlan::cross_product(
378 LogicalPlan::scan("a"),
379 LogicalPlan::cross_product(LogicalPlan::scan("b"), LogicalPlan::scan("c")),
380 );
381
382 let tables = pass.collect_tables(&plan);
383 assert_eq!(tables.len(), 3);
384 assert!(tables.contains(&"a".into()));
385 assert!(tables.contains(&"b".into()));
386 assert!(tables.contains(&"c".into()));
387 }
388}