Skip to main content

oxigdal_query/optimizer/rules/
join_reordering.rs

1//! Join reordering optimization rule.
2//!
3//! Reorders inner joins to minimize intermediate result sizes using a
4//! greedy algorithm with heuristic cost estimation.
5//!
6//! Only inner and cross joins are reordered. Outer joins (LEFT, RIGHT, FULL)
7//! preserve their original order since their semantics depend on operand position.
8//!
9//! The greedy algorithm at each step selects the pair of relations whose join
10//! produces the smallest estimated intermediate result, based on:
11//! - Heuristic table size estimates (default 10000 rows for base tables)
12//! - Predicate selectivity estimates (equality: 0.1, range: 0.33, etc.)
13//! - Hash join cost model (build + probe + output)
14
15use crate::error::Result;
16use crate::parser::ast::*;
17use std::collections::HashSet;
18
19use super::{
20    OptimizationRule, collect_table_aliases, combine_predicates_with_and, extract_predicates,
21    get_predicate_tables,
22};
23
24/// Join reordering rule.
25pub struct JoinReordering;
26
27impl OptimizationRule for JoinReordering {
28    fn apply(&self, mut stmt: SelectStatement) -> Result<SelectStatement> {
29        if let Some(from) = stmt.from.take() {
30            stmt.from = Some(reorder_join_tree(from));
31        }
32        Ok(stmt)
33    }
34}
35
36/// A component in the join reordering algorithm: a table reference with
37/// metadata used for cost estimation.
38struct JoinComponent {
39    /// The table reference (base table, subquery, or already-optimized subtree).
40    table_ref: TableReference,
41    /// Set of all table names/aliases within this component.
42    table_names: HashSet<String>,
43    /// Estimated number of output rows.
44    estimated_rows: f64,
45}
46
47/// Recursively reorder inner join chains in the table reference tree.
48fn reorder_join_tree(table_ref: TableReference) -> TableReference {
49    let is_inner = matches!(
50        &table_ref,
51        TableReference::Join { join_type, .. }
52            if *join_type == JoinType::Inner || *join_type == JoinType::Cross
53    );
54
55    if is_inner {
56        // Flatten the inner join chain into components and predicates
57        let mut components: Vec<JoinComponent> = Vec::new();
58        let mut predicates: Vec<Expr> = Vec::new();
59        flatten_inner_join_chain(table_ref, &mut components, &mut predicates);
60
61        if components.len() <= 1 {
62            return components
63                .into_iter()
64                .next()
65                .map(|c| c.table_ref)
66                .unwrap_or(TableReference::Table {
67                    name: String::new(),
68                    alias: None,
69                });
70        }
71
72        // Recursively optimize each leaf component
73        for comp in &mut components {
74            let old_ref = std::mem::replace(
75                &mut comp.table_ref,
76                TableReference::Table {
77                    name: String::new(),
78                    alias: None,
79                },
80            );
81            comp.table_ref = reorder_join_tree(old_ref);
82            comp.estimated_rows = heuristic_row_estimate(&comp.table_ref);
83        }
84
85        greedy_join_order(components, predicates)
86    } else {
87        match table_ref {
88            TableReference::Join {
89                left,
90                right,
91                join_type,
92                on,
93            } => TableReference::Join {
94                left: Box::new(reorder_join_tree(*left)),
95                right: Box::new(reorder_join_tree(*right)),
96                join_type,
97                on,
98            },
99            other => other,
100        }
101    }
102}
103
104/// Flatten a chain of inner/cross joins into base components and predicates.
105fn flatten_inner_join_chain(
106    table_ref: TableReference,
107    components: &mut Vec<JoinComponent>,
108    predicates: &mut Vec<Expr>,
109) {
110    let is_inner = matches!(
111        &table_ref,
112        TableReference::Join { join_type, .. }
113            if *join_type == JoinType::Inner || *join_type == JoinType::Cross
114    );
115
116    if is_inner {
117        if let TableReference::Join {
118            left, right, on, ..
119        } = table_ref
120        {
121            flatten_inner_join_chain(*left, components, predicates);
122            flatten_inner_join_chain(*right, components, predicates);
123
124            if let Some(on_expr) = on {
125                let mut preds = Vec::new();
126                extract_predicates(&on_expr, &mut preds);
127                predicates.extend(preds);
128            }
129        }
130    } else {
131        // Leaf: base table, subquery, or non-inner join subtree
132        let table_names = collect_table_aliases(&table_ref);
133        let estimated_rows = heuristic_row_estimate(&table_ref);
134        components.push(JoinComponent {
135            table_ref,
136            table_names,
137            estimated_rows,
138        });
139    }
140}
141
142/// Estimate the number of rows a table reference produces, using heuristics.
143fn heuristic_row_estimate(table_ref: &TableReference) -> f64 {
144    match table_ref {
145        TableReference::Table { .. } => 10_000.0,
146        TableReference::Join {
147            left,
148            right,
149            join_type,
150            on,
151        } => {
152            let left_rows = heuristic_row_estimate(left);
153            let right_rows = heuristic_row_estimate(right);
154
155            let selectivity = if let Some(on_expr) = on {
156                let mut preds = Vec::new();
157                extract_predicates(on_expr, &mut preds);
158                heuristic_selectivity(&preds)
159            } else {
160                1.0
161            };
162
163            match join_type {
164                JoinType::Inner | JoinType::Cross => {
165                    (left_rows * right_rows * selectivity).max(1.0)
166                }
167                JoinType::Left => left_rows.max(1.0),
168                JoinType::Right => right_rows.max(1.0),
169                JoinType::Full => (left_rows + right_rows).max(1.0),
170            }
171        }
172        TableReference::Subquery { .. } => 1_000.0,
173    }
174}
175
176/// Estimate selectivity of a single predicate expression using heuristics.
177///
178/// Selectivity ranges from 0.0 (filters everything) to 1.0 (filters nothing).
179/// These values are based on standard database textbook defaults:
180/// - Equality: 1/10 (assumes ~10 distinct values)
181/// - Range: 1/3
182/// - LIKE: 1/10
183/// - IS NULL: 1/20
184pub(crate) fn heuristic_single_selectivity(pred: &Expr) -> f64 {
185    match pred {
186        Expr::BinaryOp { left, op, right } => match op {
187            BinaryOperator::Eq => 0.1,
188            BinaryOperator::NotEq => 0.9,
189            BinaryOperator::Lt
190            | BinaryOperator::LtEq
191            | BinaryOperator::Gt
192            | BinaryOperator::GtEq => 0.33,
193            BinaryOperator::And => {
194                heuristic_single_selectivity(left) * heuristic_single_selectivity(right)
195            }
196            BinaryOperator::Or => {
197                let l = heuristic_single_selectivity(left);
198                let r = heuristic_single_selectivity(right);
199                l + r - l * r
200            }
201            BinaryOperator::Like => 0.1,
202            BinaryOperator::NotLike => 0.9,
203            _ => 0.5,
204        },
205        Expr::IsNull(_) => 0.05,
206        Expr::IsNotNull(_) => 0.95,
207        Expr::InList { list, negated, .. } => {
208            let sel = (list.len() as f64 * 0.1).min(0.9);
209            if *negated { 1.0 - sel } else { sel }
210        }
211        Expr::Between { negated, .. } => {
212            if *negated {
213                0.75
214            } else {
215                0.25
216            }
217        }
218        _ => 0.5,
219    }
220}
221
222/// Combined selectivity of multiple predicates (assuming independence).
223pub(crate) fn heuristic_selectivity(predicates: &[Expr]) -> f64 {
224    if predicates.is_empty() {
225        return 1.0;
226    }
227    predicates
228        .iter()
229        .map(heuristic_single_selectivity)
230        .product::<f64>()
231        .max(0.0001)
232}
233
234/// Estimate the cost of joining two components, considering applicable predicates.
235///
236/// Cost model:
237/// - Find predicates that reference both sides (cross-component predicates)
238/// - Estimate output rows = left_rows * right_rows * selectivity
239/// - Total cost = hash_build + hash_probe + output_materialization
240fn estimate_pair_join_cost(
241    left: &JoinComponent,
242    right: &JoinComponent,
243    all_predicates: &[Expr],
244) -> f64 {
245    // Find predicates that reference tables from both sides
246    let applicable: Vec<&Expr> = all_predicates
247        .iter()
248        .filter(|pred| {
249            let tables = get_predicate_tables(pred);
250            !tables.is_empty()
251                && tables.iter().any(|t| left.table_names.contains(t))
252                && tables.iter().any(|t| right.table_names.contains(t))
253        })
254        .collect();
255
256    let selectivity = if applicable.is_empty() {
257        1.0 // No cross-component predicates = cross join
258    } else {
259        applicable
260            .iter()
261            .map(|p| heuristic_single_selectivity(p))
262            .product::<f64>()
263            .max(0.0001)
264    };
265
266    let output_rows = left.estimated_rows * right.estimated_rows * selectivity;
267
268    // Hash join cost: build from smaller side, probe from larger side
269    let (build_rows, probe_rows) = if left.estimated_rows <= right.estimated_rows {
270        (left.estimated_rows, right.estimated_rows)
271    } else {
272        (right.estimated_rows, left.estimated_rows)
273    };
274
275    let build_cost = build_rows * 10.0;
276    let probe_cost = probe_rows * 5.0;
277    let output_cost = output_rows * 2.0;
278
279    build_cost + probe_cost + output_cost
280}
281
282/// Greedy join ordering: iteratively merge the cheapest pair until one tree remains.
283///
284/// At each iteration:
285/// 1. Evaluate cost for every pair of remaining components
286/// 2. Pick the pair with minimum estimated cost
287/// 3. Merge them into a single component with a JOIN node
288/// 4. Assign applicable predicates as the ON condition
289/// 5. Repeat until only one component remains
290fn greedy_join_order(
291    mut components: Vec<JoinComponent>,
292    mut all_predicates: Vec<Expr>,
293) -> TableReference {
294    // Validate we have components to reorder
295    if components.is_empty() {
296        return TableReference::Table {
297            name: String::new(),
298            alias: None,
299        };
300    }
301
302    if components.len() == 1 {
303        let mut result = components
304            .into_iter()
305            .next()
306            .map(|c| c.table_ref)
307            .unwrap_or(TableReference::Table {
308                name: String::new(),
309                alias: None,
310            });
311
312        // Apply any remaining predicates
313        if !all_predicates.is_empty() {
314            if let TableReference::Join { ref mut on, .. } = result {
315                let remaining = super::combine_predicates_with_and(all_predicates);
316                *on = match (on.take(), remaining) {
317                    (Some(existing), Some(new_pred)) => Some(Expr::BinaryOp {
318                        left: Box::new(existing),
319                        op: BinaryOperator::And,
320                        right: Box::new(new_pred),
321                    }),
322                    (Some(existing), None) => Some(existing),
323                    (None, some_pred) => some_pred,
324                };
325            }
326        }
327
328        return result;
329    }
330
331    while components.len() > 1 {
332        // Find the pair with minimum join cost
333        let mut best_i = 0;
334        let mut best_j = 1;
335        let mut best_cost = f64::MAX;
336
337        for i in 0..components.len() {
338            for j in (i + 1)..components.len() {
339                let cost = estimate_pair_join_cost(&components[i], &components[j], &all_predicates);
340                if cost < best_cost {
341                    best_cost = cost;
342                    best_i = i;
343                    best_j = j;
344                }
345            }
346        }
347
348        // Remove the two components (larger index first to avoid invalidation)
349        let right_comp = components.remove(best_j);
350        let left_comp = components.remove(best_i);
351
352        // Partition predicates: applicable to this join vs. remaining
353        let merged_tables: HashSet<String> = left_comp
354            .table_names
355            .iter()
356            .chain(right_comp.table_names.iter())
357            .cloned()
358            .collect();
359
360        let mut join_preds = Vec::new();
361        let mut remaining_preds = Vec::new();
362
363        for pred in all_predicates {
364            let tables = get_predicate_tables(&pred);
365            if !tables.is_empty() && tables.iter().all(|t| merged_tables.contains(t)) {
366                join_preds.push(pred);
367            } else {
368                remaining_preds.push(pred);
369            }
370        }
371        all_predicates = remaining_preds;
372
373        // Estimate output size for the merged component
374        let selectivity = heuristic_selectivity(&join_preds);
375        let output_rows =
376            (left_comp.estimated_rows * right_comp.estimated_rows * selectivity).max(1.0);
377
378        let on_condition = combine_predicates_with_and(join_preds);
379
380        components.push(JoinComponent {
381            table_ref: TableReference::Join {
382                left: Box::new(left_comp.table_ref),
383                right: Box::new(right_comp.table_ref),
384                join_type: JoinType::Inner,
385                on: on_condition,
386            },
387            table_names: merged_tables,
388            estimated_rows: output_rows,
389        });
390    }
391
392    let mut result = components
393        .into_iter()
394        .next()
395        .map(|c| c.table_ref)
396        .unwrap_or(TableReference::Table {
397            name: String::new(),
398            alias: None,
399        });
400
401    // Apply any remaining predicates to the outermost join's ON condition
402    if !all_predicates.is_empty() {
403        if let TableReference::Join { ref mut on, .. } = result {
404            let remaining = combine_predicates_with_and(all_predicates);
405            *on = match (on.take(), remaining) {
406                (Some(existing), Some(new_pred)) => Some(Expr::BinaryOp {
407                    left: Box::new(existing),
408                    op: BinaryOperator::And,
409                    right: Box::new(new_pred),
410                }),
411                (Some(existing), None) => Some(existing),
412                (None, some_pred) => some_pred,
413            };
414        }
415    }
416
417    result
418}
419
420#[cfg(test)]
421#[allow(clippy::unwrap_used)]
422#[allow(clippy::panic)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn test_join_reorder_preserves_outer_join() {
428        let stmt = SelectStatement {
429            projection: vec![SelectItem::Wildcard],
430            from: Some(TableReference::Join {
431                left: Box::new(TableReference::Table {
432                    name: "a".to_string(),
433                    alias: None,
434                }),
435                right: Box::new(TableReference::Table {
436                    name: "b".to_string(),
437                    alias: None,
438                }),
439                join_type: JoinType::Left,
440                on: Some(Expr::BinaryOp {
441                    left: Box::new(Expr::Column {
442                        table: Some("a".to_string()),
443                        name: "id".to_string(),
444                    }),
445                    op: BinaryOperator::Eq,
446                    right: Box::new(Expr::Column {
447                        table: Some("b".to_string()),
448                        name: "id".to_string(),
449                    }),
450                }),
451            }),
452            selection: None,
453            group_by: Vec::new(),
454            having: None,
455            order_by: Vec::new(),
456            limit: None,
457            offset: None,
458        };
459
460        let reorder = JoinReordering;
461        let result = reorder.apply(stmt);
462        assert!(result.is_ok(), "Join reordering should succeed");
463        let result = result.expect("Join reordering should succeed");
464
465        // LEFT join should be preserved (not reordered)
466        let Some(TableReference::Join { join_type, .. }) = &result.from else {
467            panic!("FROM should contain a join");
468        };
469        assert_eq!(*join_type, JoinType::Left);
470    }
471
472    #[test]
473    fn test_join_reorder_three_inner_tables() {
474        // A INNER JOIN B ON a.id = b.id INNER JOIN C ON b.id = c.id
475        let stmt = SelectStatement {
476            projection: vec![SelectItem::Wildcard],
477            from: Some(TableReference::Join {
478                left: Box::new(TableReference::Join {
479                    left: Box::new(TableReference::Table {
480                        name: "a".to_string(),
481                        alias: Some("a".to_string()),
482                    }),
483                    right: Box::new(TableReference::Table {
484                        name: "b".to_string(),
485                        alias: Some("b".to_string()),
486                    }),
487                    join_type: JoinType::Inner,
488                    on: Some(Expr::BinaryOp {
489                        left: Box::new(Expr::Column {
490                            table: Some("a".to_string()),
491                            name: "id".to_string(),
492                        }),
493                        op: BinaryOperator::Eq,
494                        right: Box::new(Expr::Column {
495                            table: Some("b".to_string()),
496                            name: "id".to_string(),
497                        }),
498                    }),
499                }),
500                right: Box::new(TableReference::Table {
501                    name: "c".to_string(),
502                    alias: Some("c".to_string()),
503                }),
504                join_type: JoinType::Inner,
505                on: Some(Expr::BinaryOp {
506                    left: Box::new(Expr::Column {
507                        table: Some("b".to_string()),
508                        name: "id".to_string(),
509                    }),
510                    op: BinaryOperator::Eq,
511                    right: Box::new(Expr::Column {
512                        table: Some("c".to_string()),
513                        name: "id".to_string(),
514                    }),
515                }),
516            }),
517            selection: None,
518            group_by: Vec::new(),
519            having: None,
520            order_by: Vec::new(),
521            limit: None,
522            offset: None,
523        };
524
525        let reorder = JoinReordering;
526        let result = reorder.apply(stmt);
527        assert!(result.is_ok(), "Join reordering should succeed");
528        let result = result.expect("Join reordering should succeed");
529
530        // All three tables should still be present in the result
531        let Some(from) = result.from.as_ref() else {
532            panic!("FROM should exist");
533        };
534        let aliases = collect_table_aliases(from);
535        assert!(aliases.contains("a"), "Table a missing");
536        assert!(aliases.contains("b"), "Table b missing");
537        assert!(aliases.contains("c"), "Table c missing");
538    }
539
540    #[test]
541    fn test_join_reorder_single_table() {
542        let stmt = SelectStatement {
543            projection: vec![SelectItem::Wildcard],
544            from: Some(TableReference::Table {
545                name: "users".to_string(),
546                alias: None,
547            }),
548            selection: None,
549            group_by: Vec::new(),
550            having: None,
551            order_by: Vec::new(),
552            limit: None,
553            offset: None,
554        };
555
556        let reorder = JoinReordering;
557        let result = reorder.apply(stmt);
558        assert!(result.is_ok(), "Join reordering should succeed");
559        let result = result.expect("Join reordering should succeed");
560
561        // Single table should be unchanged
562        assert!(matches!(
563            &result.from,
564            Some(TableReference::Table { name, .. }) if name == "users"
565        ));
566    }
567
568    #[test]
569    fn test_heuristic_selectivity_values() {
570        // Equality predicate
571        let eq_pred = Expr::BinaryOp {
572            left: Box::new(Expr::Column {
573                table: None,
574                name: "a".to_string(),
575            }),
576            op: BinaryOperator::Eq,
577            right: Box::new(Expr::Literal(Literal::Integer(1))),
578        };
579        let sel = heuristic_single_selectivity(&eq_pred);
580        assert!((sel - 0.1).abs() < 0.001);
581
582        // Range predicate
583        let lt_pred = Expr::BinaryOp {
584            left: Box::new(Expr::Column {
585                table: None,
586                name: "a".to_string(),
587            }),
588            op: BinaryOperator::Lt,
589            right: Box::new(Expr::Literal(Literal::Integer(10))),
590        };
591        let sel = heuristic_single_selectivity(&lt_pred);
592        assert!((sel - 0.33).abs() < 0.001);
593
594        // IS NULL
595        let null_pred = Expr::IsNull(Box::new(Expr::Column {
596            table: None,
597            name: "a".to_string(),
598        }));
599        let sel = heuristic_single_selectivity(&null_pred);
600        assert!((sel - 0.05).abs() < 0.001);
601
602        // Combined AND
603        let preds = vec![eq_pred, lt_pred];
604        let combined = heuristic_selectivity(&preds);
605        assert!((combined - 0.033).abs() < 0.001);
606
607        // Empty predicates
608        let empty: Vec<Expr> = vec![];
609        assert!((heuristic_selectivity(&empty) - 1.0).abs() < 0.001);
610    }
611}