1use crate::error::Result;
16use crate::parser::ast::*;
17use std::collections::HashSet;
18
19use super::{
20 OptimizationRule, collect_table_aliases, combine_predicates_with_and, extract_predicates,
21 get_predicate_tables,
22};
23
24pub struct JoinReordering;
26
27impl OptimizationRule for JoinReordering {
28 fn apply(&self, mut stmt: SelectStatement) -> Result<SelectStatement> {
29 if let Some(from) = stmt.from.take() {
30 stmt.from = Some(reorder_join_tree(from));
31 }
32 Ok(stmt)
33 }
34}
35
36struct JoinComponent {
39 table_ref: TableReference,
41 table_names: HashSet<String>,
43 estimated_rows: f64,
45}
46
47fn reorder_join_tree(table_ref: TableReference) -> TableReference {
49 let is_inner = matches!(
50 &table_ref,
51 TableReference::Join { join_type, .. }
52 if *join_type == JoinType::Inner || *join_type == JoinType::Cross
53 );
54
55 if is_inner {
56 let mut components: Vec<JoinComponent> = Vec::new();
58 let mut predicates: Vec<Expr> = Vec::new();
59 flatten_inner_join_chain(table_ref, &mut components, &mut predicates);
60
61 if components.len() <= 1 {
62 return components
63 .into_iter()
64 .next()
65 .map(|c| c.table_ref)
66 .unwrap_or(TableReference::Table {
67 name: String::new(),
68 alias: None,
69 });
70 }
71
72 for comp in &mut components {
74 let old_ref = std::mem::replace(
75 &mut comp.table_ref,
76 TableReference::Table {
77 name: String::new(),
78 alias: None,
79 },
80 );
81 comp.table_ref = reorder_join_tree(old_ref);
82 comp.estimated_rows = heuristic_row_estimate(&comp.table_ref);
83 }
84
85 greedy_join_order(components, predicates)
86 } else {
87 match table_ref {
88 TableReference::Join {
89 left,
90 right,
91 join_type,
92 on,
93 } => TableReference::Join {
94 left: Box::new(reorder_join_tree(*left)),
95 right: Box::new(reorder_join_tree(*right)),
96 join_type,
97 on,
98 },
99 other => other,
100 }
101 }
102}
103
104fn flatten_inner_join_chain(
106 table_ref: TableReference,
107 components: &mut Vec<JoinComponent>,
108 predicates: &mut Vec<Expr>,
109) {
110 let is_inner = matches!(
111 &table_ref,
112 TableReference::Join { join_type, .. }
113 if *join_type == JoinType::Inner || *join_type == JoinType::Cross
114 );
115
116 if is_inner {
117 if let TableReference::Join {
118 left, right, on, ..
119 } = table_ref
120 {
121 flatten_inner_join_chain(*left, components, predicates);
122 flatten_inner_join_chain(*right, components, predicates);
123
124 if let Some(on_expr) = on {
125 let mut preds = Vec::new();
126 extract_predicates(&on_expr, &mut preds);
127 predicates.extend(preds);
128 }
129 }
130 } else {
131 let table_names = collect_table_aliases(&table_ref);
133 let estimated_rows = heuristic_row_estimate(&table_ref);
134 components.push(JoinComponent {
135 table_ref,
136 table_names,
137 estimated_rows,
138 });
139 }
140}
141
142fn heuristic_row_estimate(table_ref: &TableReference) -> f64 {
144 match table_ref {
145 TableReference::Table { .. } => 10_000.0,
146 TableReference::Join {
147 left,
148 right,
149 join_type,
150 on,
151 } => {
152 let left_rows = heuristic_row_estimate(left);
153 let right_rows = heuristic_row_estimate(right);
154
155 let selectivity = if let Some(on_expr) = on {
156 let mut preds = Vec::new();
157 extract_predicates(on_expr, &mut preds);
158 heuristic_selectivity(&preds)
159 } else {
160 1.0
161 };
162
163 match join_type {
164 JoinType::Inner | JoinType::Cross => {
165 (left_rows * right_rows * selectivity).max(1.0)
166 }
167 JoinType::Left => left_rows.max(1.0),
168 JoinType::Right => right_rows.max(1.0),
169 JoinType::Full => (left_rows + right_rows).max(1.0),
170 }
171 }
172 TableReference::Subquery { .. } => 1_000.0,
173 }
174}
175
176pub(crate) fn heuristic_single_selectivity(pred: &Expr) -> f64 {
185 match pred {
186 Expr::BinaryOp { left, op, right } => match op {
187 BinaryOperator::Eq => 0.1,
188 BinaryOperator::NotEq => 0.9,
189 BinaryOperator::Lt
190 | BinaryOperator::LtEq
191 | BinaryOperator::Gt
192 | BinaryOperator::GtEq => 0.33,
193 BinaryOperator::And => {
194 heuristic_single_selectivity(left) * heuristic_single_selectivity(right)
195 }
196 BinaryOperator::Or => {
197 let l = heuristic_single_selectivity(left);
198 let r = heuristic_single_selectivity(right);
199 l + r - l * r
200 }
201 BinaryOperator::Like => 0.1,
202 BinaryOperator::NotLike => 0.9,
203 _ => 0.5,
204 },
205 Expr::IsNull(_) => 0.05,
206 Expr::IsNotNull(_) => 0.95,
207 Expr::InList { list, negated, .. } => {
208 let sel = (list.len() as f64 * 0.1).min(0.9);
209 if *negated { 1.0 - sel } else { sel }
210 }
211 Expr::Between { negated, .. } => {
212 if *negated {
213 0.75
214 } else {
215 0.25
216 }
217 }
218 _ => 0.5,
219 }
220}
221
222pub(crate) fn heuristic_selectivity(predicates: &[Expr]) -> f64 {
224 if predicates.is_empty() {
225 return 1.0;
226 }
227 predicates
228 .iter()
229 .map(heuristic_single_selectivity)
230 .product::<f64>()
231 .max(0.0001)
232}
233
234fn estimate_pair_join_cost(
241 left: &JoinComponent,
242 right: &JoinComponent,
243 all_predicates: &[Expr],
244) -> f64 {
245 let applicable: Vec<&Expr> = all_predicates
247 .iter()
248 .filter(|pred| {
249 let tables = get_predicate_tables(pred);
250 !tables.is_empty()
251 && tables.iter().any(|t| left.table_names.contains(t))
252 && tables.iter().any(|t| right.table_names.contains(t))
253 })
254 .collect();
255
256 let selectivity = if applicable.is_empty() {
257 1.0 } else {
259 applicable
260 .iter()
261 .map(|p| heuristic_single_selectivity(p))
262 .product::<f64>()
263 .max(0.0001)
264 };
265
266 let output_rows = left.estimated_rows * right.estimated_rows * selectivity;
267
268 let (build_rows, probe_rows) = if left.estimated_rows <= right.estimated_rows {
270 (left.estimated_rows, right.estimated_rows)
271 } else {
272 (right.estimated_rows, left.estimated_rows)
273 };
274
275 let build_cost = build_rows * 10.0;
276 let probe_cost = probe_rows * 5.0;
277 let output_cost = output_rows * 2.0;
278
279 build_cost + probe_cost + output_cost
280}
281
282fn greedy_join_order(
291 mut components: Vec<JoinComponent>,
292 mut all_predicates: Vec<Expr>,
293) -> TableReference {
294 if components.is_empty() {
296 return TableReference::Table {
297 name: String::new(),
298 alias: None,
299 };
300 }
301
302 if components.len() == 1 {
303 let mut result = components
304 .into_iter()
305 .next()
306 .map(|c| c.table_ref)
307 .unwrap_or(TableReference::Table {
308 name: String::new(),
309 alias: None,
310 });
311
312 if !all_predicates.is_empty() {
314 if let TableReference::Join { ref mut on, .. } = result {
315 let remaining = super::combine_predicates_with_and(all_predicates);
316 *on = match (on.take(), remaining) {
317 (Some(existing), Some(new_pred)) => Some(Expr::BinaryOp {
318 left: Box::new(existing),
319 op: BinaryOperator::And,
320 right: Box::new(new_pred),
321 }),
322 (Some(existing), None) => Some(existing),
323 (None, some_pred) => some_pred,
324 };
325 }
326 }
327
328 return result;
329 }
330
331 while components.len() > 1 {
332 let mut best_i = 0;
334 let mut best_j = 1;
335 let mut best_cost = f64::MAX;
336
337 for i in 0..components.len() {
338 for j in (i + 1)..components.len() {
339 let cost = estimate_pair_join_cost(&components[i], &components[j], &all_predicates);
340 if cost < best_cost {
341 best_cost = cost;
342 best_i = i;
343 best_j = j;
344 }
345 }
346 }
347
348 let right_comp = components.remove(best_j);
350 let left_comp = components.remove(best_i);
351
352 let merged_tables: HashSet<String> = left_comp
354 .table_names
355 .iter()
356 .chain(right_comp.table_names.iter())
357 .cloned()
358 .collect();
359
360 let mut join_preds = Vec::new();
361 let mut remaining_preds = Vec::new();
362
363 for pred in all_predicates {
364 let tables = get_predicate_tables(&pred);
365 if !tables.is_empty() && tables.iter().all(|t| merged_tables.contains(t)) {
366 join_preds.push(pred);
367 } else {
368 remaining_preds.push(pred);
369 }
370 }
371 all_predicates = remaining_preds;
372
373 let selectivity = heuristic_selectivity(&join_preds);
375 let output_rows =
376 (left_comp.estimated_rows * right_comp.estimated_rows * selectivity).max(1.0);
377
378 let on_condition = combine_predicates_with_and(join_preds);
379
380 components.push(JoinComponent {
381 table_ref: TableReference::Join {
382 left: Box::new(left_comp.table_ref),
383 right: Box::new(right_comp.table_ref),
384 join_type: JoinType::Inner,
385 on: on_condition,
386 },
387 table_names: merged_tables,
388 estimated_rows: output_rows,
389 });
390 }
391
392 let mut result = components
393 .into_iter()
394 .next()
395 .map(|c| c.table_ref)
396 .unwrap_or(TableReference::Table {
397 name: String::new(),
398 alias: None,
399 });
400
401 if !all_predicates.is_empty() {
403 if let TableReference::Join { ref mut on, .. } = result {
404 let remaining = combine_predicates_with_and(all_predicates);
405 *on = match (on.take(), remaining) {
406 (Some(existing), Some(new_pred)) => Some(Expr::BinaryOp {
407 left: Box::new(existing),
408 op: BinaryOperator::And,
409 right: Box::new(new_pred),
410 }),
411 (Some(existing), None) => Some(existing),
412 (None, some_pred) => some_pred,
413 };
414 }
415 }
416
417 result
418}
419
420#[cfg(test)]
421#[allow(clippy::unwrap_used)]
422#[allow(clippy::panic)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_join_reorder_preserves_outer_join() {
428 let stmt = SelectStatement {
429 projection: vec![SelectItem::Wildcard],
430 from: Some(TableReference::Join {
431 left: Box::new(TableReference::Table {
432 name: "a".to_string(),
433 alias: None,
434 }),
435 right: Box::new(TableReference::Table {
436 name: "b".to_string(),
437 alias: None,
438 }),
439 join_type: JoinType::Left,
440 on: Some(Expr::BinaryOp {
441 left: Box::new(Expr::Column {
442 table: Some("a".to_string()),
443 name: "id".to_string(),
444 }),
445 op: BinaryOperator::Eq,
446 right: Box::new(Expr::Column {
447 table: Some("b".to_string()),
448 name: "id".to_string(),
449 }),
450 }),
451 }),
452 selection: None,
453 group_by: Vec::new(),
454 having: None,
455 order_by: Vec::new(),
456 limit: None,
457 offset: None,
458 };
459
460 let reorder = JoinReordering;
461 let result = reorder.apply(stmt);
462 assert!(result.is_ok(), "Join reordering should succeed");
463 let result = result.expect("Join reordering should succeed");
464
465 let Some(TableReference::Join { join_type, .. }) = &result.from else {
467 panic!("FROM should contain a join");
468 };
469 assert_eq!(*join_type, JoinType::Left);
470 }
471
472 #[test]
473 fn test_join_reorder_three_inner_tables() {
474 let stmt = SelectStatement {
476 projection: vec![SelectItem::Wildcard],
477 from: Some(TableReference::Join {
478 left: Box::new(TableReference::Join {
479 left: Box::new(TableReference::Table {
480 name: "a".to_string(),
481 alias: Some("a".to_string()),
482 }),
483 right: Box::new(TableReference::Table {
484 name: "b".to_string(),
485 alias: Some("b".to_string()),
486 }),
487 join_type: JoinType::Inner,
488 on: Some(Expr::BinaryOp {
489 left: Box::new(Expr::Column {
490 table: Some("a".to_string()),
491 name: "id".to_string(),
492 }),
493 op: BinaryOperator::Eq,
494 right: Box::new(Expr::Column {
495 table: Some("b".to_string()),
496 name: "id".to_string(),
497 }),
498 }),
499 }),
500 right: Box::new(TableReference::Table {
501 name: "c".to_string(),
502 alias: Some("c".to_string()),
503 }),
504 join_type: JoinType::Inner,
505 on: Some(Expr::BinaryOp {
506 left: Box::new(Expr::Column {
507 table: Some("b".to_string()),
508 name: "id".to_string(),
509 }),
510 op: BinaryOperator::Eq,
511 right: Box::new(Expr::Column {
512 table: Some("c".to_string()),
513 name: "id".to_string(),
514 }),
515 }),
516 }),
517 selection: None,
518 group_by: Vec::new(),
519 having: None,
520 order_by: Vec::new(),
521 limit: None,
522 offset: None,
523 };
524
525 let reorder = JoinReordering;
526 let result = reorder.apply(stmt);
527 assert!(result.is_ok(), "Join reordering should succeed");
528 let result = result.expect("Join reordering should succeed");
529
530 let Some(from) = result.from.as_ref() else {
532 panic!("FROM should exist");
533 };
534 let aliases = collect_table_aliases(from);
535 assert!(aliases.contains("a"), "Table a missing");
536 assert!(aliases.contains("b"), "Table b missing");
537 assert!(aliases.contains("c"), "Table c missing");
538 }
539
540 #[test]
541 fn test_join_reorder_single_table() {
542 let stmt = SelectStatement {
543 projection: vec![SelectItem::Wildcard],
544 from: Some(TableReference::Table {
545 name: "users".to_string(),
546 alias: None,
547 }),
548 selection: None,
549 group_by: Vec::new(),
550 having: None,
551 order_by: Vec::new(),
552 limit: None,
553 offset: None,
554 };
555
556 let reorder = JoinReordering;
557 let result = reorder.apply(stmt);
558 assert!(result.is_ok(), "Join reordering should succeed");
559 let result = result.expect("Join reordering should succeed");
560
561 assert!(matches!(
563 &result.from,
564 Some(TableReference::Table { name, .. }) if name == "users"
565 ));
566 }
567
568 #[test]
569 fn test_heuristic_selectivity_values() {
570 let eq_pred = Expr::BinaryOp {
572 left: Box::new(Expr::Column {
573 table: None,
574 name: "a".to_string(),
575 }),
576 op: BinaryOperator::Eq,
577 right: Box::new(Expr::Literal(Literal::Integer(1))),
578 };
579 let sel = heuristic_single_selectivity(&eq_pred);
580 assert!((sel - 0.1).abs() < 0.001);
581
582 let lt_pred = Expr::BinaryOp {
584 left: Box::new(Expr::Column {
585 table: None,
586 name: "a".to_string(),
587 }),
588 op: BinaryOperator::Lt,
589 right: Box::new(Expr::Literal(Literal::Integer(10))),
590 };
591 let sel = heuristic_single_selectivity(<_pred);
592 assert!((sel - 0.33).abs() < 0.001);
593
594 let null_pred = Expr::IsNull(Box::new(Expr::Column {
596 table: None,
597 name: "a".to_string(),
598 }));
599 let sel = heuristic_single_selectivity(&null_pred);
600 assert!((sel - 0.05).abs() < 0.001);
601
602 let preds = vec![eq_pred, lt_pred];
604 let combined = heuristic_selectivity(&preds);
605 assert!((combined - 0.033).abs() < 0.001);
606
607 let empty: Vec<Expr> = vec![];
609 assert!((heuristic_selectivity(&empty) - 1.0).abs() < 0.001);
610 }
611}