1use crate::ast::{Expr, JoinType};
25use crate::context::ExecutionContext;
26use crate::optimizer::OptimizerPass;
27use crate::planner::LogicalPlan;
28use alloc::boxed::Box;
29use alloc::string::String;
30use alloc::vec::Vec;
31
32pub struct JoinReorder {
34 context: Option<ExecutionContext>,
36}
37
38impl Default for JoinReorder {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl JoinReorder {
45 pub fn new() -> Self {
48 Self { context: None }
49 }
50
51 pub fn with_context(context: ExecutionContext) -> Self {
53 Self {
54 context: Some(context),
55 }
56 }
57}
58
59impl OptimizerPass for JoinReorder {
60 fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
61 self.reorder(plan)
62 }
63
64 fn name(&self) -> &'static str {
65 "join_reorder"
66 }
67}
68
69#[derive(Clone, Debug)]
71struct JoinNode {
72 plan: LogicalPlan,
74 cardinality: usize,
76 tables: Vec<String>,
78}
79
80#[derive(Clone, Debug)]
82struct JoinCondition {
83 condition: Expr,
85 left_tables: Vec<String>,
87 right_tables: Vec<String>,
89}
90
91impl JoinReorder {
92 fn reorder(&self, plan: LogicalPlan) -> LogicalPlan {
93 match plan {
94 LogicalPlan::Join {
95 left,
96 right,
97 condition,
98 join_type,
99 } => {
100 let optimized_left = self.reorder(*left);
102 let optimized_right = self.reorder(*right);
103
104 if join_type != JoinType::Inner {
106 return LogicalPlan::Join {
107 left: Box::new(optimized_left),
108 right: Box::new(optimized_right),
109 condition,
110 join_type,
111 };
112 }
113
114 let mut nodes = Vec::new();
116 let mut conditions = Vec::new();
117
118 self.collect_join_nodes(
119 &LogicalPlan::Join {
120 left: Box::new(optimized_left),
121 right: Box::new(optimized_right),
122 condition,
123 join_type,
124 },
125 &mut nodes,
126 &mut conditions,
127 );
128
129 if nodes.len() <= 2 {
131 if nodes.len() == 2 && !conditions.is_empty() {
133 let (left_node, right_node) = self.order_two_nodes(nodes);
134 return LogicalPlan::Join {
135 left: Box::new(left_node.plan),
136 right: Box::new(right_node.plan),
137 condition: conditions.into_iter().next().unwrap().condition,
138 join_type: JoinType::Inner,
139 };
140 }
141 if let Some(node) = nodes.into_iter().next() {
143 return node.plan;
144 }
145 return LogicalPlan::Empty;
146 }
147
148 self.greedy_reorder(nodes, conditions)
150 }
151
152 LogicalPlan::Filter { input, predicate } => LogicalPlan::Filter {
153 input: Box::new(self.reorder(*input)),
154 predicate,
155 },
156
157 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
158 input: Box::new(self.reorder(*input)),
159 columns,
160 },
161
162 LogicalPlan::Aggregate {
163 input,
164 group_by,
165 aggregates,
166 } => LogicalPlan::Aggregate {
167 input: Box::new(self.reorder(*input)),
168 group_by,
169 aggregates,
170 },
171
172 LogicalPlan::Sort { input, order_by } => LogicalPlan::Sort {
173 input: Box::new(self.reorder(*input)),
174 order_by,
175 },
176
177 LogicalPlan::Limit {
178 input,
179 limit,
180 offset,
181 } => LogicalPlan::Limit {
182 input: Box::new(self.reorder(*input)),
183 limit,
184 offset,
185 },
186
187 LogicalPlan::CrossProduct { left, right } => LogicalPlan::CrossProduct {
188 left: Box::new(self.reorder(*left)),
189 right: Box::new(self.reorder(*right)),
190 },
191
192 LogicalPlan::Union { left, right, all } => LogicalPlan::Union {
193 left: Box::new(self.reorder(*left)),
194 right: Box::new(self.reorder(*right)),
195 all,
196 },
197
198 LogicalPlan::Scan { .. }
200 | LogicalPlan::IndexScan { .. }
201 | LogicalPlan::IndexGet { .. }
202 | LogicalPlan::IndexInGet { .. }
203 | LogicalPlan::GinIndexScan { .. }
204 | LogicalPlan::GinIndexScanMulti { .. }
205 | LogicalPlan::Empty => plan,
206 }
207 }
208
209 fn collect_join_nodes(
211 &self,
212 plan: &LogicalPlan,
213 nodes: &mut Vec<JoinNode>,
214 conditions: &mut Vec<JoinCondition>,
215 ) {
216 match plan {
217 LogicalPlan::Join {
218 left,
219 right,
220 condition,
221 join_type: JoinType::Inner,
222 } => {
223 self.collect_join_nodes(left, nodes, conditions);
225 self.collect_join_nodes(right, nodes, conditions);
226
227 let (left_tables, right_tables) = self.extract_condition_tables(condition);
229 conditions.push(JoinCondition {
230 condition: condition.clone(),
231 left_tables,
232 right_tables,
233 });
234 }
235
236 _ => {
238 let tables = self.extract_plan_tables(plan);
239 let cardinality = self.estimate_cardinality(plan, &tables);
240 nodes.push(JoinNode {
241 plan: plan.clone(),
242 cardinality,
243 tables,
244 });
245 }
246 }
247 }
248
249 fn order_two_nodes(&self, mut nodes: Vec<JoinNode>) -> (JoinNode, JoinNode) {
251 if nodes.len() != 2 {
252 panic!("Expected exactly 2 nodes");
253 }
254 let second = nodes.pop().unwrap();
255 let first = nodes.pop().unwrap();
256
257 if first.cardinality <= second.cardinality {
258 (first, second)
259 } else {
260 (second, first)
261 }
262 }
263
264 fn greedy_reorder(&self, mut nodes: Vec<JoinNode>, conditions: Vec<JoinCondition>) -> LogicalPlan {
267 if nodes.is_empty() {
268 return LogicalPlan::Empty;
269 }
270
271 if nodes.len() == 1 {
272 return nodes.pop().unwrap().plan;
273 }
274
275 nodes.sort_by_key(|n| n.cardinality);
277
278 let mut result_node = nodes.remove(0);
280 let mut used_conditions: Vec<bool> = alloc::vec![false; conditions.len()];
281
282 while !nodes.is_empty() {
283 let (best_idx, best_condition_idx) =
285 self.find_best_join(&result_node, &nodes, &conditions, &used_conditions);
286
287 if best_condition_idx.is_none() {
289 let mut found_idx = None;
291 let mut found_cond_idx = None;
292
293 for (i, node) in nodes.iter().enumerate() {
294 for (j, cond) in conditions.iter().enumerate() {
295 if used_conditions[j] {
296 continue;
297 }
298 let result_has_left = cond.left_tables.iter().any(|t| result_node.tables.contains(t));
300 let result_has_right = cond.right_tables.iter().any(|t| result_node.tables.contains(t));
301 let node_has_left = cond.left_tables.iter().any(|t| node.tables.contains(t));
302 let node_has_right = cond.right_tables.iter().any(|t| node.tables.contains(t));
303
304 if (result_has_left && node_has_right) || (result_has_right && node_has_left) {
305 found_idx = Some(i);
306 found_cond_idx = Some(j);
307 break;
308 }
309 }
310 if found_idx.is_some() {
311 break;
312 }
313 }
314
315 if let (Some(idx), Some(cond_idx)) = (found_idx, found_cond_idx) {
316 let next_node = nodes.remove(idx);
317 used_conditions[cond_idx] = true;
318
319 let new_plan = LogicalPlan::Join {
320 left: Box::new(result_node.plan),
321 right: Box::new(next_node.plan),
322 condition: conditions[cond_idx].condition.clone(),
323 join_type: JoinType::Inner,
324 };
325
326 let mut new_tables = result_node.tables;
327 new_tables.extend(next_node.tables);
328
329 result_node = JoinNode {
330 plan: new_plan,
331 cardinality: self.estimate_join_cardinality(
332 result_node.cardinality,
333 next_node.cardinality,
334 ),
335 tables: new_tables,
336 };
337 continue;
338 }
339
340 let next_node = nodes.remove(0);
342 let new_plan = LogicalPlan::Join {
343 left: Box::new(result_node.plan),
344 right: Box::new(next_node.plan),
345 condition: Expr::literal(true),
346 join_type: JoinType::Inner,
347 };
348
349 let mut new_tables = result_node.tables;
350 new_tables.extend(next_node.tables);
351
352 result_node = JoinNode {
353 plan: new_plan,
354 cardinality: self.estimate_join_cardinality(
355 result_node.cardinality,
356 next_node.cardinality,
357 ),
358 tables: new_tables,
359 };
360 continue;
361 }
362
363 let next_node = nodes.remove(best_idx);
364
365 let condition = if let Some(cond_idx) = best_condition_idx {
367 used_conditions[cond_idx] = true;
368 conditions[cond_idx].condition.clone()
369 } else {
370 Expr::literal(true)
373 };
374
375 let new_plan = LogicalPlan::Join {
377 left: Box::new(result_node.plan),
378 right: Box::new(next_node.plan),
379 condition,
380 join_type: JoinType::Inner,
381 };
382
383 let mut new_tables = result_node.tables;
385 new_tables.extend(next_node.tables);
386
387 result_node = JoinNode {
388 plan: new_plan,
389 cardinality: self.estimate_join_cardinality(
391 result_node.cardinality,
392 next_node.cardinality,
393 ),
394 tables: new_tables,
395 };
396 }
397
398 let mut final_plan = result_node.plan;
401 for (i, cond) in conditions.iter().enumerate() {
402 if !used_conditions[i] {
403 final_plan = LogicalPlan::Filter {
405 input: Box::new(final_plan),
406 predicate: cond.condition.clone(),
407 };
408 }
409 }
410
411 final_plan
412 }
413
414 fn find_best_join(
416 &self,
417 current: &JoinNode,
418 candidates: &[JoinNode],
419 conditions: &[JoinCondition],
420 used_conditions: &[bool],
421 ) -> (usize, Option<usize>) {
422 let mut best_idx = 0;
423 let mut best_condition_idx = None;
424 let mut best_score = usize::MAX;
425
426 for (i, candidate) in candidates.iter().enumerate() {
427 let condition_idx = self.find_applicable_condition(
429 ¤t.tables,
430 &candidate.tables,
431 conditions,
432 used_conditions,
433 );
434
435 let score = candidate.cardinality;
437
438 let adjusted_score = if condition_idx.is_some() {
440 score
441 } else {
442 score.saturating_mul(10) };
444
445 if adjusted_score < best_score {
446 best_score = adjusted_score;
447 best_idx = i;
448 best_condition_idx = condition_idx;
449 }
450 }
451
452 (best_idx, best_condition_idx)
453 }
454
455 fn find_applicable_condition(
457 &self,
458 left_tables: &[String],
459 right_tables: &[String],
460 conditions: &[JoinCondition],
461 used_conditions: &[bool],
462 ) -> Option<usize> {
463 for (i, cond) in conditions.iter().enumerate() {
464 if used_conditions[i] {
465 continue;
466 }
467
468 let left_matches = cond
470 .left_tables
471 .iter()
472 .any(|t| left_tables.contains(t) || right_tables.contains(t));
473 let right_matches = cond
474 .right_tables
475 .iter()
476 .any(|t| left_tables.contains(t) || right_tables.contains(t));
477
478 if left_matches && right_matches {
479 return Some(i);
480 }
481 }
482 None
483 }
484
485 fn extract_condition_tables(&self, condition: &Expr) -> (Vec<String>, Vec<String>) {
487 match condition {
488 Expr::BinaryOp { left, right, .. } => {
489 let left_tables = self.extract_expr_tables(left);
490 let right_tables = self.extract_expr_tables(right);
491 (left_tables, right_tables)
492 }
493 _ => (Vec::new(), Vec::new()),
494 }
495 }
496
497 fn extract_expr_tables(&self, expr: &Expr) -> Vec<String> {
499 let mut tables = Vec::new();
500 self.collect_expr_tables(expr, &mut tables);
501 tables
502 }
503
504 fn collect_expr_tables(&self, expr: &Expr, tables: &mut Vec<String>) {
505 match expr {
506 Expr::Column(col) => {
507 if !tables.contains(&col.table) {
508 tables.push(col.table.clone());
509 }
510 }
511 Expr::BinaryOp { left, right, .. } => {
512 self.collect_expr_tables(left, tables);
513 self.collect_expr_tables(right, tables);
514 }
515 Expr::UnaryOp { expr, .. } => {
516 self.collect_expr_tables(expr, tables);
517 }
518 _ => {}
519 }
520 }
521
522 fn extract_plan_tables(&self, plan: &LogicalPlan) -> Vec<String> {
524 let mut tables = Vec::new();
525 self.collect_plan_tables(plan, &mut tables);
526 tables
527 }
528
529 fn collect_plan_tables(&self, plan: &LogicalPlan, tables: &mut Vec<String>) {
530 match plan {
531 LogicalPlan::Scan { table } => {
532 tables.push(table.clone());
533 }
534 LogicalPlan::IndexScan { table, .. }
535 | LogicalPlan::IndexGet { table, .. }
536 | LogicalPlan::IndexInGet { table, .. }
537 | LogicalPlan::GinIndexScan { table, .. }
538 | LogicalPlan::GinIndexScanMulti { table, .. } => {
539 tables.push(table.clone());
540 }
541 LogicalPlan::Filter { input, .. }
542 | LogicalPlan::Project { input, .. }
543 | LogicalPlan::Aggregate { input, .. }
544 | LogicalPlan::Sort { input, .. }
545 | LogicalPlan::Limit { input, .. } => {
546 self.collect_plan_tables(input, tables);
547 }
548 LogicalPlan::Join { left, right, .. }
549 | LogicalPlan::CrossProduct { left, right }
550 | LogicalPlan::Union { left, right, .. } => {
551 self.collect_plan_tables(left, tables);
552 self.collect_plan_tables(right, tables);
553 }
554 LogicalPlan::Empty => {}
555 }
556 }
557
558 fn estimate_cardinality(&self, plan: &LogicalPlan, tables: &[String]) -> usize {
560 if let Some(ctx) = &self.context {
562 if tables.len() == 1 {
563 let count = ctx.row_count(&tables[0]);
564 if count > 0 {
565 return count;
566 }
567 }
568 }
569
570 match plan {
572 LogicalPlan::Scan { .. } => 1000, LogicalPlan::IndexGet { .. } => 1, LogicalPlan::IndexInGet { keys, .. } => keys.len(), LogicalPlan::IndexScan { .. } => 100, LogicalPlan::Filter { input, .. } => {
577 self.estimate_cardinality(input, tables) / 10
579 }
580 LogicalPlan::Limit { limit, .. } => *limit,
581 _ => 1000,
582 }
583 }
584
585 fn estimate_join_cardinality(&self, left_card: usize, right_card: usize) -> usize {
587 let product = left_card.saturating_mul(right_card);
590 core::cmp::max(product / 10, 1)
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597 use crate::context::{IndexInfo, TableStats};
598
599 fn create_test_context() -> ExecutionContext {
600 let mut ctx = ExecutionContext::new();
601
602 ctx.register_table(
603 "small",
604 TableStats {
605 row_count: 100,
606 is_sorted: false,
607 indexes: alloc::vec![],
608 },
609 );
610
611 ctx.register_table(
612 "medium",
613 TableStats {
614 row_count: 1000,
615 is_sorted: false,
616 indexes: alloc::vec![],
617 },
618 );
619
620 ctx.register_table(
621 "large",
622 TableStats {
623 row_count: 10000,
624 is_sorted: false,
625 indexes: alloc::vec![],
626 },
627 );
628
629 ctx
630 }
631
632 #[test]
633 fn test_join_reorder_basic() {
634 let pass = JoinReorder::new();
635
636 let plan = LogicalPlan::Join {
637 left: Box::new(LogicalPlan::scan("a")),
638 right: Box::new(LogicalPlan::scan("b")),
639 condition: Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0)),
640 join_type: JoinType::Inner,
641 };
642
643 let optimized = pass.optimize(plan);
644 assert!(matches!(optimized, LogicalPlan::Join { .. }));
645 }
646
647 #[test]
648 fn test_join_reorder_with_context() {
649 let ctx = create_test_context();
650 let pass = JoinReorder::with_context(ctx);
651
652 let plan = LogicalPlan::Join {
655 left: Box::new(LogicalPlan::Join {
656 left: Box::new(LogicalPlan::scan("large")),
657 right: Box::new(LogicalPlan::scan("medium")),
658 condition: Expr::eq(
659 Expr::column("large", "id", 0),
660 Expr::column("medium", "large_id", 0),
661 ),
662 join_type: JoinType::Inner,
663 }),
664 right: Box::new(LogicalPlan::scan("small")),
665 condition: Expr::eq(
666 Expr::column("medium", "id", 0),
667 Expr::column("small", "medium_id", 0),
668 ),
669 join_type: JoinType::Inner,
670 };
671
672 let optimized = pass.optimize(plan);
673
674 assert!(matches!(optimized, LogicalPlan::Join { .. }));
676
677 }
680
681 #[test]
682 fn test_outer_join_not_reordered() {
683 let pass = JoinReorder::new();
684
685 let plan = LogicalPlan::Join {
687 left: Box::new(LogicalPlan::scan("a")),
688 right: Box::new(LogicalPlan::scan("b")),
689 condition: Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0)),
690 join_type: JoinType::LeftOuter,
691 };
692
693 let optimized = pass.optimize(plan);
694
695 if let LogicalPlan::Join { join_type, .. } = optimized {
696 assert_eq!(join_type, JoinType::LeftOuter);
697 } else {
698 panic!("Expected Join");
699 }
700 }
701
702 #[test]
703 fn test_nested_inner_joins() {
704 let ctx = create_test_context();
705 let pass = JoinReorder::with_context(ctx);
706
707 let plan = LogicalPlan::Join {
709 left: Box::new(LogicalPlan::Join {
710 left: Box::new(LogicalPlan::scan("large")),
711 right: Box::new(LogicalPlan::scan("small")),
712 condition: Expr::eq(
713 Expr::column("large", "id", 0),
714 Expr::column("small", "large_id", 0),
715 ),
716 join_type: JoinType::Inner,
717 }),
718 right: Box::new(LogicalPlan::scan("medium")),
719 condition: Expr::eq(
720 Expr::column("small", "id", 0),
721 Expr::column("medium", "small_id", 0),
722 ),
723 join_type: JoinType::Inner,
724 };
725
726 let optimized = pass.optimize(plan);
727 assert!(matches!(optimized, LogicalPlan::Join { .. }));
728 }
729
730 #[test]
731 fn test_single_table_unchanged() {
732 let pass = JoinReorder::new();
733
734 let plan = LogicalPlan::scan("users");
735 let optimized = pass.optimize(plan.clone());
736
737 assert!(matches!(optimized, LogicalPlan::Scan { .. }));
738 }
739
740 #[test]
741 fn test_filter_preserved() {
742 let pass = JoinReorder::new();
743
744 let plan = LogicalPlan::filter(
745 LogicalPlan::Join {
746 left: Box::new(LogicalPlan::scan("a")),
747 right: Box::new(LogicalPlan::scan("b")),
748 condition: Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "a_id", 0)),
749 join_type: JoinType::Inner,
750 },
751 Expr::gt(Expr::column("a", "value", 1), Expr::literal(100i64)),
752 );
753
754 let optimized = pass.optimize(plan);
755
756 assert!(matches!(optimized, LogicalPlan::Filter { .. }));
758 }
759
760 #[test]
761 fn test_extract_condition_tables() {
762 let pass = JoinReorder::new();
763
764 let condition = Expr::eq(
765 Expr::column("users", "id", 0),
766 Expr::column("orders", "user_id", 0),
767 );
768
769 let (left, right) = pass.extract_condition_tables(&condition);
770 assert!(left.contains(&"users".into()));
771 assert!(right.contains(&"orders".into()));
772 }
773
774 #[test]
775 fn test_estimate_cardinality() {
776 let ctx = create_test_context();
777 let pass = JoinReorder::with_context(ctx);
778
779 let plan = LogicalPlan::scan("small");
780 let tables = pass.extract_plan_tables(&plan);
781 let card = pass.estimate_cardinality(&plan, &tables);
782
783 assert_eq!(card, 100);
784 }
785}