1use crate::expressions::*;
9use crate::scope::traverse_scope;
10use crate::scope::ColumnRef;
11use std::collections::HashMap;
12
13pub fn eliminate_joins(expression: Expression) -> Expression {
41 let scopes = traverse_scope(&expression);
42
43 let mut removals: Vec<JoinRemoval> = Vec::new();
47
48 for mut scope in scopes {
49 if !scope.unqualified_columns().is_empty() {
52 continue;
53 }
54
55 let select = match &scope.expression {
56 Expression::Select(s) => s.clone(),
57 _ => continue,
58 };
59
60 let joins = &select.joins;
61 if joins.is_empty() {
62 continue;
63 }
64
65 for (idx, join) in joins.iter().enumerate().rev() {
68 if is_semi_or_anti_join(join) {
69 continue;
70 }
71
72 let alias = join_alias_or_name(join);
73 let alias = match alias {
74 Some(a) => a,
75 None => continue,
76 };
77
78 if should_eliminate_join(&mut scope, &select, idx, join, &alias) {
79 removals.push(JoinRemoval {
80 select_id: select_identity(&select),
81 join_index: idx,
82 source_alias: alias,
83 });
84 }
85 }
86 }
87
88 if removals.is_empty() {
89 return expression;
90 }
91
92 apply_removals(expression, &removals)
93}
94
95struct JoinRemoval {
101 select_id: SelectIdentity,
103 join_index: usize,
105 #[allow(dead_code)]
108 source_alias: String,
109}
110
111#[derive(Debug, Clone, PartialEq, Eq)]
117struct SelectIdentity {
118 num_expressions: usize,
119 num_joins: usize,
120 first_expr_debug: String,
122}
123
124fn select_identity(select: &Select) -> SelectIdentity {
125 SelectIdentity {
126 num_expressions: select.expressions.len(),
127 num_joins: select.joins.len(),
128 first_expr_debug: select
129 .expressions
130 .first()
131 .map(|e| format!("{:?}", e))
132 .unwrap_or_default(),
133 }
134}
135
136fn is_semi_or_anti_join(join: &Join) -> bool {
144 matches!(
145 join.kind,
146 JoinKind::Semi
147 | JoinKind::Anti
148 | JoinKind::LeftSemi
149 | JoinKind::LeftAnti
150 | JoinKind::RightSemi
151 | JoinKind::RightAnti
152 )
153}
154
155fn join_alias_or_name(join: &Join) -> Option<String> {
157 get_table_alias_or_name(&join.this)
158}
159
160fn get_table_alias_or_name(expr: &Expression) -> Option<String> {
162 match expr {
163 Expression::Table(table) => {
164 if let Some(ref alias) = table.alias {
165 Some(alias.name.clone())
166 } else {
167 Some(table.name.name.clone())
168 }
169 }
170 Expression::Subquery(subquery) => subquery.alias.as_ref().map(|a| a.name.clone()),
171 _ => None,
172 }
173}
174
175fn should_eliminate_join(
185 scope: &mut crate::scope::Scope,
186 _select: &Select,
187 _join_index: usize,
188 join: &Join,
189 alias: &str,
190) -> bool {
191 if join.kind != JoinKind::Left {
195 return false;
196 }
197
198 let source_cols = scope.source_columns(alias);
202 if source_cols.is_empty() {
203 return true;
204 }
205
206 let mut source_counts: HashMap<(String, String), usize> = HashMap::new();
207 for col in &source_cols {
208 if let Some(table) = &col.table {
209 *source_counts
210 .entry((table.clone(), col.name.clone()))
211 .or_insert(0) += 1;
212 }
213 }
214
215 if let Some(on) = &join.on {
216 subtract_columns_from_counts(alias, on, &mut source_counts);
217 }
218 if let Some(match_condition) = &join.match_condition {
219 subtract_columns_from_counts(alias, match_condition, &mut source_counts);
220 }
221
222 !source_counts.values().any(|&count| count > 0)
223}
224
225fn subtract_columns_from_counts(
226 alias: &str,
227 expr: &Expression,
228 counts: &mut HashMap<(String, String), usize>,
229) {
230 let mut cols: Vec<ColumnRef> = Vec::new();
231 collect_columns_in_expression(expr, &mut cols);
232
233 for col in cols {
234 if col.table.as_deref() != Some(alias) {
235 continue;
236 }
237 let key = (alias.to_string(), col.name);
238 if let Some(value) = counts.get_mut(&key) {
239 if *value > 0 {
240 *value -= 1;
241 }
242 }
243 }
244}
245
246fn collect_columns_in_expression(expr: &Expression, columns: &mut Vec<ColumnRef>) {
247 match expr {
248 Expression::Column(col) => {
249 columns.push(ColumnRef {
250 table: col.table.as_ref().map(|t| t.name.clone()),
251 name: col.name.name.clone(),
252 });
253 }
254 Expression::Select(select) => {
255 for e in &select.expressions {
256 collect_columns_in_expression(e, columns);
257 }
258 if let Some(from) = &select.from {
259 for e in &from.expressions {
260 collect_columns_in_expression(e, columns);
261 }
262 }
263 for join in &select.joins {
264 collect_columns_in_expression(&join.this, columns);
265 if let Some(on) = &join.on {
266 collect_columns_in_expression(on, columns);
267 }
268 if let Some(match_condition) = &join.match_condition {
269 collect_columns_in_expression(match_condition, columns);
270 }
271 }
272 if let Some(where_clause) = &select.where_clause {
273 collect_columns_in_expression(&where_clause.this, columns);
274 }
275 if let Some(group_by) = &select.group_by {
276 for e in &group_by.expressions {
277 collect_columns_in_expression(e, columns);
278 }
279 }
280 if let Some(having) = &select.having {
281 collect_columns_in_expression(&having.this, columns);
282 }
283 if let Some(order_by) = &select.order_by {
284 for o in &order_by.expressions {
285 collect_columns_in_expression(&o.this, columns);
286 }
287 }
288 if let Some(qualify) = &select.qualify {
289 collect_columns_in_expression(&qualify.this, columns);
290 }
291 if let Some(limit) = &select.limit {
292 collect_columns_in_expression(&limit.this, columns);
293 }
294 if let Some(offset) = &select.offset {
295 collect_columns_in_expression(&offset.this, columns);
296 }
297 }
298 Expression::Alias(alias) => {
299 collect_columns_in_expression(&alias.this, columns);
300 }
301 Expression::Function(func) => {
302 for arg in &func.args {
303 collect_columns_in_expression(arg, columns);
304 }
305 }
306 Expression::AggregateFunction(agg) => {
307 for arg in &agg.args {
308 collect_columns_in_expression(arg, columns);
309 }
310 }
311 Expression::And(bin)
312 | Expression::Or(bin)
313 | Expression::Eq(bin)
314 | Expression::Neq(bin)
315 | Expression::Lt(bin)
316 | Expression::Lte(bin)
317 | Expression::Gt(bin)
318 | Expression::Gte(bin)
319 | Expression::Add(bin)
320 | Expression::Sub(bin)
321 | Expression::Mul(bin)
322 | Expression::Div(bin)
323 | Expression::Mod(bin)
324 | Expression::BitwiseAnd(bin)
325 | Expression::BitwiseOr(bin)
326 | Expression::BitwiseXor(bin)
327 | Expression::Concat(bin) => {
328 collect_columns_in_expression(&bin.left, columns);
329 collect_columns_in_expression(&bin.right, columns);
330 }
331 Expression::Like(like) | Expression::ILike(like) => {
332 collect_columns_in_expression(&like.left, columns);
333 collect_columns_in_expression(&like.right, columns);
334 if let Some(escape) = &like.escape {
335 collect_columns_in_expression(escape, columns);
336 }
337 }
338 Expression::Not(unary) | Expression::Neg(unary) | Expression::BitwiseNot(unary) => {
339 collect_columns_in_expression(&unary.this, columns);
340 }
341 Expression::Case(case) => {
342 if let Some(operand) = &case.operand {
343 collect_columns_in_expression(operand, columns);
344 }
345 for (when_expr, then_expr) in &case.whens {
346 collect_columns_in_expression(when_expr, columns);
347 collect_columns_in_expression(then_expr, columns);
348 }
349 if let Some(else_) = &case.else_ {
350 collect_columns_in_expression(else_, columns);
351 }
352 }
353 Expression::Cast(cast) => {
354 collect_columns_in_expression(&cast.this, columns);
355 }
356 Expression::In(in_expr) => {
357 collect_columns_in_expression(&in_expr.this, columns);
358 for e in &in_expr.expressions {
359 collect_columns_in_expression(e, columns);
360 }
361 if let Some(query) = &in_expr.query {
362 collect_columns_in_expression(query, columns);
363 }
364 }
365 Expression::Between(between) => {
366 collect_columns_in_expression(&between.this, columns);
367 collect_columns_in_expression(&between.low, columns);
368 collect_columns_in_expression(&between.high, columns);
369 }
370 Expression::Exists(exists) => {
371 collect_columns_in_expression(&exists.this, columns);
372 }
373 Expression::Subquery(subquery) => {
374 collect_columns_in_expression(&subquery.this, columns);
375 }
376 Expression::WindowFunction(wf) => {
377 collect_columns_in_expression(&wf.this, columns);
378 for p in &wf.over.partition_by {
379 collect_columns_in_expression(p, columns);
380 }
381 for o in &wf.over.order_by {
382 collect_columns_in_expression(&o.this, columns);
383 }
384 if let Some(frame) = &wf.over.frame {
385 collect_columns_from_window_bound(&frame.start, columns);
386 if let Some(end) = &frame.end {
387 collect_columns_from_window_bound(end, columns);
388 }
389 }
390 }
391 Expression::Ordered(ord) => {
392 collect_columns_in_expression(&ord.this, columns);
393 }
394 Expression::Paren(paren) => {
395 collect_columns_in_expression(&paren.this, columns);
396 }
397 Expression::Join(join) => {
398 collect_columns_in_expression(&join.this, columns);
399 if let Some(on) = &join.on {
400 collect_columns_in_expression(on, columns);
401 }
402 if let Some(match_condition) = &join.match_condition {
403 collect_columns_in_expression(match_condition, columns);
404 }
405 }
406 _ => {}
407 }
408}
409
410fn collect_columns_from_window_bound(bound: &WindowFrameBound, columns: &mut Vec<ColumnRef>) {
411 match bound {
412 WindowFrameBound::Preceding(expr)
413 | WindowFrameBound::Following(expr)
414 | WindowFrameBound::Value(expr) => collect_columns_in_expression(expr, columns),
415 WindowFrameBound::CurrentRow
416 | WindowFrameBound::UnboundedPreceding
417 | WindowFrameBound::UnboundedFollowing
418 | WindowFrameBound::BarePreceding
419 | WindowFrameBound::BareFollowing => {}
420 }
421}
422
423fn apply_removals(expression: Expression, removals: &[JoinRemoval]) -> Expression {
426 match expression {
427 Expression::Select(select) => {
428 let id = select_identity(&select);
429
430 let mut indices_to_drop: Vec<usize> = removals
432 .iter()
433 .filter(|r| r.select_id == id)
434 .map(|r| r.join_index)
435 .collect();
436 indices_to_drop.sort_unstable();
437 indices_to_drop.dedup();
438
439 let mut new_select = select.clone();
440
441 for &idx in indices_to_drop.iter().rev() {
443 if idx < new_select.joins.len() {
444 new_select.joins.remove(idx);
445 }
446 }
447
448 new_select.expressions = new_select
450 .expressions
451 .into_iter()
452 .map(|e| apply_removals(e, removals))
453 .collect();
454
455 if let Some(ref mut from) = new_select.from {
456 from.expressions = from
457 .expressions
458 .clone()
459 .into_iter()
460 .map(|e| apply_removals(e, removals))
461 .collect();
462 }
463
464 if let Some(ref mut w) = new_select.where_clause {
465 w.this = apply_removals(w.this.clone(), removals);
466 }
467
468 new_select.joins = new_select
470 .joins
471 .into_iter()
472 .map(|mut j| {
473 j.this = apply_removals(j.this, removals);
474 if let Some(on) = j.on {
475 j.on = Some(apply_removals(on, removals));
476 }
477 j
478 })
479 .collect();
480
481 if let Some(ref mut with) = new_select.with {
482 with.ctes = with
483 .ctes
484 .iter()
485 .map(|cte| {
486 let mut new_cte = cte.clone();
487 new_cte.this = apply_removals(new_cte.this, removals);
488 new_cte
489 })
490 .collect();
491 }
492
493 Expression::Select(new_select)
494 }
495 Expression::Subquery(mut subquery) => {
496 subquery.this = apply_removals(subquery.this, removals);
497 Expression::Subquery(subquery)
498 }
499 Expression::Union(mut union) => {
500 union.left = apply_removals(union.left, removals);
501 union.right = apply_removals(union.right, removals);
502 Expression::Union(union)
503 }
504 Expression::Intersect(mut intersect) => {
505 intersect.left = apply_removals(intersect.left, removals);
506 intersect.right = apply_removals(intersect.right, removals);
507 Expression::Intersect(intersect)
508 }
509 Expression::Except(mut except) => {
510 except.left = apply_removals(except.left, removals);
511 except.right = apply_removals(except.right, removals);
512 Expression::Except(except)
513 }
514 other => other,
515 }
516}
517
518#[cfg(test)]
523mod tests {
524 use super::*;
525 use crate::generator::Generator;
526 use crate::parser::Parser;
527
528 fn gen(expr: &Expression) -> String {
529 Generator::new().generate(expr).unwrap()
530 }
531
532 fn parse(sql: &str) -> Expression {
533 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
534 }
535
536 #[test]
541 fn test_eliminate_unused_left_join() {
542 let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b");
543 let result = eliminate_joins(expr);
544 let sql = gen(&result);
545
546 assert!(
549 !sql.contains("JOIN"),
550 "Expected JOIN to be eliminated, got: {}",
551 sql
552 );
553 assert!(
554 sql.contains("SELECT x.a FROM x"),
555 "Expected simple select, got: {}",
556 sql
557 );
558 }
559
560 #[test]
565 fn test_keep_used_left_join() {
566 let expr = parse("SELECT x.a, y.c FROM x LEFT JOIN y ON x.b = y.b");
567 let result = eliminate_joins(expr);
568 let sql = gen(&result);
569
570 assert!(
572 sql.contains("JOIN"),
573 "Expected JOIN to be preserved, got: {}",
574 sql
575 );
576 }
577
578 #[test]
583 fn test_inner_join_not_eliminated() {
584 let expr = parse("SELECT x.a FROM x JOIN y ON x.b = y.b");
585 let result = eliminate_joins(expr);
586 let sql = gen(&result);
587
588 assert!(
591 sql.contains("JOIN"),
592 "Expected INNER JOIN to be preserved, got: {}",
593 sql
594 );
595 }
596
597 #[test]
602 fn test_keep_left_join_column_in_where() {
603 let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b WHERE y.c > 1");
604 let result = eliminate_joins(expr);
605 let sql = gen(&result);
606
607 assert!(
608 sql.contains("JOIN"),
609 "Expected JOIN to be preserved (column in WHERE), got: {}",
610 sql
611 );
612 }
613
614 #[test]
619 fn test_eliminate_one_of_multiple_joins() {
620 let expr =
621 parse("SELECT x.a, z.d FROM x LEFT JOIN y ON x.b = y.b LEFT JOIN z ON x.c = z.c");
622 let result = eliminate_joins(expr);
623 let sql = gen(&result);
624
625 assert!(
628 sql.contains("JOIN"),
629 "Expected at least one JOIN to remain, got: {}",
630 sql
631 );
632 assert!(
633 !sql.contains("JOIN y"),
634 "Expected JOIN y to be removed, got: {}",
635 sql
636 );
637 assert!(sql.contains("z"), "Expected z to remain, got: {}", sql);
638 }
639
640 #[test]
645 fn test_no_joins_unchanged() {
646 let expr = parse("SELECT a FROM x");
647 let original_sql = gen(&expr);
648 let result = eliminate_joins(expr);
649 let result_sql = gen(&result);
650
651 assert_eq!(original_sql, result_sql);
652 }
653
654 #[test]
659 fn test_cross_join_not_eliminated() {
660 let expr = parse("SELECT x.a FROM x CROSS JOIN y");
661 let result = eliminate_joins(expr);
662 let sql = gen(&result);
663
664 assert!(
665 sql.contains("CROSS JOIN"),
666 "Expected CROSS JOIN to be preserved, got: {}",
667 sql
668 );
669 }
670
671 #[test]
676 fn test_skip_with_unqualified_columns() {
677 let expr = parse("SELECT a FROM x LEFT JOIN y ON x.b = y.b");
679 let result = eliminate_joins(expr);
680 let sql = gen(&result);
681
682 assert!(
684 sql.contains("JOIN"),
685 "Expected JOIN to be preserved (unqualified columns), got: {}",
686 sql
687 );
688 }
689
690 #[test]
695 fn test_keep_left_join_column_in_group_by() {
696 let expr = parse("SELECT x.a, COUNT(*) FROM x LEFT JOIN y ON x.b = y.b GROUP BY y.c");
697 let result = eliminate_joins(expr);
698 let sql = gen(&result);
699
700 assert!(
701 sql.contains("JOIN"),
702 "Expected JOIN to be preserved (column in GROUP BY), got: {}",
703 sql
704 );
705 }
706
707 #[test]
712 fn test_keep_left_join_column_in_order_by() {
713 let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b ORDER BY y.c");
714 let result = eliminate_joins(expr);
715 let sql = gen(&result);
716
717 assert!(
718 sql.contains("JOIN"),
719 "Expected JOIN to be preserved (column in ORDER BY), got: {}",
720 sql
721 );
722 }
723
724 #[test]
725 fn test_keep_left_join_used_in_other_join_condition() {
726 let expr =
727 parse("SELECT x.a FROM x LEFT JOIN y ON x.y_id = y.id LEFT JOIN z ON y.id = z.y_id");
728 let result = eliminate_joins(expr);
729 let sql = gen(&result);
730
731 assert!(
732 sql.contains("JOIN y"),
733 "Expected JOIN y to be preserved (used in another JOIN ON), got: {}",
734 sql
735 );
736 }
737}