Skip to main content

polyglot_sql/optimizer/
eliminate_joins.rs

1//! Join Elimination Module
2//!
3//! This module removes unused joins from SQL queries. A join can be eliminated
4//! when no columns from the joined table are referenced outside the ON clause.
5//!
6//! Ported from sqlglot's optimizer/eliminate_joins.py
7
8use crate::expressions::*;
9use crate::scope::traverse_scope;
10
11/// Remove unused joins from an expression.
12///
13/// A LEFT JOIN can be eliminated when no columns from the joined table are
14/// referenced in the SELECT list, WHERE clause, GROUP BY, HAVING, ORDER BY,
15/// or any other part of the query outside the JOIN's own ON clause.
16///
17/// Semi and anti joins are never eliminated because they affect the result
18/// set cardinality even when no columns are selected from them.
19///
20/// If the scope contains unqualified columns, we conservatively skip
21/// elimination since we cannot determine which source an unqualified
22/// column belongs to.
23///
24/// # Example
25///
26/// ```sql
27/// -- Before:
28/// SELECT x.a FROM x LEFT JOIN y ON x.b = y.b
29/// -- After:
30/// SELECT x.a FROM x
31/// ```
32///
33/// # Arguments
34/// * `expression` - The expression to optimize
35///
36/// # Returns
37/// The optimized expression with unnecessary joins removed
38pub fn eliminate_joins(expression: Expression) -> Expression {
39    let scopes = traverse_scope(&expression);
40
41    // Collect (source_alias, join_index) pairs to remove across all scopes.
42    // We gather them first and then apply removals so that scope analysis
43    // (which borrows the expression immutably) is finished before we mutate.
44    let mut removals: Vec<JoinRemoval> = Vec::new();
45
46    for mut scope in scopes {
47        // If there are unqualified columns we cannot safely determine which
48        // source they belong to, so skip this scope.
49        if !scope.unqualified_columns().is_empty() {
50            continue;
51        }
52
53        let select = match &scope.expression {
54            Expression::Select(s) => s.clone(),
55            _ => continue,
56        };
57
58        let joins = &select.joins;
59        if joins.is_empty() {
60            continue;
61        }
62
63        // Iterate joins in reverse order (like the Python implementation)
64        // so that index-based removal is stable.
65        for (idx, join) in joins.iter().enumerate().rev() {
66            if is_semi_or_anti_join(join) {
67                continue;
68            }
69
70            let alias = join_alias_or_name(join);
71            let alias = match alias {
72                Some(a) => a,
73                None => continue,
74            };
75
76            if should_eliminate_join(&mut scope, join, &alias) {
77                removals.push(JoinRemoval {
78                    select_id: select_identity(&select),
79                    join_index: idx,
80                    source_alias: alias,
81                });
82            }
83        }
84    }
85
86    if removals.is_empty() {
87        return expression;
88    }
89
90    apply_removals(expression, &removals)
91}
92
93// ---------------------------------------------------------------------------
94// Internal types
95// ---------------------------------------------------------------------------
96
97/// Describes a join that should be removed.
98struct JoinRemoval {
99    /// An identity key for the Select node that owns this join.
100    select_id: SelectIdentity,
101    /// The index of the join in the Select's joins vec.
102    join_index: usize,
103    /// The alias (or name) of the joined source so we can also remove it
104    /// from scope bookkeeping.
105    #[allow(dead_code)]
106    source_alias: String,
107}
108
109/// A lightweight identity for a Select node so we can match it when
110/// walking the cloned tree. We use a combination of the number of
111/// select-list expressions and the number of joins since that is
112/// sufficient for the simple cases we handle and avoids needing
113/// pointer identity across a clone.
114#[derive(Debug, Clone, PartialEq, Eq)]
115struct SelectIdentity {
116    num_expressions: usize,
117    num_joins: usize,
118    /// First select expression as generated text (for disambiguation).
119    first_expr_debug: String,
120}
121
122fn select_identity(select: &Select) -> SelectIdentity {
123    SelectIdentity {
124        num_expressions: select.expressions.len(),
125        num_joins: select.joins.len(),
126        first_expr_debug: select
127            .expressions
128            .first()
129            .map(|e| format!("{:?}", e))
130            .unwrap_or_default(),
131    }
132}
133
134// ---------------------------------------------------------------------------
135// Helpers
136// ---------------------------------------------------------------------------
137
138/// Returns `true` if the join is a SEMI or ANTI join (any directional
139/// variant). These joins affect result cardinality even when no columns
140/// are selected, so they must not be eliminated.
141fn is_semi_or_anti_join(join: &Join) -> bool {
142    matches!(
143        join.kind,
144        JoinKind::Semi
145            | JoinKind::Anti
146            | JoinKind::LeftSemi
147            | JoinKind::LeftAnti
148            | JoinKind::RightSemi
149            | JoinKind::RightAnti
150    )
151}
152
153/// Extract the alias or table name from a join's source expression.
154fn join_alias_or_name(join: &Join) -> Option<String> {
155    get_table_alias_or_name(&join.this)
156}
157
158/// Get alias or name from a table/subquery expression.
159fn get_table_alias_or_name(expr: &Expression) -> Option<String> {
160    match expr {
161        Expression::Table(table) => {
162            if let Some(ref alias) = table.alias {
163                Some(alias.name.clone())
164            } else {
165                Some(table.name.name.clone())
166            }
167        }
168        Expression::Subquery(subquery) => subquery.alias.as_ref().map(|a| a.name.clone()),
169        _ => None,
170    }
171}
172
173/// Determine whether a join should be eliminated.
174///
175/// A join is eliminable when:
176/// 1. It is a LEFT JOIN, AND
177/// 2. No columns from the joined source appear outside the ON clause
178///
179/// The scope's `source_columns` method collects column references from
180/// the SELECT list, WHERE, HAVING, GROUP BY, and ORDER BY -- but not
181/// from JOIN ON clauses (those belong to the join, not the query body).
182/// So if `source_columns(alias)` is empty, the joined table is unused.
183fn should_eliminate_join(scope: &mut crate::scope::Scope, join: &Join, alias: &str) -> bool {
184    // Only LEFT JOINs can be safely eliminated in the general case.
185    // (INNER JOINs can filter rows, RIGHT/FULL JOINs can introduce NULLs
186    // on the left side, CROSS JOINs affect cardinality.)
187    if join.kind != JoinKind::Left {
188        return false;
189    }
190
191    // Check whether any columns from this source are referenced
192    // outside the ON clause.
193    let source_cols = scope.source_columns(alias);
194    source_cols.is_empty()
195}
196
197/// Walk the expression tree, find matching Select nodes, and remove the
198/// indicated joins.
199fn apply_removals(expression: Expression, removals: &[JoinRemoval]) -> Expression {
200    match expression {
201        Expression::Select(select) => {
202            let id = select_identity(&select);
203
204            // Collect join indices to drop for this Select.
205            let mut indices_to_drop: Vec<usize> = removals
206                .iter()
207                .filter(|r| r.select_id == id)
208                .map(|r| r.join_index)
209                .collect();
210            indices_to_drop.sort_unstable();
211            indices_to_drop.dedup();
212
213            let mut new_select = select.clone();
214
215            // Remove joins (iterate in reverse to keep indices valid).
216            for &idx in indices_to_drop.iter().rev() {
217                if idx < new_select.joins.len() {
218                    new_select.joins.remove(idx);
219                }
220            }
221
222            // Recursively process subqueries in other parts of the Select
223            new_select.expressions = new_select
224                .expressions
225                .into_iter()
226                .map(|e| apply_removals(e, removals))
227                .collect();
228
229            if let Some(ref mut from) = new_select.from {
230                from.expressions = from
231                    .expressions
232                    .clone()
233                    .into_iter()
234                    .map(|e| apply_removals(e, removals))
235                    .collect();
236            }
237
238            if let Some(ref mut w) = new_select.where_clause {
239                w.this = apply_removals(w.this.clone(), removals);
240            }
241
242            // Process remaining joins' subqueries
243            new_select.joins = new_select
244                .joins
245                .into_iter()
246                .map(|mut j| {
247                    j.this = apply_removals(j.this, removals);
248                    if let Some(on) = j.on {
249                        j.on = Some(apply_removals(on, removals));
250                    }
251                    j
252                })
253                .collect();
254
255            if let Some(ref mut with) = new_select.with {
256                with.ctes = with
257                    .ctes
258                    .iter()
259                    .map(|cte| {
260                        let mut new_cte = cte.clone();
261                        new_cte.this = apply_removals(new_cte.this, removals);
262                        new_cte
263                    })
264                    .collect();
265            }
266
267            Expression::Select(new_select)
268        }
269        Expression::Subquery(mut subquery) => {
270            subquery.this = apply_removals(subquery.this, removals);
271            Expression::Subquery(subquery)
272        }
273        Expression::Union(mut union) => {
274            union.left = apply_removals(union.left, removals);
275            union.right = apply_removals(union.right, removals);
276            Expression::Union(union)
277        }
278        Expression::Intersect(mut intersect) => {
279            intersect.left = apply_removals(intersect.left, removals);
280            intersect.right = apply_removals(intersect.right, removals);
281            Expression::Intersect(intersect)
282        }
283        Expression::Except(mut except) => {
284            except.left = apply_removals(except.left, removals);
285            except.right = apply_removals(except.right, removals);
286            Expression::Except(except)
287        }
288        other => other,
289    }
290}
291
292// ===========================================================================
293// Tests
294// ===========================================================================
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::generator::Generator;
300    use crate::parser::Parser;
301
302    fn gen(expr: &Expression) -> String {
303        Generator::new().generate(expr).unwrap()
304    }
305
306    fn parse(sql: &str) -> Expression {
307        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
308    }
309
310    // -----------------------------------------------------------------------
311    // LEFT JOIN where no columns from the joined table are used => removed
312    // -----------------------------------------------------------------------
313
314    #[test]
315    fn test_eliminate_unused_left_join() {
316        let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b");
317        let result = eliminate_joins(expr);
318        let sql = gen(&result);
319
320        // The LEFT JOIN to y should be removed because no columns from y
321        // appear in the SELECT list (or WHERE, GROUP BY, etc.).
322        assert!(
323            !sql.contains("JOIN"),
324            "Expected JOIN to be eliminated, got: {}",
325            sql
326        );
327        assert!(
328            sql.contains("SELECT x.a FROM x"),
329            "Expected simple select, got: {}",
330            sql
331        );
332    }
333
334    // -----------------------------------------------------------------------
335    // LEFT JOIN where columns from the joined table ARE used => kept
336    // -----------------------------------------------------------------------
337
338    #[test]
339    fn test_keep_used_left_join() {
340        let expr = parse("SELECT x.a, y.c FROM x LEFT JOIN y ON x.b = y.b");
341        let result = eliminate_joins(expr);
342        let sql = gen(&result);
343
344        // The LEFT JOIN should be preserved because y.c is in the SELECT list.
345        assert!(
346            sql.contains("JOIN"),
347            "Expected JOIN to be preserved, got: {}",
348            sql
349        );
350    }
351
352    // -----------------------------------------------------------------------
353    // INNER JOIN where no columns are used => NOT removed (INNER affects rows)
354    // -----------------------------------------------------------------------
355
356    #[test]
357    fn test_inner_join_not_eliminated() {
358        let expr = parse("SELECT x.a FROM x JOIN y ON x.b = y.b");
359        let result = eliminate_joins(expr);
360        let sql = gen(&result);
361
362        // INNER JOINs can filter rows, so they should not be removed even
363        // when no columns from the inner source are selected.
364        assert!(
365            sql.contains("JOIN"),
366            "Expected INNER JOIN to be preserved, got: {}",
367            sql
368        );
369    }
370
371    // -----------------------------------------------------------------------
372    // LEFT JOIN with column in WHERE => kept
373    // -----------------------------------------------------------------------
374
375    #[test]
376    fn test_keep_left_join_column_in_where() {
377        let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b WHERE y.c > 1");
378        let result = eliminate_joins(expr);
379        let sql = gen(&result);
380
381        assert!(
382            sql.contains("JOIN"),
383            "Expected JOIN to be preserved (column in WHERE), got: {}",
384            sql
385        );
386    }
387
388    // -----------------------------------------------------------------------
389    // Multiple joins: only the unused one is removed
390    // -----------------------------------------------------------------------
391
392    #[test]
393    fn test_eliminate_one_of_multiple_joins() {
394        let expr = parse(
395            "SELECT x.a, z.d FROM x LEFT JOIN y ON x.b = y.b LEFT JOIN z ON x.c = z.c",
396        );
397        let result = eliminate_joins(expr);
398        let sql = gen(&result);
399
400        // y is unused (no y.* columns outside ON), z is used (z.d in SELECT).
401        // So the JOIN to y should be removed but the JOIN to z kept.
402        assert!(
403            sql.contains("JOIN"),
404            "Expected at least one JOIN to remain, got: {}",
405            sql
406        );
407        assert!(
408            !sql.contains("JOIN y"),
409            "Expected JOIN y to be removed, got: {}",
410            sql
411        );
412        assert!(
413            sql.contains("z"),
414            "Expected z to remain, got: {}",
415            sql
416        );
417    }
418
419    // -----------------------------------------------------------------------
420    // No joins at all => expression unchanged
421    // -----------------------------------------------------------------------
422
423    #[test]
424    fn test_no_joins_unchanged() {
425        let expr = parse("SELECT a FROM x");
426        let original_sql = gen(&expr);
427        let result = eliminate_joins(expr);
428        let result_sql = gen(&result);
429
430        assert_eq!(original_sql, result_sql);
431    }
432
433    // -----------------------------------------------------------------------
434    // CROSS JOIN => not eliminated (affects cardinality)
435    // -----------------------------------------------------------------------
436
437    #[test]
438    fn test_cross_join_not_eliminated() {
439        let expr = parse("SELECT x.a FROM x CROSS JOIN y");
440        let result = eliminate_joins(expr);
441        let sql = gen(&result);
442
443        assert!(
444            sql.contains("CROSS JOIN"),
445            "Expected CROSS JOIN to be preserved, got: {}",
446            sql
447        );
448    }
449
450    // -----------------------------------------------------------------------
451    // Unqualified columns => skip elimination (conservative)
452    // -----------------------------------------------------------------------
453
454    #[test]
455    fn test_skip_with_unqualified_columns() {
456        // 'a' is unqualified -- we cannot be sure it doesn't come from y
457        let expr = parse("SELECT a FROM x LEFT JOIN y ON x.b = y.b");
458        let result = eliminate_joins(expr);
459        let sql = gen(&result);
460
461        // Because 'a' is unqualified the pass should conservatively keep the join.
462        assert!(
463            sql.contains("JOIN"),
464            "Expected JOIN to be preserved (unqualified columns), got: {}",
465            sql
466        );
467    }
468
469    // -----------------------------------------------------------------------
470    // LEFT JOIN column used in GROUP BY => kept
471    // -----------------------------------------------------------------------
472
473    #[test]
474    fn test_keep_left_join_column_in_group_by() {
475        let expr = parse(
476            "SELECT x.a, COUNT(*) FROM x LEFT JOIN y ON x.b = y.b GROUP BY y.c",
477        );
478        let result = eliminate_joins(expr);
479        let sql = gen(&result);
480
481        assert!(
482            sql.contains("JOIN"),
483            "Expected JOIN to be preserved (column in GROUP BY), got: {}",
484            sql
485        );
486    }
487
488    // -----------------------------------------------------------------------
489    // LEFT JOIN column used in ORDER BY => kept
490    // -----------------------------------------------------------------------
491
492    #[test]
493    fn test_keep_left_join_column_in_order_by() {
494        let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b ORDER BY y.c");
495        let result = eliminate_joins(expr);
496        let sql = gen(&result);
497
498        assert!(
499            sql.contains("JOIN"),
500            "Expected JOIN to be preserved (column in ORDER BY), got: {}",
501            sql
502        );
503    }
504}