Skip to main content

polyglot_sql/optimizer/
eliminate_joins.rs

1//! Join Elimination Module
2//!
3//! This module removes unused joins from SQL queries. A join can be eliminated
4//! when no columns from the joined table are referenced outside the ON clause.
5//!
6//! Ported from sqlglot's optimizer/eliminate_joins.py
7
8use crate::expressions::*;
9use crate::scope::traverse_scope;
10use crate::scope::ColumnRef;
11use std::collections::HashMap;
12
13/// Remove unused joins from an expression.
14///
15/// A LEFT JOIN can be eliminated when no columns from the joined table are
16/// referenced in the SELECT list, WHERE clause, GROUP BY, HAVING, ORDER BY,
17/// or any other part of the query outside the JOIN's own ON clause.
18///
19/// Semi and anti joins are never eliminated because they affect the result
20/// set cardinality even when no columns are selected from them.
21///
22/// If the scope contains unqualified columns, we conservatively skip
23/// elimination since we cannot determine which source an unqualified
24/// column belongs to.
25///
26/// # Example
27///
28/// ```sql
29/// -- Before:
30/// SELECT x.a FROM x LEFT JOIN y ON x.b = y.b
31/// -- After:
32/// SELECT x.a FROM x
33/// ```
34///
35/// # Arguments
36/// * `expression` - The expression to optimize
37///
38/// # Returns
39/// The optimized expression with unnecessary joins removed
40pub fn eliminate_joins(expression: Expression) -> Expression {
41    let scopes = traverse_scope(&expression);
42
43    // Collect (source_alias, join_index) pairs to remove across all scopes.
44    // We gather them first and then apply removals so that scope analysis
45    // (which borrows the expression immutably) is finished before we mutate.
46    let mut removals: Vec<JoinRemoval> = Vec::new();
47
48    for mut scope in scopes {
49        // If there are unqualified columns we cannot safely determine which
50        // source they belong to, so skip this scope.
51        if !scope.unqualified_columns().is_empty() {
52            continue;
53        }
54
55        let select = match &scope.expression {
56            Expression::Select(s) => s.clone(),
57            _ => continue,
58        };
59
60        let joins = &select.joins;
61        if joins.is_empty() {
62            continue;
63        }
64
65        // Iterate joins in reverse order (like the Python implementation)
66        // so that index-based removal is stable.
67        for (idx, join) in joins.iter().enumerate().rev() {
68            if is_semi_or_anti_join(join) {
69                continue;
70            }
71
72            let alias = join_alias_or_name(join);
73            let alias = match alias {
74                Some(a) => a,
75                None => continue,
76            };
77
78            if should_eliminate_join(&mut scope, &select, idx, join, &alias) {
79                removals.push(JoinRemoval {
80                    select_id: select_identity(&select),
81                    join_index: idx,
82                    source_alias: alias,
83                });
84            }
85        }
86    }
87
88    if removals.is_empty() {
89        return expression;
90    }
91
92    apply_removals(expression, &removals)
93}
94
95// ---------------------------------------------------------------------------
96// Internal types
97// ---------------------------------------------------------------------------
98
99/// Describes a join that should be removed.
100struct JoinRemoval {
101    /// An identity key for the Select node that owns this join.
102    select_id: SelectIdentity,
103    /// The index of the join in the Select's joins vec.
104    join_index: usize,
105    /// The alias (or name) of the joined source so we can also remove it
106    /// from scope bookkeeping.
107    #[allow(dead_code)]
108    source_alias: String,
109}
110
111/// A lightweight identity for a Select node so we can match it when
112/// walking the cloned tree. We use a combination of the number of
113/// select-list expressions and the number of joins since that is
114/// sufficient for the simple cases we handle and avoids needing
115/// pointer identity across a clone.
116#[derive(Debug, Clone, PartialEq, Eq)]
117struct SelectIdentity {
118    num_expressions: usize,
119    num_joins: usize,
120    /// First select expression as generated text (for disambiguation).
121    first_expr_debug: String,
122}
123
124fn select_identity(select: &Select) -> SelectIdentity {
125    SelectIdentity {
126        num_expressions: select.expressions.len(),
127        num_joins: select.joins.len(),
128        first_expr_debug: select
129            .expressions
130            .first()
131            .map(|e| format!("{:?}", e))
132            .unwrap_or_default(),
133    }
134}
135
136// ---------------------------------------------------------------------------
137// Helpers
138// ---------------------------------------------------------------------------
139
140/// Returns `true` if the join is a SEMI or ANTI join (any directional
141/// variant). These joins affect result cardinality even when no columns
142/// are selected, so they must not be eliminated.
143fn is_semi_or_anti_join(join: &Join) -> bool {
144    matches!(
145        join.kind,
146        JoinKind::Semi
147            | JoinKind::Anti
148            | JoinKind::LeftSemi
149            | JoinKind::LeftAnti
150            | JoinKind::RightSemi
151            | JoinKind::RightAnti
152    )
153}
154
155/// Extract the alias or table name from a join's source expression.
156fn join_alias_or_name(join: &Join) -> Option<String> {
157    get_table_alias_or_name(&join.this)
158}
159
160/// Get alias or name from a table/subquery expression.
161fn get_table_alias_or_name(expr: &Expression) -> Option<String> {
162    match expr {
163        Expression::Table(table) => {
164            if let Some(ref alias) = table.alias {
165                Some(alias.name.clone())
166            } else {
167                Some(table.name.name.clone())
168            }
169        }
170        Expression::Subquery(subquery) => subquery.alias.as_ref().map(|a| a.name.clone()),
171        _ => None,
172    }
173}
174
175/// Determine whether a join should be eliminated.
176///
177/// A join is eliminable when:
178/// 1. It is a LEFT JOIN, AND
179/// 2. No columns from the joined source appear outside the ON clause
180///
181/// The scope's `source_columns` includes JOIN conditions, so this check
182/// explicitly subtracts the current join's own ON / MATCH_CONDITION references
183/// and verifies whether any remaining references to the joined source exist.
184fn should_eliminate_join(
185    scope: &mut crate::scope::Scope,
186    _select: &Select,
187    _join_index: usize,
188    join: &Join,
189    alias: &str,
190) -> bool {
191    // Only LEFT JOINs can be safely eliminated in the general case.
192    // (INNER JOINs can filter rows, RIGHT/FULL JOINs can introduce NULLs
193    // on the left side, CROSS JOINs affect cardinality.)
194    if join.kind != JoinKind::Left {
195        return false;
196    }
197
198    // Check whether any columns from this source are referenced outside this join's
199    // own join condition (ON / MATCH_CONDITION). With `scope.columns()` now
200    // including all JOIN conditions, we subtract this join's own condition refs.
201    let source_cols = scope.source_columns(alias);
202    if source_cols.is_empty() {
203        return true;
204    }
205
206    let mut source_counts: HashMap<(String, String), usize> = HashMap::new();
207    for col in &source_cols {
208        if let Some(table) = &col.table {
209            *source_counts
210                .entry((table.clone(), col.name.clone()))
211                .or_insert(0) += 1;
212        }
213    }
214
215    if let Some(on) = &join.on {
216        subtract_columns_from_counts(alias, on, &mut source_counts);
217    }
218    if let Some(match_condition) = &join.match_condition {
219        subtract_columns_from_counts(alias, match_condition, &mut source_counts);
220    }
221
222    !source_counts.values().any(|&count| count > 0)
223}
224
225fn subtract_columns_from_counts(
226    alias: &str,
227    expr: &Expression,
228    counts: &mut HashMap<(String, String), usize>,
229) {
230    let mut cols: Vec<ColumnRef> = Vec::new();
231    collect_columns_in_expression(expr, &mut cols);
232
233    for col in cols {
234        if col.table.as_deref() != Some(alias) {
235            continue;
236        }
237        let key = (alias.to_string(), col.name);
238        if let Some(value) = counts.get_mut(&key) {
239            if *value > 0 {
240                *value -= 1;
241            }
242        }
243    }
244}
245
246fn collect_columns_in_expression(expr: &Expression, columns: &mut Vec<ColumnRef>) {
247    match expr {
248        Expression::Column(col) => {
249            columns.push(ColumnRef {
250                table: col.table.as_ref().map(|t| t.name.clone()),
251                name: col.name.name.clone(),
252            });
253        }
254        Expression::Select(select) => {
255            for e in &select.expressions {
256                collect_columns_in_expression(e, columns);
257            }
258            if let Some(from) = &select.from {
259                for e in &from.expressions {
260                    collect_columns_in_expression(e, columns);
261                }
262            }
263            for join in &select.joins {
264                collect_columns_in_expression(&join.this, columns);
265                if let Some(on) = &join.on {
266                    collect_columns_in_expression(on, columns);
267                }
268                if let Some(match_condition) = &join.match_condition {
269                    collect_columns_in_expression(match_condition, columns);
270                }
271            }
272            if let Some(where_clause) = &select.where_clause {
273                collect_columns_in_expression(&where_clause.this, columns);
274            }
275            if let Some(group_by) = &select.group_by {
276                for e in &group_by.expressions {
277                    collect_columns_in_expression(e, columns);
278                }
279            }
280            if let Some(having) = &select.having {
281                collect_columns_in_expression(&having.this, columns);
282            }
283            if let Some(order_by) = &select.order_by {
284                for o in &order_by.expressions {
285                    collect_columns_in_expression(&o.this, columns);
286                }
287            }
288            if let Some(qualify) = &select.qualify {
289                collect_columns_in_expression(&qualify.this, columns);
290            }
291            if let Some(limit) = &select.limit {
292                collect_columns_in_expression(&limit.this, columns);
293            }
294            if let Some(offset) = &select.offset {
295                collect_columns_in_expression(&offset.this, columns);
296            }
297        }
298        Expression::Alias(alias) => {
299            collect_columns_in_expression(&alias.this, columns);
300        }
301        Expression::Function(func) => {
302            for arg in &func.args {
303                collect_columns_in_expression(arg, columns);
304            }
305        }
306        Expression::AggregateFunction(agg) => {
307            for arg in &agg.args {
308                collect_columns_in_expression(arg, columns);
309            }
310        }
311        Expression::And(bin)
312        | Expression::Or(bin)
313        | Expression::Eq(bin)
314        | Expression::Neq(bin)
315        | Expression::Lt(bin)
316        | Expression::Lte(bin)
317        | Expression::Gt(bin)
318        | Expression::Gte(bin)
319        | Expression::Add(bin)
320        | Expression::Sub(bin)
321        | Expression::Mul(bin)
322        | Expression::Div(bin)
323        | Expression::Mod(bin)
324        | Expression::BitwiseAnd(bin)
325        | Expression::BitwiseOr(bin)
326        | Expression::BitwiseXor(bin)
327        | Expression::Concat(bin) => {
328            collect_columns_in_expression(&bin.left, columns);
329            collect_columns_in_expression(&bin.right, columns);
330        }
331        Expression::Like(like) | Expression::ILike(like) => {
332            collect_columns_in_expression(&like.left, columns);
333            collect_columns_in_expression(&like.right, columns);
334            if let Some(escape) = &like.escape {
335                collect_columns_in_expression(escape, columns);
336            }
337        }
338        Expression::Not(unary) | Expression::Neg(unary) | Expression::BitwiseNot(unary) => {
339            collect_columns_in_expression(&unary.this, columns);
340        }
341        Expression::Case(case) => {
342            if let Some(operand) = &case.operand {
343                collect_columns_in_expression(operand, columns);
344            }
345            for (when_expr, then_expr) in &case.whens {
346                collect_columns_in_expression(when_expr, columns);
347                collect_columns_in_expression(then_expr, columns);
348            }
349            if let Some(else_) = &case.else_ {
350                collect_columns_in_expression(else_, columns);
351            }
352        }
353        Expression::Cast(cast) => {
354            collect_columns_in_expression(&cast.this, columns);
355        }
356        Expression::In(in_expr) => {
357            collect_columns_in_expression(&in_expr.this, columns);
358            for e in &in_expr.expressions {
359                collect_columns_in_expression(e, columns);
360            }
361            if let Some(query) = &in_expr.query {
362                collect_columns_in_expression(query, columns);
363            }
364        }
365        Expression::Between(between) => {
366            collect_columns_in_expression(&between.this, columns);
367            collect_columns_in_expression(&between.low, columns);
368            collect_columns_in_expression(&between.high, columns);
369        }
370        Expression::Exists(exists) => {
371            collect_columns_in_expression(&exists.this, columns);
372        }
373        Expression::Subquery(subquery) => {
374            collect_columns_in_expression(&subquery.this, columns);
375        }
376        Expression::WindowFunction(wf) => {
377            collect_columns_in_expression(&wf.this, columns);
378            for p in &wf.over.partition_by {
379                collect_columns_in_expression(p, columns);
380            }
381            for o in &wf.over.order_by {
382                collect_columns_in_expression(&o.this, columns);
383            }
384            if let Some(frame) = &wf.over.frame {
385                collect_columns_from_window_bound(&frame.start, columns);
386                if let Some(end) = &frame.end {
387                    collect_columns_from_window_bound(end, columns);
388                }
389            }
390        }
391        Expression::Ordered(ord) => {
392            collect_columns_in_expression(&ord.this, columns);
393        }
394        Expression::Paren(paren) => {
395            collect_columns_in_expression(&paren.this, columns);
396        }
397        Expression::Join(join) => {
398            collect_columns_in_expression(&join.this, columns);
399            if let Some(on) = &join.on {
400                collect_columns_in_expression(on, columns);
401            }
402            if let Some(match_condition) = &join.match_condition {
403                collect_columns_in_expression(match_condition, columns);
404            }
405        }
406        _ => {}
407    }
408}
409
410fn collect_columns_from_window_bound(bound: &WindowFrameBound, columns: &mut Vec<ColumnRef>) {
411    match bound {
412        WindowFrameBound::Preceding(expr)
413        | WindowFrameBound::Following(expr)
414        | WindowFrameBound::Value(expr) => collect_columns_in_expression(expr, columns),
415        WindowFrameBound::CurrentRow
416        | WindowFrameBound::UnboundedPreceding
417        | WindowFrameBound::UnboundedFollowing
418        | WindowFrameBound::BarePreceding
419        | WindowFrameBound::BareFollowing => {}
420    }
421}
422
423/// Walk the expression tree, find matching Select nodes, and remove the
424/// indicated joins.
425fn apply_removals(expression: Expression, removals: &[JoinRemoval]) -> Expression {
426    match expression {
427        Expression::Select(select) => {
428            let id = select_identity(&select);
429
430            // Collect join indices to drop for this Select.
431            let mut indices_to_drop: Vec<usize> = removals
432                .iter()
433                .filter(|r| r.select_id == id)
434                .map(|r| r.join_index)
435                .collect();
436            indices_to_drop.sort_unstable();
437            indices_to_drop.dedup();
438
439            let mut new_select = select.clone();
440
441            // Remove joins (iterate in reverse to keep indices valid).
442            for &idx in indices_to_drop.iter().rev() {
443                if idx < new_select.joins.len() {
444                    new_select.joins.remove(idx);
445                }
446            }
447
448            // Recursively process subqueries in other parts of the Select
449            new_select.expressions = new_select
450                .expressions
451                .into_iter()
452                .map(|e| apply_removals(e, removals))
453                .collect();
454
455            if let Some(ref mut from) = new_select.from {
456                from.expressions = from
457                    .expressions
458                    .clone()
459                    .into_iter()
460                    .map(|e| apply_removals(e, removals))
461                    .collect();
462            }
463
464            if let Some(ref mut w) = new_select.where_clause {
465                w.this = apply_removals(w.this.clone(), removals);
466            }
467
468            // Process remaining joins' subqueries
469            new_select.joins = new_select
470                .joins
471                .into_iter()
472                .map(|mut j| {
473                    j.this = apply_removals(j.this, removals);
474                    if let Some(on) = j.on {
475                        j.on = Some(apply_removals(on, removals));
476                    }
477                    j
478                })
479                .collect();
480
481            if let Some(ref mut with) = new_select.with {
482                with.ctes = with
483                    .ctes
484                    .iter()
485                    .map(|cte| {
486                        let mut new_cte = cte.clone();
487                        new_cte.this = apply_removals(new_cte.this, removals);
488                        new_cte
489                    })
490                    .collect();
491            }
492
493            Expression::Select(new_select)
494        }
495        Expression::Subquery(mut subquery) => {
496            subquery.this = apply_removals(subquery.this, removals);
497            Expression::Subquery(subquery)
498        }
499        Expression::Union(mut union) => {
500            union.left = apply_removals(union.left, removals);
501            union.right = apply_removals(union.right, removals);
502            Expression::Union(union)
503        }
504        Expression::Intersect(mut intersect) => {
505            intersect.left = apply_removals(intersect.left, removals);
506            intersect.right = apply_removals(intersect.right, removals);
507            Expression::Intersect(intersect)
508        }
509        Expression::Except(mut except) => {
510            except.left = apply_removals(except.left, removals);
511            except.right = apply_removals(except.right, removals);
512            Expression::Except(except)
513        }
514        other => other,
515    }
516}
517
518// ===========================================================================
519// Tests
520// ===========================================================================
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use crate::generator::Generator;
526    use crate::parser::Parser;
527
528    fn gen(expr: &Expression) -> String {
529        Generator::new().generate(expr).unwrap()
530    }
531
532    fn parse(sql: &str) -> Expression {
533        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
534    }
535
536    // -----------------------------------------------------------------------
537    // LEFT JOIN where no columns from the joined table are used => removed
538    // -----------------------------------------------------------------------
539
540    #[test]
541    fn test_eliminate_unused_left_join() {
542        let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b");
543        let result = eliminate_joins(expr);
544        let sql = gen(&result);
545
546        // The LEFT JOIN to y should be removed because no columns from y
547        // appear in the SELECT list (or WHERE, GROUP BY, etc.).
548        assert!(
549            !sql.contains("JOIN"),
550            "Expected JOIN to be eliminated, got: {}",
551            sql
552        );
553        assert!(
554            sql.contains("SELECT x.a FROM x"),
555            "Expected simple select, got: {}",
556            sql
557        );
558    }
559
560    // -----------------------------------------------------------------------
561    // LEFT JOIN where columns from the joined table ARE used => kept
562    // -----------------------------------------------------------------------
563
564    #[test]
565    fn test_keep_used_left_join() {
566        let expr = parse("SELECT x.a, y.c FROM x LEFT JOIN y ON x.b = y.b");
567        let result = eliminate_joins(expr);
568        let sql = gen(&result);
569
570        // The LEFT JOIN should be preserved because y.c is in the SELECT list.
571        assert!(
572            sql.contains("JOIN"),
573            "Expected JOIN to be preserved, got: {}",
574            sql
575        );
576    }
577
578    // -----------------------------------------------------------------------
579    // INNER JOIN where no columns are used => NOT removed (INNER affects rows)
580    // -----------------------------------------------------------------------
581
582    #[test]
583    fn test_inner_join_not_eliminated() {
584        let expr = parse("SELECT x.a FROM x JOIN y ON x.b = y.b");
585        let result = eliminate_joins(expr);
586        let sql = gen(&result);
587
588        // INNER JOINs can filter rows, so they should not be removed even
589        // when no columns from the inner source are selected.
590        assert!(
591            sql.contains("JOIN"),
592            "Expected INNER JOIN to be preserved, got: {}",
593            sql
594        );
595    }
596
597    // -----------------------------------------------------------------------
598    // LEFT JOIN with column in WHERE => kept
599    // -----------------------------------------------------------------------
600
601    #[test]
602    fn test_keep_left_join_column_in_where() {
603        let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b WHERE y.c > 1");
604        let result = eliminate_joins(expr);
605        let sql = gen(&result);
606
607        assert!(
608            sql.contains("JOIN"),
609            "Expected JOIN to be preserved (column in WHERE), got: {}",
610            sql
611        );
612    }
613
614    // -----------------------------------------------------------------------
615    // Multiple joins: only the unused one is removed
616    // -----------------------------------------------------------------------
617
618    #[test]
619    fn test_eliminate_one_of_multiple_joins() {
620        let expr =
621            parse("SELECT x.a, z.d FROM x LEFT JOIN y ON x.b = y.b LEFT JOIN z ON x.c = z.c");
622        let result = eliminate_joins(expr);
623        let sql = gen(&result);
624
625        // y is unused (no y.* columns outside ON), z is used (z.d in SELECT).
626        // So the JOIN to y should be removed but the JOIN to z kept.
627        assert!(
628            sql.contains("JOIN"),
629            "Expected at least one JOIN to remain, got: {}",
630            sql
631        );
632        assert!(
633            !sql.contains("JOIN y"),
634            "Expected JOIN y to be removed, got: {}",
635            sql
636        );
637        assert!(sql.contains("z"), "Expected z to remain, got: {}", sql);
638    }
639
640    // -----------------------------------------------------------------------
641    // No joins at all => expression unchanged
642    // -----------------------------------------------------------------------
643
644    #[test]
645    fn test_no_joins_unchanged() {
646        let expr = parse("SELECT a FROM x");
647        let original_sql = gen(&expr);
648        let result = eliminate_joins(expr);
649        let result_sql = gen(&result);
650
651        assert_eq!(original_sql, result_sql);
652    }
653
654    // -----------------------------------------------------------------------
655    // CROSS JOIN => not eliminated (affects cardinality)
656    // -----------------------------------------------------------------------
657
658    #[test]
659    fn test_cross_join_not_eliminated() {
660        let expr = parse("SELECT x.a FROM x CROSS JOIN y");
661        let result = eliminate_joins(expr);
662        let sql = gen(&result);
663
664        assert!(
665            sql.contains("CROSS JOIN"),
666            "Expected CROSS JOIN to be preserved, got: {}",
667            sql
668        );
669    }
670
671    // -----------------------------------------------------------------------
672    // Unqualified columns => skip elimination (conservative)
673    // -----------------------------------------------------------------------
674
675    #[test]
676    fn test_skip_with_unqualified_columns() {
677        // 'a' is unqualified -- we cannot be sure it doesn't come from y
678        let expr = parse("SELECT a FROM x LEFT JOIN y ON x.b = y.b");
679        let result = eliminate_joins(expr);
680        let sql = gen(&result);
681
682        // Because 'a' is unqualified the pass should conservatively keep the join.
683        assert!(
684            sql.contains("JOIN"),
685            "Expected JOIN to be preserved (unqualified columns), got: {}",
686            sql
687        );
688    }
689
690    // -----------------------------------------------------------------------
691    // LEFT JOIN column used in GROUP BY => kept
692    // -----------------------------------------------------------------------
693
694    #[test]
695    fn test_keep_left_join_column_in_group_by() {
696        let expr = parse("SELECT x.a, COUNT(*) FROM x LEFT JOIN y ON x.b = y.b GROUP BY y.c");
697        let result = eliminate_joins(expr);
698        let sql = gen(&result);
699
700        assert!(
701            sql.contains("JOIN"),
702            "Expected JOIN to be preserved (column in GROUP BY), got: {}",
703            sql
704        );
705    }
706
707    // -----------------------------------------------------------------------
708    // LEFT JOIN column used in ORDER BY => kept
709    // -----------------------------------------------------------------------
710
711    #[test]
712    fn test_keep_left_join_column_in_order_by() {
713        let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b ORDER BY y.c");
714        let result = eliminate_joins(expr);
715        let sql = gen(&result);
716
717        assert!(
718            sql.contains("JOIN"),
719            "Expected JOIN to be preserved (column in ORDER BY), got: {}",
720            sql
721        );
722    }
723
724    #[test]
725    fn test_keep_left_join_used_in_other_join_condition() {
726        let expr =
727            parse("SELECT x.a FROM x LEFT JOIN y ON x.y_id = y.id LEFT JOIN z ON y.id = z.y_id");
728        let result = eliminate_joins(expr);
729        let sql = gen(&result);
730
731        assert!(
732            sql.contains("JOIN y"),
733            "Expected JOIN y to be preserved (used in another JOIN ON), got: {}",
734            sql
735        );
736    }
737}