1use crate::ast::{BinaryOp, Expr, JoinType};
19use crate::context::ExecutionContext;
20use crate::planner::PhysicalPlan;
21use alloc::boxed::Box;
22use alloc::string::String;
23
24pub struct IndexJoinPass<'a> {
26 ctx: &'a ExecutionContext,
27}
28
29impl<'a> IndexJoinPass<'a> {
30 pub fn new(ctx: &'a ExecutionContext) -> Self {
32 Self { ctx }
33 }
34
35 pub fn optimize(&self, plan: PhysicalPlan) -> PhysicalPlan {
37 self.traverse(plan)
38 }
39
40 fn traverse(&self, plan: PhysicalPlan) -> PhysicalPlan {
41 match plan {
42 PhysicalPlan::HashJoin {
44 left,
45 right,
46 condition,
47 join_type,
48 } => {
49 let left = self.traverse(*left);
50 let right = self.traverse(*right);
51
52 if join_type != JoinType::Inner || !condition.is_equi_join() {
54 return PhysicalPlan::HashJoin {
55 left: Box::new(left),
56 right: Box::new(right),
57 condition,
58 join_type,
59 };
60 }
61
62 if let Some((outer, inner_table, inner_index)) =
64 self.find_index_join_candidate(&left, &right, &condition)
65 {
66 return PhysicalPlan::IndexNestedLoopJoin {
67 outer: Box::new(outer),
68 inner_table,
69 inner_index,
70 condition,
71 join_type,
72 };
73 }
74
75 PhysicalPlan::HashJoin {
76 left: Box::new(left),
77 right: Box::new(right),
78 condition,
79 join_type,
80 }
81 }
82
83 PhysicalPlan::NestedLoopJoin {
85 left,
86 right,
87 condition,
88 join_type,
89 } => {
90 let left = self.traverse(*left);
91 let right = self.traverse(*right);
92
93 if join_type != JoinType::Inner || !condition.is_equi_join() {
95 return PhysicalPlan::NestedLoopJoin {
96 left: Box::new(left),
97 right: Box::new(right),
98 condition,
99 join_type,
100 };
101 }
102
103 if let Some((outer, inner_table, inner_index)) =
105 self.find_index_join_candidate(&left, &right, &condition)
106 {
107 return PhysicalPlan::IndexNestedLoopJoin {
108 outer: Box::new(outer),
109 inner_table,
110 inner_index,
111 condition,
112 join_type,
113 };
114 }
115
116 PhysicalPlan::NestedLoopJoin {
117 left: Box::new(left),
118 right: Box::new(right),
119 condition,
120 join_type,
121 }
122 }
123
124 PhysicalPlan::Filter { input, predicate } => PhysicalPlan::Filter {
126 input: Box::new(self.traverse(*input)),
127 predicate,
128 },
129
130 PhysicalPlan::Project { input, columns } => PhysicalPlan::Project {
131 input: Box::new(self.traverse(*input)),
132 columns,
133 },
134
135 PhysicalPlan::SortMergeJoin {
136 left,
137 right,
138 condition,
139 join_type,
140 } => PhysicalPlan::SortMergeJoin {
141 left: Box::new(self.traverse(*left)),
142 right: Box::new(self.traverse(*right)),
143 condition,
144 join_type,
145 },
146
147 PhysicalPlan::HashAggregate {
148 input,
149 group_by,
150 aggregates,
151 } => PhysicalPlan::HashAggregate {
152 input: Box::new(self.traverse(*input)),
153 group_by,
154 aggregates,
155 },
156
157 PhysicalPlan::Sort { input, order_by } => PhysicalPlan::Sort {
158 input: Box::new(self.traverse(*input)),
159 order_by,
160 },
161
162 PhysicalPlan::Limit {
163 input,
164 limit,
165 offset,
166 } => PhysicalPlan::Limit {
167 input: Box::new(self.traverse(*input)),
168 limit,
169 offset,
170 },
171
172 PhysicalPlan::CrossProduct { left, right } => PhysicalPlan::CrossProduct {
173 left: Box::new(self.traverse(*left)),
174 right: Box::new(self.traverse(*right)),
175 },
176
177 PhysicalPlan::NoOp { input } => PhysicalPlan::NoOp {
178 input: Box::new(self.traverse(*input)),
179 },
180
181 PhysicalPlan::TopN {
182 input,
183 order_by,
184 limit,
185 offset,
186 } => PhysicalPlan::TopN {
187 input: Box::new(self.traverse(*input)),
188 order_by,
189 limit,
190 offset,
191 },
192
193 plan @ (PhysicalPlan::TableScan { .. }
195 | PhysicalPlan::IndexScan { .. }
196 | PhysicalPlan::IndexGet { .. }
197 | PhysicalPlan::IndexInGet { .. }
198 | PhysicalPlan::IndexNestedLoopJoin { .. }
199 | PhysicalPlan::Empty | PhysicalPlan::GinIndexScan { .. } | PhysicalPlan::GinIndexScanMulti { .. }) => plan,
200 }
201 }
202
203 fn find_index_join_candidate(
206 &self,
207 left: &PhysicalPlan,
208 right: &PhysicalPlan,
209 condition: &Expr,
210 ) -> Option<(PhysicalPlan, String, String)> {
211 let (left_col, right_col) = self.extract_join_columns(condition)?;
213
214 if let Some((table, index)) = self.get_indexed_table_scan(right, &right_col) {
216 return Some((left.clone(), table, index));
217 }
218
219 if let Some((table, index)) = self.get_indexed_table_scan(left, &left_col) {
221 return Some((right.clone(), table, index));
222 }
223
224 None
225 }
226
227 fn extract_join_columns(&self, condition: &Expr) -> Option<(String, String)> {
229 match condition {
230 Expr::BinaryOp {
231 left,
232 op: BinaryOp::Eq,
233 right,
234 } => {
235 let left_col = self.extract_column_name(left)?;
236 let right_col = self.extract_column_name(right)?;
237 Some((left_col, right_col))
238 }
239 _ => None,
240 }
241 }
242
243 fn extract_column_name(&self, expr: &Expr) -> Option<String> {
245 match expr {
246 Expr::Column(col_ref) => Some(col_ref.column.clone()),
247 _ => None,
248 }
249 }
250
251 fn get_indexed_table_scan(
254 &self,
255 plan: &PhysicalPlan,
256 column: &str,
257 ) -> Option<(String, String)> {
258 match plan {
259 PhysicalPlan::TableScan { table } => {
260 let index = self.ctx.find_index(table, &[column])?;
262 Some((table.clone(), index.name.clone()))
263 }
264 _ => None,
265 }
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::ast::Expr;
273 use crate::context::{IndexInfo, TableStats};
274
275 fn create_test_context() -> ExecutionContext {
276 let mut ctx = ExecutionContext::new();
277
278 ctx.register_table(
280 "a",
281 TableStats {
282 row_count: 100,
283 is_sorted: false,
284 indexes: alloc::vec![],
285 },
286 );
287
288 ctx.register_table(
290 "b",
291 TableStats {
292 row_count: 1000,
293 is_sorted: false,
294 indexes: alloc::vec![IndexInfo::new(
295 "idx_a_id",
296 alloc::vec!["a_id".into()],
297 false
298 )],
299 },
300 );
301
302 ctx
303 }
304
305 #[test]
306 fn test_hash_join_to_index_join() {
307 let ctx = create_test_context();
308 let pass = IndexJoinPass::new(&ctx);
309
310 let plan = PhysicalPlan::HashJoin {
312 left: Box::new(PhysicalPlan::table_scan("a")),
313 right: Box::new(PhysicalPlan::table_scan("b")),
314 condition: Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0)),
315 join_type: JoinType::Inner,
316 };
317
318 let result = pass.optimize(plan);
319
320 assert!(matches!(result, PhysicalPlan::IndexNestedLoopJoin { .. }));
322 if let PhysicalPlan::IndexNestedLoopJoin {
323 inner_table,
324 inner_index,
325 ..
326 } = result
327 {
328 assert_eq!(inner_table, "b");
329 assert_eq!(inner_index, "idx_a_id");
330 }
331 }
332
333 #[test]
334 fn test_no_index_remains_hash_join() {
335 let ctx = ExecutionContext::new(); let pass = IndexJoinPass::new(&ctx);
337
338 let plan = PhysicalPlan::HashJoin {
339 left: Box::new(PhysicalPlan::table_scan("a")),
340 right: Box::new(PhysicalPlan::table_scan("b")),
341 condition: Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0)),
342 join_type: JoinType::Inner,
343 };
344
345 let result = pass.optimize(plan);
346
347 assert!(matches!(result, PhysicalPlan::HashJoin { .. }));
349 }
350
351 #[test]
352 fn test_outer_join_not_optimized() {
353 let ctx = create_test_context();
354 let pass = IndexJoinPass::new(&ctx);
355
356 let plan = PhysicalPlan::HashJoin {
358 left: Box::new(PhysicalPlan::table_scan("a")),
359 right: Box::new(PhysicalPlan::table_scan("b")),
360 condition: Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0)),
361 join_type: JoinType::LeftOuter,
362 };
363
364 let result = pass.optimize(plan);
365
366 assert!(matches!(result, PhysicalPlan::HashJoin { .. }));
368 }
369
370 #[test]
371 fn test_non_equi_join_not_optimized() {
372 let ctx = create_test_context();
373 let pass = IndexJoinPass::new(&ctx);
374
375 let plan = PhysicalPlan::HashJoin {
377 left: Box::new(PhysicalPlan::table_scan("a")),
378 right: Box::new(PhysicalPlan::table_scan("b")),
379 condition: Expr::gt(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0)),
380 join_type: JoinType::Inner,
381 };
382
383 let result = pass.optimize(plan);
384
385 assert!(matches!(result, PhysicalPlan::HashJoin { .. }));
387 }
388
389 #[test]
390 fn test_nested_joins() {
391 let ctx = create_test_context();
392 let pass = IndexJoinPass::new(&ctx);
393
394 let inner_join = PhysicalPlan::HashJoin {
396 left: Box::new(PhysicalPlan::table_scan("a")),
397 right: Box::new(PhysicalPlan::table_scan("b")),
398 condition: Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0)),
399 join_type: JoinType::Inner,
400 };
401
402 let outer_join = PhysicalPlan::HashJoin {
403 left: Box::new(inner_join),
404 right: Box::new(PhysicalPlan::table_scan("c")),
405 condition: Expr::eq(Expr::column("b", "id", 0), Expr::column("c", "b_id", 0)),
406 join_type: JoinType::Inner,
407 };
408
409 let result = pass.optimize(outer_join);
410
411 assert!(matches!(result, PhysicalPlan::HashJoin { .. }));
414 if let PhysicalPlan::HashJoin { left, .. } = result {
415 assert!(matches!(*left, PhysicalPlan::IndexNestedLoopJoin { .. }));
416 }
417 }
418}