use crate::expressions::*;
use crate::scope::traverse_scope;
use crate::scope::ColumnRef;
use std::collections::HashMap;
pub fn eliminate_joins(expression: Expression) -> Expression {
let scopes = traverse_scope(&expression);
let mut removals: Vec<JoinRemoval> = Vec::new();
for mut scope in scopes {
if !scope.unqualified_columns().is_empty() {
continue;
}
let select = match &scope.expression {
Expression::Select(s) => s.clone(),
_ => continue,
};
let joins = &select.joins;
if joins.is_empty() {
continue;
}
for (idx, join) in joins.iter().enumerate().rev() {
if is_semi_or_anti_join(join) {
continue;
}
let alias = join_alias_or_name(join);
let alias = match alias {
Some(a) => a,
None => continue,
};
if should_eliminate_join(&mut scope, &select, idx, join, &alias) {
removals.push(JoinRemoval {
select_id: select_identity(&select),
join_index: idx,
source_alias: alias,
});
}
}
}
if removals.is_empty() {
return expression;
}
apply_removals(expression, &removals)
}
struct JoinRemoval {
select_id: SelectIdentity,
join_index: usize,
#[allow(dead_code)]
source_alias: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct SelectIdentity {
num_expressions: usize,
num_joins: usize,
first_expr_debug: String,
}
fn select_identity(select: &Select) -> SelectIdentity {
SelectIdentity {
num_expressions: select.expressions.len(),
num_joins: select.joins.len(),
first_expr_debug: select
.expressions
.first()
.map(|e| format!("{:?}", e))
.unwrap_or_default(),
}
}
fn is_semi_or_anti_join(join: &Join) -> bool {
matches!(
join.kind,
JoinKind::Semi
| JoinKind::Anti
| JoinKind::LeftSemi
| JoinKind::LeftAnti
| JoinKind::RightSemi
| JoinKind::RightAnti
)
}
fn join_alias_or_name(join: &Join) -> Option<String> {
get_table_alias_or_name(&join.this)
}
fn get_table_alias_or_name(expr: &Expression) -> Option<String> {
match expr {
Expression::Table(table) => {
if let Some(ref alias) = table.alias {
Some(alias.name.clone())
} else {
Some(table.name.name.clone())
}
}
Expression::Subquery(subquery) => subquery.alias.as_ref().map(|a| a.name.clone()),
_ => None,
}
}
fn should_eliminate_join(
scope: &mut crate::scope::Scope,
_select: &Select,
_join_index: usize,
join: &Join,
alias: &str,
) -> bool {
if join.kind != JoinKind::Left {
return false;
}
let source_cols = scope.source_columns(alias);
if source_cols.is_empty() {
return true;
}
let mut source_counts: HashMap<(String, String), usize> = HashMap::new();
for col in &source_cols {
if let Some(table) = &col.table {
*source_counts
.entry((table.clone(), col.name.clone()))
.or_insert(0) += 1;
}
}
if let Some(on) = &join.on {
subtract_columns_from_counts(alias, on, &mut source_counts);
}
if let Some(match_condition) = &join.match_condition {
subtract_columns_from_counts(alias, match_condition, &mut source_counts);
}
!source_counts.values().any(|&count| count > 0)
}
fn subtract_columns_from_counts(
alias: &str,
expr: &Expression,
counts: &mut HashMap<(String, String), usize>,
) {
let mut cols: Vec<ColumnRef> = Vec::new();
collect_columns_in_expression(expr, &mut cols);
for col in cols {
if col.table.as_deref() != Some(alias) {
continue;
}
let key = (alias.to_string(), col.name);
if let Some(value) = counts.get_mut(&key) {
if *value > 0 {
*value -= 1;
}
}
}
}
fn collect_columns_in_expression(expr: &Expression, columns: &mut Vec<ColumnRef>) {
match expr {
Expression::Column(col) => {
columns.push(ColumnRef {
table: col.table.as_ref().map(|t| t.name.clone()),
name: col.name.name.clone(),
});
}
Expression::Select(select) => {
for e in &select.expressions {
collect_columns_in_expression(e, columns);
}
if let Some(from) = &select.from {
for e in &from.expressions {
collect_columns_in_expression(e, columns);
}
}
for join in &select.joins {
collect_columns_in_expression(&join.this, columns);
if let Some(on) = &join.on {
collect_columns_in_expression(on, columns);
}
if let Some(match_condition) = &join.match_condition {
collect_columns_in_expression(match_condition, columns);
}
}
if let Some(where_clause) = &select.where_clause {
collect_columns_in_expression(&where_clause.this, columns);
}
if let Some(group_by) = &select.group_by {
for e in &group_by.expressions {
collect_columns_in_expression(e, columns);
}
}
if let Some(having) = &select.having {
collect_columns_in_expression(&having.this, columns);
}
if let Some(order_by) = &select.order_by {
for o in &order_by.expressions {
collect_columns_in_expression(&o.this, columns);
}
}
if let Some(qualify) = &select.qualify {
collect_columns_in_expression(&qualify.this, columns);
}
if let Some(limit) = &select.limit {
collect_columns_in_expression(&limit.this, columns);
}
if let Some(offset) = &select.offset {
collect_columns_in_expression(&offset.this, columns);
}
}
Expression::Alias(alias) => {
collect_columns_in_expression(&alias.this, columns);
}
Expression::Function(func) => {
for arg in &func.args {
collect_columns_in_expression(arg, columns);
}
}
Expression::AggregateFunction(agg) => {
for arg in &agg.args {
collect_columns_in_expression(arg, columns);
}
}
Expression::And(bin)
| Expression::Or(bin)
| Expression::Eq(bin)
| Expression::Neq(bin)
| Expression::Lt(bin)
| Expression::Lte(bin)
| Expression::Gt(bin)
| Expression::Gte(bin)
| Expression::Add(bin)
| Expression::Sub(bin)
| Expression::Mul(bin)
| Expression::Div(bin)
| Expression::Mod(bin)
| Expression::BitwiseAnd(bin)
| Expression::BitwiseOr(bin)
| Expression::BitwiseXor(bin)
| Expression::Concat(bin) => {
collect_columns_in_expression(&bin.left, columns);
collect_columns_in_expression(&bin.right, columns);
}
Expression::Like(like) | Expression::ILike(like) => {
collect_columns_in_expression(&like.left, columns);
collect_columns_in_expression(&like.right, columns);
if let Some(escape) = &like.escape {
collect_columns_in_expression(escape, columns);
}
}
Expression::Not(unary) | Expression::Neg(unary) | Expression::BitwiseNot(unary) => {
collect_columns_in_expression(&unary.this, columns);
}
Expression::Case(case) => {
if let Some(operand) = &case.operand {
collect_columns_in_expression(operand, columns);
}
for (when_expr, then_expr) in &case.whens {
collect_columns_in_expression(when_expr, columns);
collect_columns_in_expression(then_expr, columns);
}
if let Some(else_) = &case.else_ {
collect_columns_in_expression(else_, columns);
}
}
Expression::Cast(cast) => {
collect_columns_in_expression(&cast.this, columns);
}
Expression::In(in_expr) => {
collect_columns_in_expression(&in_expr.this, columns);
for e in &in_expr.expressions {
collect_columns_in_expression(e, columns);
}
if let Some(query) = &in_expr.query {
collect_columns_in_expression(query, columns);
}
}
Expression::Between(between) => {
collect_columns_in_expression(&between.this, columns);
collect_columns_in_expression(&between.low, columns);
collect_columns_in_expression(&between.high, columns);
}
Expression::Exists(exists) => {
collect_columns_in_expression(&exists.this, columns);
}
Expression::Subquery(subquery) => {
collect_columns_in_expression(&subquery.this, columns);
}
Expression::WindowFunction(wf) => {
collect_columns_in_expression(&wf.this, columns);
for p in &wf.over.partition_by {
collect_columns_in_expression(p, columns);
}
for o in &wf.over.order_by {
collect_columns_in_expression(&o.this, columns);
}
if let Some(frame) = &wf.over.frame {
collect_columns_from_window_bound(&frame.start, columns);
if let Some(end) = &frame.end {
collect_columns_from_window_bound(end, columns);
}
}
}
Expression::Ordered(ord) => {
collect_columns_in_expression(&ord.this, columns);
}
Expression::Paren(paren) => {
collect_columns_in_expression(&paren.this, columns);
}
Expression::Join(join) => {
collect_columns_in_expression(&join.this, columns);
if let Some(on) = &join.on {
collect_columns_in_expression(on, columns);
}
if let Some(match_condition) = &join.match_condition {
collect_columns_in_expression(match_condition, columns);
}
}
_ => {}
}
}
fn collect_columns_from_window_bound(bound: &WindowFrameBound, columns: &mut Vec<ColumnRef>) {
match bound {
WindowFrameBound::Preceding(expr)
| WindowFrameBound::Following(expr)
| WindowFrameBound::Value(expr) => collect_columns_in_expression(expr, columns),
WindowFrameBound::CurrentRow
| WindowFrameBound::UnboundedPreceding
| WindowFrameBound::UnboundedFollowing
| WindowFrameBound::BarePreceding
| WindowFrameBound::BareFollowing => {}
}
}
fn apply_removals(expression: Expression, removals: &[JoinRemoval]) -> Expression {
match expression {
Expression::Select(select) => {
let id = select_identity(&select);
let mut indices_to_drop: Vec<usize> = removals
.iter()
.filter(|r| r.select_id == id)
.map(|r| r.join_index)
.collect();
indices_to_drop.sort_unstable();
indices_to_drop.dedup();
let mut new_select = select.clone();
for &idx in indices_to_drop.iter().rev() {
if idx < new_select.joins.len() {
new_select.joins.remove(idx);
}
}
new_select.expressions = new_select
.expressions
.into_iter()
.map(|e| apply_removals(e, removals))
.collect();
if let Some(ref mut from) = new_select.from {
from.expressions = from
.expressions
.clone()
.into_iter()
.map(|e| apply_removals(e, removals))
.collect();
}
if let Some(ref mut w) = new_select.where_clause {
w.this = apply_removals(w.this.clone(), removals);
}
new_select.joins = new_select
.joins
.into_iter()
.map(|mut j| {
j.this = apply_removals(j.this, removals);
if let Some(on) = j.on {
j.on = Some(apply_removals(on, removals));
}
j
})
.collect();
if let Some(ref mut with) = new_select.with {
with.ctes = with
.ctes
.iter()
.map(|cte| {
let mut new_cte = cte.clone();
new_cte.this = apply_removals(new_cte.this, removals);
new_cte
})
.collect();
}
Expression::Select(new_select)
}
Expression::Subquery(mut subquery) => {
subquery.this = apply_removals(subquery.this, removals);
Expression::Subquery(subquery)
}
Expression::Union(mut union) => {
let left = std::mem::replace(&mut union.left, Expression::Null(Null));
union.left = apply_removals(left, removals);
let right = std::mem::replace(&mut union.right, Expression::Null(Null));
union.right = apply_removals(right, removals);
Expression::Union(union)
}
Expression::Intersect(mut intersect) => {
let left = std::mem::replace(&mut intersect.left, Expression::Null(Null));
intersect.left = apply_removals(left, removals);
let right = std::mem::replace(&mut intersect.right, Expression::Null(Null));
intersect.right = apply_removals(right, removals);
Expression::Intersect(intersect)
}
Expression::Except(mut except) => {
let left = std::mem::replace(&mut except.left, Expression::Null(Null));
except.left = apply_removals(left, removals);
let right = std::mem::replace(&mut except.right, Expression::Null(Null));
except.right = apply_removals(right, removals);
Expression::Except(except)
}
other => other,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::generator::Generator;
use crate::parser::Parser;
fn gen(expr: &Expression) -> String {
Generator::new().generate(expr).unwrap()
}
fn parse(sql: &str) -> Expression {
Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
}
#[test]
fn test_eliminate_unused_left_join() {
let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b");
let result = eliminate_joins(expr);
let sql = gen(&result);
assert!(
!sql.contains("JOIN"),
"Expected JOIN to be eliminated, got: {}",
sql
);
assert!(
sql.contains("SELECT x.a FROM x"),
"Expected simple select, got: {}",
sql
);
}
#[test]
fn test_keep_used_left_join() {
let expr = parse("SELECT x.a, y.c FROM x LEFT JOIN y ON x.b = y.b");
let result = eliminate_joins(expr);
let sql = gen(&result);
assert!(
sql.contains("JOIN"),
"Expected JOIN to be preserved, got: {}",
sql
);
}
#[test]
fn test_inner_join_not_eliminated() {
let expr = parse("SELECT x.a FROM x JOIN y ON x.b = y.b");
let result = eliminate_joins(expr);
let sql = gen(&result);
assert!(
sql.contains("JOIN"),
"Expected INNER JOIN to be preserved, got: {}",
sql
);
}
#[test]
fn test_keep_left_join_column_in_where() {
let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b WHERE y.c > 1");
let result = eliminate_joins(expr);
let sql = gen(&result);
assert!(
sql.contains("JOIN"),
"Expected JOIN to be preserved (column in WHERE), got: {}",
sql
);
}
#[test]
fn test_eliminate_one_of_multiple_joins() {
let expr =
parse("SELECT x.a, z.d FROM x LEFT JOIN y ON x.b = y.b LEFT JOIN z ON x.c = z.c");
let result = eliminate_joins(expr);
let sql = gen(&result);
assert!(
sql.contains("JOIN"),
"Expected at least one JOIN to remain, got: {}",
sql
);
assert!(
!sql.contains("JOIN y"),
"Expected JOIN y to be removed, got: {}",
sql
);
assert!(sql.contains("z"), "Expected z to remain, got: {}", sql);
}
#[test]
fn test_no_joins_unchanged() {
let expr = parse("SELECT a FROM x");
let original_sql = gen(&expr);
let result = eliminate_joins(expr);
let result_sql = gen(&result);
assert_eq!(original_sql, result_sql);
}
#[test]
fn test_cross_join_not_eliminated() {
let expr = parse("SELECT x.a FROM x CROSS JOIN y");
let result = eliminate_joins(expr);
let sql = gen(&result);
assert!(
sql.contains("CROSS JOIN"),
"Expected CROSS JOIN to be preserved, got: {}",
sql
);
}
#[test]
fn test_skip_with_unqualified_columns() {
let expr = parse("SELECT a FROM x LEFT JOIN y ON x.b = y.b");
let result = eliminate_joins(expr);
let sql = gen(&result);
assert!(
sql.contains("JOIN"),
"Expected JOIN to be preserved (unqualified columns), got: {}",
sql
);
}
#[test]
fn test_keep_left_join_column_in_group_by() {
let expr = parse("SELECT x.a, COUNT(*) FROM x LEFT JOIN y ON x.b = y.b GROUP BY y.c");
let result = eliminate_joins(expr);
let sql = gen(&result);
assert!(
sql.contains("JOIN"),
"Expected JOIN to be preserved (column in GROUP BY), got: {}",
sql
);
}
#[test]
fn test_keep_left_join_column_in_order_by() {
let expr = parse("SELECT x.a FROM x LEFT JOIN y ON x.b = y.b ORDER BY y.c");
let result = eliminate_joins(expr);
let sql = gen(&result);
assert!(
sql.contains("JOIN"),
"Expected JOIN to be preserved (column in ORDER BY), got: {}",
sql
);
}
#[test]
fn test_keep_left_join_used_in_other_join_condition() {
let expr =
parse("SELECT x.a FROM x LEFT JOIN y ON x.y_id = y.id LEFT JOIN z ON y.id = z.y_id");
let result = eliminate_joins(expr);
let sql = gen(&result);
assert!(
sql.contains("JOIN y"),
"Expected JOIN y to be preserved (used in another JOIN ON), got: {}",
sql
);
}
}