polyglot_sql/optimizer/eliminate_joins.rs
1//! Join Elimination Module
2//!
3//! This module removes unused joins from SQL queries. A join can be eliminated
4//! when no columns from the joined table are referenced outside the ON clause.
5//!
6//! Ported from sqlglot's optimizer/eliminate_joins.py
7
8use crate::expressions::*;
9use crate::scope::traverse_scope;
10
11/// Remove unused joins from an expression.
12///
13/// A LEFT JOIN can be eliminated when no columns from the joined table are
14/// referenced in the SELECT list, WHERE clause, GROUP BY, HAVING, ORDER BY,
15/// or any other part of the query outside the JOIN's own ON clause.
16///
17/// Semi and anti joins are never eliminated because they affect the result
18/// set cardinality even when no columns are selected from them.
19///
20/// If the scope contains unqualified columns, we conservatively skip
21/// elimination since we cannot determine which source an unqualified
22/// column belongs to.
23///
24/// # Example
25///
26/// ```sql
27/// -- Before:
28/// SELECT x.a FROM x LEFT JOIN y ON x.b = y.b
29/// -- After:
30/// SELECT x.a FROM x
31/// ```
32///
33/// # Arguments
34/// * `expression` - The expression to optimize
35///
36/// # Returns
37/// The optimized expression with unnecessary joins removed
38pub fn eliminate_joins(expression: Expression) -> Expression {
39 let scopes = traverse_scope(&expression);
40
41 // Collect (source_alias, join_index) pairs to remove across all scopes.
42 // We gather them first and then apply removals so that scope analysis
43 // (which borrows the expression immutably) is finished before we mutate.
44 let mut removals: Vec<JoinRemoval> = Vec::new();
45
46 for mut scope in scopes {
47 // If there are unqualified columns we cannot safely determine which
48 // source they belong to, so skip this scope.
49 if !scope.unqualified_columns().is_empty() {
50 continue;
51 }
52
53 let select = match &scope.expression {
54 Expression::Select(s) => s.clone(),
55 _ => continue,
56 };
57
58 let joins = &select.joins;
59 if joins.is_empty() {
60 continue;
61 }
62
63 // Iterate joins in reverse order (like the Python implementation)
64 // so that index-based removal is stable.
65 for (idx, join) in joins.iter().enumerate().rev() {
66 if is_semi_or_anti_join(join) {
67 continue;
68 }
69
70 let alias = join_alias_or_name(join);
71 let alias = match alias {
72 Some(a) => a,
73 None => continue,
74 };
75
76 if should_eliminate_join(&mut scope, join, &alias) {
77 removals.push(JoinRemoval {
78 select_id: select_identity(&select),
79 join_index: idx,
80 source_alias: alias,
81 });
82 }
83 }
84 }
85
86 if removals.is_empty() {
87 return expression;
88 }
89
90 apply_removals(expression, &removals)
91}
92
93// ---------------------------------------------------------------------------
94// Internal types
95// ---------------------------------------------------------------------------
96
97/// Describes a join that should be removed.
98struct JoinRemoval {
99 /// An identity key for the Select node that owns this join.
100 select_id: SelectIdentity,
101 /// The index of the join in the Select's joins vec.
102 join_index: usize,
103 /// The alias (or name) of the joined source so we can also remove it
104 /// from scope bookkeeping.
105 #[allow(dead_code)]
106 source_alias: String,
107}
108
109/// A lightweight identity for a Select node so we can match it when
110/// walking the cloned tree. We use a combination of the number of
111/// select-list expressions and the number of joins since that is
112/// sufficient for the simple cases we handle and avoids needing
113/// pointer identity across a clone.
114#[derive(Debug, Clone, PartialEq, Eq)]
115struct SelectIdentity {
116 num_expressions: usize,
117 num_joins: usize,
118 /// First select expression as generated text (for disambiguation).
119 first_expr_debug: String,
120}
121
122fn select_identity(select: &Select) -> SelectIdentity {
123 SelectIdentity {
124 num_expressions: select.expressions.len(),
125 num_joins: select.joins.len(),
126 first_expr_debug: select
127 .expressions
128 .first()
129 .map(|e| format!("{:?}", e))
130 .unwrap_or_default(),
131 }
132}
133
134// ---------------------------------------------------------------------------
135// Helpers
136// ---------------------------------------------------------------------------
137
138/// Returns `true` if the join is a SEMI or ANTI join (any directional
139/// variant). These joins affect result cardinality even when no columns
140/// are selected, so they must not be eliminated.
141fn is_semi_or_anti_join(join: &Join) -> bool {
142 matches!(
143 join.kind,
144 JoinKind::Semi
145 | JoinKind::Anti
146 | JoinKind::LeftSemi
147 | JoinKind::LeftAnti
148 | JoinKind::RightSemi
149 | JoinKind::RightAnti
150 )
151}
152
153/// Extract the alias or table name from a join's source expression.
154fn join_alias_or_name(join: &Join) -> Option<String> {
155 get_table_alias_or_name(&join.this)
156}
157
158/// Get alias or name from a table/subquery expression.
159fn get_table_alias_or_name(expr: &Expression) -> Option<String> {
160 match expr {
161 Expression::Table(table) => {
162 if let Some(ref alias) = table.alias {
163 Some(alias.name.clone())
164 } else {
165 Some(table.name.name.clone())
166 }
167 }
168 Expression::Subquery(subquery) => subquery.alias.as_ref().map(|a| a.name.clone()),
169 _ => None,
170 }
171}
172
173/// Determine whether a join should be eliminated.
174///
175/// A join is eliminable when:
176/// 1. It is a LEFT JOIN, AND
177/// 2. No columns from the joined source appear outside the ON clause
178///
179/// The scope's `source_columns` method collects column references from
180/// the SELECT list, WHERE, HAVING, GROUP BY, and ORDER BY -- but not
181/// from JOIN ON clauses (those belong to the join, not the query body).
182/// So if `source_columns(alias)` is empty, the joined table is unused.
183fn should_eliminate_join(scope: &mut crate::scope::Scope, join: &Join, alias: &str) -> bool {
184 // Only LEFT JOINs can be safely eliminated in the general case.
185 // (INNER JOINs can filter rows, RIGHT/FULL JOINs can introduce NULLs
186 // on the left side, CROSS JOINs affect cardinality.)
187 if join.kind != JoinKind::Left {
188 return false;
189 }
190
191 // Check whether any columns from this source are referenced
192 // outside the ON clause.
193 let source_cols = scope.source_columns(alias);
194 source_cols.is_empty()
195}
196
197/// Walk the expression tree, find matching Select nodes, and remove the
198/// indicated joins.
199fn apply_removals(expression: Expression, removals: &[JoinRemoval]) -> Expression {
200 match expression {
201 Expression::Select(select) => {
202 let id = select_identity(&select);
203
204 // Collect join indices to drop for this Select.
205 let mut indices_to_drop: Vec<usize> = removals
206 .iter()
207 .filter(|r| r.select_id == id)
208 .map(|r| r.join_index)
209 .collect();
210 indices_to_drop.sort_unstable();
211 indices_to_drop.dedup();
212
213 let mut new_select = select.clone();
214
215 // Remove joins (iterate in reverse to keep indices valid).
216 for &idx in indices_to_drop.iter().rev() {
217 if idx < new_select.joins.len() {
218 new_select.joins.remove(idx);
219 }
220 }
221
222 // Recursively process subqueries in other parts of the Select
223 new_select.expressions = new_select
224 .expressions
225 .into_iter()
226 .map(|e| apply_removals(e, removals))
227 .collect();
228
229 if let Some(ref mut from) = new_select.from {
230 from.expressions = from
231 .expressions
232 .clone()
233 .into_iter()
234 .map(|e| apply_removals(e, removals))
235 .collect();
236 }
237
238 if let Some(ref mut w) = new_select.where_clause {
239 w.this = apply_removals(w.this.clone(), removals);
240 }
241
242 // Process remaining joins' subqueries
243 new_select.joins = new_select
244 .joins
245 .into_iter()
246 .map(|mut j| {
247 j.this = apply_removals(j.this, removals);
248 if let Some(on) = j.on {
249 j.on = Some(apply_removals(on, removals));
250 }
251 j
252 })
253 .collect();
254
255 if let Some(ref mut with) = new_select.with {
256 with.ctes = with
257 .ctes
258 .iter()
259 .map(|cte| {
260 let mut new_cte = cte.clone();
261 new_cte.this = apply_removals(new_cte.this, removals);
262 new_cte
263 })
264 .collect();
265 }
266
267 Expression::Select(new_select)
268 }
269 Expression::Subquery(mut subquery) => {
270 subquery.this = apply_removals(subquery.this, removals);
271 Expression::Subquery(subquery)
272 }
273 Expression::Union(mut union) => {
274 union.left = apply_removals(union.left, removals);
275 union.right = apply_removals(union.right, removals);
276 Expression::Union(union)
277 }
278 Expression::Intersect(mut intersect) => {
279 intersect.left = apply_removals(intersect.left, removals);
280 intersect.right = apply_removals(intersect.right, removals);
281 Expression::Intersect(intersect)
282 }
283 Expression::Except(mut except) => {
284 except.left = apply_removals(except.left, removals);
285 except.right = apply_removals(except.right, removals);
286 Expression::Except(except)
287 }
288 other => other,
289 }
290}
291
292// ===========================================================================
293// Tests
294// ===========================================================================
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::generator::Generator;
300 use crate::parser::Parser;
301
302 fn gen(expr: &Expression) -> String {
303 Generator::new().generate(expr).unwrap()
304 }
305
306 fn parse(sql: &str) -> Expression {
307 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
308 }
309
310 // -----------------------------------------------------------------------
311 // LEFT JOIN where no columns from the joined table are used => removed
312 // -----------------------------------------------------------------------
313
314 #[test]
315 fn test_eliminate_unused_left_join() {
316 let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b");
317 let result = eliminate_joins(expr);
318 let sql = gen(&result);
319
320 // The LEFT JOIN to y should be removed because no columns from y
321 // appear in the SELECT list (or WHERE, GROUP BY, etc.).
322 assert!(
323 !sql.contains("JOIN"),
324 "Expected JOIN to be eliminated, got: {}",
325 sql
326 );
327 assert!(
328 sql.contains("SELECT x.a FROM x"),
329 "Expected simple select, got: {}",
330 sql
331 );
332 }
333
334 // -----------------------------------------------------------------------
335 // LEFT JOIN where columns from the joined table ARE used => kept
336 // -----------------------------------------------------------------------
337
338 #[test]
339 fn test_keep_used_left_join() {
340 let expr = parse("SELECT x.a, y.c FROM x LEFT JOIN y ON x.b = y.b");
341 let result = eliminate_joins(expr);
342 let sql = gen(&result);
343
344 // The LEFT JOIN should be preserved because y.c is in the SELECT list.
345 assert!(
346 sql.contains("JOIN"),
347 "Expected JOIN to be preserved, got: {}",
348 sql
349 );
350 }
351
352 // -----------------------------------------------------------------------
353 // INNER JOIN where no columns are used => NOT removed (INNER affects rows)
354 // -----------------------------------------------------------------------
355
356 #[test]
357 fn test_inner_join_not_eliminated() {
358 let expr = parse("SELECT x.a FROM x JOIN y ON x.b = y.b");
359 let result = eliminate_joins(expr);
360 let sql = gen(&result);
361
362 // INNER JOINs can filter rows, so they should not be removed even
363 // when no columns from the inner source are selected.
364 assert!(
365 sql.contains("JOIN"),
366 "Expected INNER JOIN to be preserved, got: {}",
367 sql
368 );
369 }
370
371 // -----------------------------------------------------------------------
372 // LEFT JOIN with column in WHERE => kept
373 // -----------------------------------------------------------------------
374
375 #[test]
376 fn test_keep_left_join_column_in_where() {
377 let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b WHERE y.c > 1");
378 let result = eliminate_joins(expr);
379 let sql = gen(&result);
380
381 assert!(
382 sql.contains("JOIN"),
383 "Expected JOIN to be preserved (column in WHERE), got: {}",
384 sql
385 );
386 }
387
388 // -----------------------------------------------------------------------
389 // Multiple joins: only the unused one is removed
390 // -----------------------------------------------------------------------
391
392 #[test]
393 fn test_eliminate_one_of_multiple_joins() {
394 let expr = parse(
395 "SELECT x.a, z.d FROM x LEFT JOIN y ON x.b = y.b LEFT JOIN z ON x.c = z.c",
396 );
397 let result = eliminate_joins(expr);
398 let sql = gen(&result);
399
400 // y is unused (no y.* columns outside ON), z is used (z.d in SELECT).
401 // So the JOIN to y should be removed but the JOIN to z kept.
402 assert!(
403 sql.contains("JOIN"),
404 "Expected at least one JOIN to remain, got: {}",
405 sql
406 );
407 assert!(
408 !sql.contains("JOIN y"),
409 "Expected JOIN y to be removed, got: {}",
410 sql
411 );
412 assert!(
413 sql.contains("z"),
414 "Expected z to remain, got: {}",
415 sql
416 );
417 }
418
419 // -----------------------------------------------------------------------
420 // No joins at all => expression unchanged
421 // -----------------------------------------------------------------------
422
423 #[test]
424 fn test_no_joins_unchanged() {
425 let expr = parse("SELECT a FROM x");
426 let original_sql = gen(&expr);
427 let result = eliminate_joins(expr);
428 let result_sql = gen(&result);
429
430 assert_eq!(original_sql, result_sql);
431 }
432
433 // -----------------------------------------------------------------------
434 // CROSS JOIN => not eliminated (affects cardinality)
435 // -----------------------------------------------------------------------
436
437 #[test]
438 fn test_cross_join_not_eliminated() {
439 let expr = parse("SELECT x.a FROM x CROSS JOIN y");
440 let result = eliminate_joins(expr);
441 let sql = gen(&result);
442
443 assert!(
444 sql.contains("CROSS JOIN"),
445 "Expected CROSS JOIN to be preserved, got: {}",
446 sql
447 );
448 }
449
450 // -----------------------------------------------------------------------
451 // Unqualified columns => skip elimination (conservative)
452 // -----------------------------------------------------------------------
453
454 #[test]
455 fn test_skip_with_unqualified_columns() {
456 // 'a' is unqualified -- we cannot be sure it doesn't come from y
457 let expr = parse("SELECT a FROM x LEFT JOIN y ON x.b = y.b");
458 let result = eliminate_joins(expr);
459 let sql = gen(&result);
460
461 // Because 'a' is unqualified the pass should conservatively keep the join.
462 assert!(
463 sql.contains("JOIN"),
464 "Expected JOIN to be preserved (unqualified columns), got: {}",
465 sql
466 );
467 }
468
469 // -----------------------------------------------------------------------
470 // LEFT JOIN column used in GROUP BY => kept
471 // -----------------------------------------------------------------------
472
473 #[test]
474 fn test_keep_left_join_column_in_group_by() {
475 let expr = parse(
476 "SELECT x.a, COUNT(*) FROM x LEFT JOIN y ON x.b = y.b GROUP BY y.c",
477 );
478 let result = eliminate_joins(expr);
479 let sql = gen(&result);
480
481 assert!(
482 sql.contains("JOIN"),
483 "Expected JOIN to be preserved (column in GROUP BY), got: {}",
484 sql
485 );
486 }
487
488 // -----------------------------------------------------------------------
489 // LEFT JOIN column used in ORDER BY => kept
490 // -----------------------------------------------------------------------
491
492 #[test]
493 fn test_keep_left_join_column_in_order_by() {
494 let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b ORDER BY y.c");
495 let result = eliminate_joins(expr);
496 let sql = gen(&result);
497
498 assert!(
499 sql.contains("JOIN"),
500 "Expected JOIN to be preserved (column in ORDER BY), got: {}",
501 sql
502 );
503 }
504}