use crate::ast::*;
pub fn unnest_subqueries(statement: Statement) -> Statement {
match statement {
Statement::Select(sel) => Statement::Select(unnest_select(sel)),
other => other,
}
}
struct AliasGen {
counter: usize,
}
impl AliasGen {
fn new() -> Self {
Self { counter: 0 }
}
fn next(&mut self) -> String {
let alias = format!("_u{}", self.counter);
self.counter += 1;
alias
}
}
fn unnest_select(mut sel: SelectStatement) -> SelectStatement {
let mut alias_gen = AliasGen::new();
if let Some(where_clause) = sel.where_clause.take() {
let (new_where, new_joins) = unnest_where(where_clause, &mut alias_gen);
sel.where_clause = new_where;
sel.joins.extend(new_joins);
}
sel
}
fn unnest_where(expr: Expr, alias_gen: &mut AliasGen) -> (Option<Expr>, Vec<JoinClause>) {
let mut joins = Vec::new();
let residual = unnest_expr(expr, &mut joins, alias_gen);
(residual, joins)
}
fn unnest_expr(expr: Expr, joins: &mut Vec<JoinClause>, alias_gen: &mut AliasGen) -> Option<Expr> {
match expr {
Expr::BinaryOp {
left,
op: BinaryOperator::And,
right,
} => {
let left_result = unnest_expr(*left, joins, alias_gen);
let right_result = unnest_expr(*right, joins, alias_gen);
match (left_result, right_result) {
(Some(l), Some(r)) => Some(Expr::BinaryOp {
left: Box::new(l),
op: BinaryOperator::And,
right: Box::new(r),
}),
(Some(l), None) => Some(l),
(None, Some(r)) => Some(r),
(None, None) => None,
}
}
Expr::Exists { subquery, negated } => {
let subquery_inner = *subquery;
if let Some((join, residual)) =
try_unnest_exists(subquery_inner.clone(), negated, alias_gen)
{
joins.push(join);
residual
} else {
Some(Expr::Exists {
subquery: Box::new(subquery_inner),
negated,
})
}
}
Expr::UnaryOp {
op: UnaryOperator::Not,
expr,
} if matches!(expr.as_ref(), Expr::Exists { negated: false, .. }) => {
if let Expr::Exists { subquery, .. } = *expr {
let subquery_inner = *subquery;
if let Some((join, residual)) =
try_unnest_exists(subquery_inner.clone(), true, alias_gen)
{
joins.push(join);
residual
} else {
Some(Expr::UnaryOp {
op: UnaryOperator::Not,
expr: Box::new(Expr::Exists {
subquery: Box::new(subquery_inner),
negated: false,
}),
})
}
} else {
unreachable!()
}
}
Expr::InSubquery {
expr: lhs,
subquery,
negated,
} => {
let lhs_inner = *lhs;
let subquery_inner = *subquery;
if let Some((join, residual)) = try_unnest_in_subquery(
lhs_inner.clone(),
subquery_inner.clone(),
negated,
alias_gen,
) {
joins.push(join);
residual
} else {
Some(Expr::InSubquery {
expr: Box::new(lhs_inner),
subquery: Box::new(subquery_inner),
negated,
})
}
}
Expr::Nested(inner) => {
let result = unnest_expr(*inner, joins, alias_gen);
result.map(|e| {
if e.is_literal() || matches!(e, Expr::Column { .. }) {
e
} else {
Expr::Nested(Box::new(e))
}
})
}
other => Some(other),
}
}
fn try_unnest_exists(
subquery: Statement,
negated: bool,
alias_gen: &mut AliasGen,
) -> Option<(JoinClause, Option<Expr>)> {
let inner_select = match &subquery {
Statement::Select(sel) => sel,
_ => return None,
};
let inner_where = inner_select.where_clause.as_ref()?;
let (eq_preds, non_eq_preds) = extract_correlation_predicates(inner_where);
if eq_preds.is_empty() {
return None;
}
if !non_eq_preds.is_empty() {
return None;
}
let alias = alias_gen.next();
let on_condition = build_join_on(&eq_preds, &alias);
let derived = build_derived_table_from_exists(subquery, &eq_preds, &alias);
if negated {
let null_check_col = sentinel_column(&alias);
let join = JoinClause {
join_type: JoinType::Left,
table: derived,
on: Some(on_condition),
using: vec![],
};
let residual = Some(Expr::IsNull {
expr: Box::new(null_check_col),
negated: false,
});
Some((join, residual))
} else {
let join = JoinClause {
join_type: JoinType::Inner,
table: derived,
on: Some(on_condition),
using: vec![],
};
Some((join, None))
}
}
fn try_unnest_in_subquery(
lhs: Expr,
subquery: Statement,
negated: bool,
alias_gen: &mut AliasGen,
) -> Option<(JoinClause, Option<Expr>)> {
let inner_select = match &subquery {
Statement::Select(sel) => sel,
_ => return None,
};
if inner_select.columns.len() != 1 {
return None;
}
let alias = alias_gen.next();
let inner_col_alias = "_col0".to_string();
let on_condition = Expr::BinaryOp {
left: Box::new(lhs),
op: BinaryOperator::Eq,
right: Box::new(Expr::Column {
table: Some(alias.clone()),
name: inner_col_alias.clone(),
quote_style: QuoteStyle::None,
table_quote_style: QuoteStyle::None,
}),
};
let derived = build_derived_table_from_in(subquery, &inner_col_alias, &alias);
if negated {
let null_check = Expr::IsNull {
expr: Box::new(Expr::Column {
table: Some(alias.clone()),
name: inner_col_alias,
quote_style: QuoteStyle::None,
table_quote_style: QuoteStyle::None,
}),
negated: false,
};
let join = JoinClause {
join_type: JoinType::Left,
table: derived,
on: Some(on_condition),
using: vec![],
};
Some((join, Some(null_check)))
} else {
let join = JoinClause {
join_type: JoinType::Inner,
table: derived,
on: Some(on_condition),
using: vec![],
};
Some((join, None))
}
}
#[derive(Debug, Clone)]
struct CorrelationPredicate {
outer_col: Expr,
inner_col: Expr,
}
fn extract_correlation_predicates(expr: &Expr) -> (Vec<CorrelationPredicate>, Vec<Expr>) {
let mut eq_preds = Vec::new();
let mut non_eq_preds = Vec::new();
collect_correlation_predicates(expr, &mut eq_preds, &mut non_eq_preds);
(eq_preds, non_eq_preds)
}
fn collect_correlation_predicates(
expr: &Expr,
eq_preds: &mut Vec<CorrelationPredicate>,
non_eq_preds: &mut Vec<Expr>,
) {
match expr {
Expr::BinaryOp {
left,
op: BinaryOperator::And,
right,
} => {
collect_correlation_predicates(left, eq_preds, non_eq_preds);
collect_correlation_predicates(right, eq_preds, non_eq_preds);
}
Expr::BinaryOp {
left,
op: BinaryOperator::Eq,
right,
} => {
if let (Some((l_table, _l_name)), Some((r_table, _r_name))) =
(extract_column_ref(left), extract_column_ref(right))
{
if l_table == r_table {
} else {
eq_preds.push(CorrelationPredicate {
outer_col: *left.clone(),
inner_col: *right.clone(),
});
return;
}
}
if has_potential_outer_reference(expr) {
non_eq_preds.push(expr.clone());
}
}
Expr::BinaryOp {
op:
BinaryOperator::Lt
| BinaryOperator::Gt
| BinaryOperator::LtEq
| BinaryOperator::GtEq
| BinaryOperator::Neq,
..
} => {
if is_cross_table_predicate(expr) {
non_eq_preds.push(expr.clone());
}
}
_ => {}
}
}
fn extract_column_ref(expr: &Expr) -> Option<(String, String)> {
match expr {
Expr::Column {
table: Some(t),
name,
..
} => Some((t.clone(), name.clone())),
_ => None,
}
}
fn is_cross_table_predicate(expr: &Expr) -> bool {
let mut tables = Vec::new();
expr.walk(&mut |e| {
if let Expr::Column { table: Some(t), .. } = e
&& !tables.iter().any(|existing: &String| existing == t)
{
tables.push(t.clone());
}
true
});
tables.len() > 1
}
fn has_potential_outer_reference(expr: &Expr) -> bool {
is_cross_table_predicate(expr)
}
fn build_join_on(preds: &[CorrelationPredicate], alias: &str) -> Expr {
let conditions: Vec<Expr> = preds
.iter()
.map(|p| {
let rewritten_inner = rewrite_column_table(&p.inner_col, alias);
Expr::BinaryOp {
left: Box::new(p.outer_col.clone()),
op: BinaryOperator::Eq,
right: Box::new(rewritten_inner),
}
})
.collect();
and_all(conditions)
}
fn and_all(mut exprs: Vec<Expr>) -> Expr {
assert!(
!exprs.is_empty(),
"and_all requires at least one expression"
);
if exprs.len() == 1 {
return exprs.remove(0);
}
let first = exprs.remove(0);
exprs.into_iter().fold(first, |acc, e| Expr::BinaryOp {
left: Box::new(acc),
op: BinaryOperator::And,
right: Box::new(e),
})
}
fn rewrite_column_table(expr: &Expr, new_table: &str) -> Expr {
match expr {
Expr::Column {
name, quote_style, ..
} => Expr::Column {
table: Some(new_table.to_string()),
name: name.clone(),
quote_style: *quote_style,
table_quote_style: QuoteStyle::None,
},
other => other.clone(),
}
}
fn build_derived_table_from_exists(
subquery: Statement,
eq_preds: &[CorrelationPredicate],
alias: &str,
) -> TableSource {
let mut inner_select = match subquery {
Statement::Select(sel) => sel,
_ => unreachable!("Caller ensures this is a SELECT"),
};
if let Some(where_clause) = inner_select.where_clause.take() {
inner_select.where_clause = strip_correlation_predicates(where_clause, eq_preds);
}
inner_select.distinct = true;
inner_select.columns = vec![SelectItem::Expr {
expr: Expr::Number("1".to_string()),
alias: Some("_sentinel".to_string()),
alias_quote_style: QuoteStyle::None,
}];
TableSource::Subquery {
query: Box::new(Statement::Select(inner_select)),
alias: Some(alias.to_string()),
alias_quote_style: QuoteStyle::None,
}
}
fn build_derived_table_from_in(
subquery: Statement,
col_alias: &str,
table_alias: &str,
) -> TableSource {
let mut inner_select = match subquery {
Statement::Select(sel) => sel,
_ => unreachable!("Caller ensures this is a SELECT"),
};
inner_select.distinct = true;
if let Some(SelectItem::Expr { alias, .. }) = inner_select.columns.first_mut() {
*alias = Some(col_alias.to_string());
}
TableSource::Subquery {
query: Box::new(Statement::Select(inner_select)),
alias: Some(table_alias.to_string()),
alias_quote_style: QuoteStyle::None,
}
}
fn sentinel_column(alias: &str) -> Expr {
Expr::Column {
table: Some(alias.to_string()),
name: "_sentinel".to_string(),
quote_style: QuoteStyle::None,
table_quote_style: QuoteStyle::None,
}
}
fn strip_correlation_predicates(expr: Expr, eq_preds: &[CorrelationPredicate]) -> Option<Expr> {
match expr {
Expr::BinaryOp {
left,
op: BinaryOperator::And,
right,
} => {
let left_result = strip_correlation_predicates(*left, eq_preds);
let right_result = strip_correlation_predicates(*right, eq_preds);
match (left_result, right_result) {
(Some(l), Some(r)) => Some(Expr::BinaryOp {
left: Box::new(l),
op: BinaryOperator::And,
right: Box::new(r),
}),
(Some(l), None) => Some(l),
(None, Some(r)) => Some(r),
(None, None) => None,
}
}
Expr::BinaryOp {
ref left,
op: BinaryOperator::Eq,
ref right,
} => {
for pred in eq_preds {
if (*left.as_ref() == pred.outer_col && *right.as_ref() == pred.inner_col)
|| (*left.as_ref() == pred.inner_col && *right.as_ref() == pred.outer_col)
{
return None; }
}
Some(expr)
}
other => Some(other),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dialects::Dialect;
use crate::generator::generate;
use crate::parser::Parser;
fn parse_and_unnest(sql: &str) -> String {
let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
let unnested = unnest_subqueries(stmt);
generate(&unnested, Dialect::Ansi)
}
#[test]
fn test_exists_to_inner_join() {
let sql = "SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b WHERE b.id = a.id)";
let result = parse_and_unnest(sql);
assert!(
result.contains("INNER JOIN"),
"Expected INNER JOIN in: {result}"
);
assert!(
result.contains("_u0"),
"Expected derived alias _u0 in: {result}"
);
assert!(
!result.contains("EXISTS"),
"Should not contain EXISTS: {result}"
);
}
#[test]
fn test_not_exists_to_left_join() {
let sql = "SELECT a.id FROM a WHERE NOT EXISTS (SELECT 1 FROM b WHERE b.id = a.id)";
let result = parse_and_unnest(sql);
assert!(
result.contains("LEFT JOIN"),
"Expected LEFT JOIN in: {result}"
);
assert!(
result.contains("IS NULL"),
"Expected IS NULL check in: {result}"
);
assert!(
!result.contains("NOT EXISTS"),
"Should not contain NOT EXISTS: {result}"
);
}
#[test]
fn test_in_subquery_to_inner_join() {
let sql = "SELECT a.id FROM a WHERE a.id IN (SELECT b.id FROM b)";
let result = parse_and_unnest(sql);
assert!(
result.contains("INNER JOIN"),
"Expected INNER JOIN in: {result}"
);
assert!(!result.contains(" IN "), "Should not contain IN: {result}");
}
#[test]
fn test_not_in_subquery_to_left_join() {
let sql = "SELECT a.id FROM a WHERE a.id NOT IN (SELECT b.id FROM b)";
let result = parse_and_unnest(sql);
assert!(
result.contains("LEFT JOIN"),
"Expected LEFT JOIN in: {result}"
);
assert!(
result.contains("IS NULL"),
"Expected IS NULL check in: {result}"
);
}
#[test]
fn test_no_correlation_not_unnested() {
let sql = "SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b WHERE b.x > 10)";
let result = parse_and_unnest(sql);
assert!(
result.contains("EXISTS"),
"Uncorrelated EXISTS should remain: {result}"
);
}
#[test]
fn test_non_equality_correlation_not_unnested() {
let sql =
"SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b WHERE b.val < a.val AND b.id = a.id)";
let result = parse_and_unnest(sql);
assert!(
result.contains("EXISTS"),
"Subquery with non-eq correlation should not be unnested: {result}"
);
}
#[test]
fn test_subquery_in_select_not_unnested() {
let sql =
"SELECT COALESCE((SELECT MAX(b.val) FROM b WHERE b.id = a.id), a.val) AS result FROM a";
let result = parse_and_unnest(sql);
assert!(
!result.contains("JOIN"),
"Subquery in SELECT should not become a JOIN: {result}"
);
}
#[test]
fn test_exists_with_additional_where() {
let sql = "SELECT a.id FROM a WHERE a.x > 5 AND EXISTS (SELECT 1 FROM b WHERE b.id = a.id)";
let result = parse_and_unnest(sql);
assert!(
result.contains("INNER JOIN"),
"Expected INNER JOIN in: {result}"
);
assert!(
result.contains("a.x > 5") || result.contains("a.x >"),
"Should keep non-subquery predicate: {result}"
);
}
#[test]
fn test_non_select_statement_unchanged() {
let sql = "INSERT INTO t (a) VALUES (1)";
let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
let result = unnest_subqueries(stmt.clone());
assert_eq!(
format!("{result:?}"),
format!("{stmt:?}"),
"Non-SELECT statements should pass through unchanged"
);
}
#[test]
fn test_exists_multiple_correlations() {
let sql =
"SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b WHERE b.id = a.id AND b.org = a.org)";
let result = parse_and_unnest(sql);
assert!(
result.contains("INNER JOIN"),
"Expected INNER JOIN in: {result}"
);
assert!(
!result.contains("EXISTS"),
"Should not contain EXISTS: {result}"
);
assert!(
result.contains(" AND "),
"ON clause should have AND for multiple correlations: {result}"
);
assert!(result.contains(".id"), "ON should reference id: {result}");
assert!(result.contains(".org"), "ON should reference org: {result}");
}
#[test]
fn test_multiple_subqueries_in_where() {
let sql = "SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b WHERE b.id = a.id) AND a.id IN (SELECT c.id FROM c)";
let result = parse_and_unnest(sql);
assert!(
!result.contains("EXISTS"),
"EXISTS should be unnested: {result}"
);
assert!(!result.contains(" IN "), "IN should be unnested: {result}");
assert!(result.contains("_u0"), "Expected first alias _u0: {result}");
assert!(
result.contains("_u1"),
"Expected second alias _u1: {result}"
);
}
#[test]
fn test_exists_with_inner_residual_where() {
let sql =
"SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b WHERE b.id = a.id AND b.active = 1)";
let result = parse_and_unnest(sql);
assert!(
result.contains("INNER JOIN"),
"Expected INNER JOIN in: {result}"
);
assert!(
!result.contains("EXISTS"),
"Should not contain EXISTS: {result}"
);
assert!(
result.contains("active") && result.contains("1"),
"Inner residual WHERE should be preserved: {result}"
);
}
#[test]
fn test_parenthesized_exists() {
let sql = "SELECT a.id FROM a WHERE (EXISTS (SELECT 1 FROM b WHERE b.id = a.id))";
let result = parse_and_unnest(sql);
assert!(
result.contains("INNER JOIN"),
"Expected INNER JOIN in: {result}"
);
assert!(
!result.contains("EXISTS"),
"Should not contain EXISTS: {result}"
);
}
#[test]
fn test_in_subquery_multi_column_not_unnested() {
let sql = "SELECT a.id FROM a WHERE a.id IN (SELECT b.id, b.name FROM b)";
let result = parse_and_unnest(sql);
assert!(
result.contains(" IN "),
"Multi-column IN should remain: {result}"
);
}
#[test]
fn test_or_with_exists_not_unnested() {
let sql = "SELECT a.id FROM a WHERE a.x > 1 OR EXISTS (SELECT 1 FROM b WHERE b.id = a.id)";
let result = parse_and_unnest(sql);
assert!(
result.contains("EXISTS"),
"EXISTS in OR should remain: {result}"
);
}
#[test]
fn test_scalar_subquery_in_where_not_unnested() {
let sql = "SELECT a.id FROM a WHERE a.val = (SELECT MAX(b.val) FROM b WHERE b.id = a.id)";
let result = parse_and_unnest(sql);
assert!(
!result.contains("JOIN"),
"Scalar subquery in WHERE should not become JOIN: {result}"
);
}
#[test]
fn test_sqlglot_issue_7295_exact_reproducer() {
let sql = "SELECT COALESCE((SELECT MAX(b.val) FROM t AS b WHERE b.val < a.val AND b.id = a.id), a.val) AS result FROM t AS a";
let result = parse_and_unnest(sql);
assert!(
!result.contains("JOIN"),
"Issue #7295 query must NOT be rewritten to JOIN: {result}"
);
assert!(
result.contains("COALESCE"),
"COALESCE should remain: {result}"
);
}
#[test]
fn test_no_where_clause_unchanged() {
let sql = "SELECT a.id FROM a";
let result = parse_and_unnest(sql);
assert_eq!(result, "SELECT a.id FROM a", "No WHERE should be unchanged");
}
#[test]
fn test_where_without_subqueries_unchanged() {
let sql = "SELECT a.id FROM a WHERE a.x > 1 AND a.y = 2";
let result = parse_and_unnest(sql);
assert!(
!result.contains("JOIN"),
"No subqueries, no joins should be added: {result}"
);
assert!(
result.contains("a.x > 1"),
"Original predicates should remain: {result}"
);
}
#[test]
fn test_exists_no_where_not_unnested() {
let sql = "SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b)";
let result = parse_and_unnest(sql);
assert!(
result.contains("EXISTS"),
"EXISTS without inner WHERE should remain: {result}"
);
}
#[test]
fn test_exists_same_table_predicate_not_unnested() {
let sql = "SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b WHERE b.x = b.y)";
let result = parse_and_unnest(sql);
assert!(
result.contains("EXISTS"),
"Same-table predicate is not correlation: {result}"
);
}
#[test]
fn test_exists_produces_distinct_derived_table() {
let sql = "SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b WHERE b.id = a.id)";
let result = parse_and_unnest(sql);
assert!(
result.contains("DISTINCT"),
"Derived table should use DISTINCT: {result}"
);
}
#[test]
fn test_in_produces_distinct_derived_table() {
let sql = "SELECT a.id FROM a WHERE a.id IN (SELECT b.id FROM b)";
let result = parse_and_unnest(sql);
assert!(
result.contains("DISTINCT"),
"IN-derived table should use DISTINCT: {result}"
);
}
#[test]
fn test_not_in_preserves_inner_where() {
let sql = "SELECT a.id FROM a WHERE a.id NOT IN (SELECT b.id FROM b WHERE b.active = 1)";
let result = parse_and_unnest(sql);
assert!(result.contains("LEFT JOIN"), "Expected LEFT JOIN: {result}");
assert!(result.contains("IS NULL"), "Expected IS NULL: {result}");
assert!(
result.contains("active"),
"Inner WHERE should be preserved: {result}"
);
}
#[test]
fn test_alias_gen_sequential() {
let mut alias_gen = AliasGen::new();
assert_eq!(alias_gen.next(), "_u0");
assert_eq!(alias_gen.next(), "_u1");
assert_eq!(alias_gen.next(), "_u2");
}
}