use sqlparser::ast::{self, Expr, SetExpr};
use crate::error::{Result, SqlError};
use crate::functions::registry::FunctionRegistry;
use crate::parser::normalize::normalize_ident;
use crate::types::*;
pub struct SubqueryExtraction {
pub joins: Vec<SubqueryJoin>,
pub remaining_where: Option<Expr>,
}
pub struct SubqueryJoin {
pub outer_column: String,
pub inner_plan: SqlPlan,
pub inner_column: String,
pub join_type: JoinType,
}
fn canonical_aggregate_key(function: &str, field: &str) -> String {
format!("{function}({field})")
}
pub fn extract_subqueries(
expr: &Expr,
catalog: &dyn SqlCatalog,
functions: &FunctionRegistry,
) -> Result<SubqueryExtraction> {
let mut joins = Vec::new();
let remaining = extract_recursive(expr, &mut joins, catalog, functions)?;
Ok(SubqueryExtraction {
joins,
remaining_where: remaining,
})
}
fn extract_recursive(
expr: &Expr,
joins: &mut Vec<SubqueryJoin>,
catalog: &dyn SqlCatalog,
functions: &FunctionRegistry,
) -> Result<Option<Expr>> {
match expr {
Expr::BinaryOp {
left,
op: ast::BinaryOperator::And,
right,
} => {
let left_remaining = extract_recursive(left, joins, catalog, functions)?;
let right_remaining = extract_recursive(right, joins, catalog, functions)?;
match (left_remaining, right_remaining) {
(None, None) => Ok(None),
(Some(l), None) => Ok(Some(l)),
(None, Some(r)) => Ok(Some(r)),
(Some(l), Some(r)) => Ok(Some(Expr::BinaryOp {
left: Box::new(l),
op: ast::BinaryOperator::And,
right: Box::new(r),
})),
}
}
Expr::InSubquery {
expr: outer_expr,
subquery,
negated,
} => {
if let Some(join) =
try_plan_in_subquery(outer_expr, subquery, *negated, catalog, functions)?
{
joins.push(join);
Ok(None) } else {
Ok(Some(expr.clone()))
}
}
Expr::BinaryOp { left, op, right } if is_comparison_op(op) => {
if let Expr::Subquery(subquery) = right.as_ref() {
if let Some(scalar) = try_plan_scalar_subquery(subquery, catalog, functions)? {
joins.push(scalar.join);
Ok(Some(Expr::BinaryOp {
left: left.clone(),
op: op.clone(),
right: Box::new(scalar.replacement_expr),
}))
} else {
Ok(Some(expr.clone()))
}
} else {
Ok(Some(expr.clone()))
}
}
Expr::Exists { subquery, negated } => {
if let Some(join) = try_plan_exists_subquery(subquery, *negated, catalog, functions)? {
joins.push(join);
Ok(None)
} else {
Ok(Some(expr.clone()))
}
}
Expr::Nested(inner) => extract_recursive(inner, joins, catalog, functions),
_ => Ok(Some(expr.clone())),
}
}
fn try_plan_in_subquery(
outer_expr: &Expr,
subquery: &ast::Query,
negated: bool,
catalog: &dyn SqlCatalog,
functions: &FunctionRegistry,
) -> Result<Option<SubqueryJoin>> {
let outer_col = match outer_expr {
Expr::Identifier(ident) => normalize_ident(ident),
Expr::CompoundIdentifier(parts) if parts.len() == 2 => normalize_ident(&parts[1]),
_ => return Ok(None), };
let inner_plan = super::select::plan_query(subquery, catalog, functions)?;
let inner_col = extract_single_projected_column(subquery)?;
Ok(Some(SubqueryJoin {
outer_column: outer_col,
inner_plan,
inner_column: inner_col,
join_type: if negated {
JoinType::Anti
} else {
JoinType::Semi
},
}))
}
fn extract_single_projected_column(query: &ast::Query) -> Result<String> {
let select = match &*query.body {
SetExpr::Select(s) => s,
_ => {
return Err(SqlError::Unsupported {
detail: "subquery must be a simple SELECT".into(),
});
}
};
if select.projection.len() != 1 {
return Err(SqlError::Unsupported {
detail: format!(
"subquery must select exactly 1 column, got {}",
select.projection.len()
),
});
}
match &select.projection[0] {
ast::SelectItem::UnnamedExpr(expr) => match expr {
Expr::Identifier(ident) => Ok(normalize_ident(ident)),
Expr::CompoundIdentifier(parts) if parts.len() == 2 => Ok(normalize_ident(&parts[1])),
_ => Err(SqlError::Unsupported {
detail: "subquery projection must be a column reference".into(),
}),
},
ast::SelectItem::ExprWithAlias { alias, .. } => Ok(normalize_ident(alias)),
_ => Err(SqlError::Unsupported {
detail: "subquery projection must be a column reference".into(),
}),
}
}
fn try_plan_exists_subquery(
subquery: &ast::Query,
negated: bool,
catalog: &dyn SqlCatalog,
functions: &FunctionRegistry,
) -> Result<Option<SubqueryJoin>> {
let select = match &*subquery.body {
SetExpr::Select(s) => s,
_ => return Ok(None),
};
let (outer_col, inner_col) = match &select.selection {
Some(expr) => match extract_correlated_eq(expr) {
Some(pair) => pair,
None => return Ok(None),
},
None => return Ok(None),
};
let inner_plan = super::select::plan_query(subquery, catalog, functions)?;
Ok(Some(SubqueryJoin {
outer_column: outer_col,
inner_plan,
inner_column: inner_col,
join_type: if negated {
JoinType::Anti
} else {
JoinType::Semi
},
}))
}
fn extract_correlated_eq(expr: &Expr) -> Option<(String, String)> {
match expr {
Expr::BinaryOp {
left,
op: ast::BinaryOperator::Eq,
right,
} => {
let left_parts = extract_qualified_column(left);
let right_parts = extract_qualified_column(right);
match (left_parts, right_parts) {
(Some((_lt, lc)), Some((_rt, rc))) => {
Some((rc, lc))
}
_ => None,
}
}
Expr::BinaryOp {
left,
op: ast::BinaryOperator::And,
right,
} => extract_correlated_eq(left).or_else(|| extract_correlated_eq(right)),
Expr::Nested(inner) => extract_correlated_eq(inner),
_ => None,
}
}
fn extract_qualified_column(expr: &Expr) -> Option<(String, String)> {
match expr {
Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
Some((normalize_ident(&parts[0]), normalize_ident(&parts[1])))
}
Expr::Identifier(ident) => Some((String::new(), normalize_ident(ident))),
_ => None,
}
}
fn is_comparison_op(op: &ast::BinaryOperator) -> bool {
matches!(
op,
ast::BinaryOperator::Gt
| ast::BinaryOperator::GtEq
| ast::BinaryOperator::Lt
| ast::BinaryOperator::LtEq
| ast::BinaryOperator::Eq
| ast::BinaryOperator::NotEq
)
}
struct ScalarSubqueryResult {
join: SubqueryJoin,
replacement_expr: Expr,
}
fn try_plan_scalar_subquery(
subquery: &ast::Query,
catalog: &dyn SqlCatalog,
functions: &FunctionRegistry,
) -> Result<Option<ScalarSubqueryResult>> {
let inner_plan = super::select::plan_query(subquery, catalog, functions)?;
let result_col = match extract_scalar_column(subquery) {
Some(col) => col,
None => return Ok(None),
};
let replacement = Expr::Identifier(ast::Ident::new(&result_col));
Ok(Some(ScalarSubqueryResult {
join: SubqueryJoin {
outer_column: String::new(),
inner_plan,
inner_column: String::new(),
join_type: JoinType::Cross,
},
replacement_expr: replacement,
}))
}
fn extract_scalar_column(query: &ast::Query) -> Option<String> {
let select = match &*query.body {
SetExpr::Select(s) => s,
_ => return None,
};
if select.projection.len() != 1 {
return None;
}
match &select.projection[0] {
ast::SelectItem::ExprWithAlias { alias, .. } => Some(normalize_ident(alias)),
ast::SelectItem::UnnamedExpr(expr) => match expr {
Expr::Identifier(ident) => Some(normalize_ident(ident)),
Expr::CompoundIdentifier(parts) if parts.len() == 2 => Some(normalize_ident(&parts[1])),
Expr::Function(func) => {
let func_name = func
.name
.0
.iter()
.map(|p| match p {
ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
_ => String::new(),
})
.collect::<Vec<_>>()
.join(".")
.to_lowercase();
let arg = match &func.args {
ast::FunctionArguments::List(arg_list) => arg_list
.args
.first()
.and_then(|a| match a {
ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
Expr::Identifier(ident),
)) => Some(normalize_ident(ident)),
ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
Expr::CompoundIdentifier(parts),
)) if parts.len() == 2 => Some(normalize_ident(&parts[1])),
ast::FunctionArg::Unnamed(
ast::FunctionArgExpr::Wildcard
| ast::FunctionArgExpr::QualifiedWildcard(_),
) => Some("all".to_string()),
_ => None,
})
.unwrap_or_else(|| "*".to_string()),
_ => "*".to_string(),
};
Some(canonical_aggregate_key(&func_name, &arg))
}
_ => None,
},
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::extract_scalar_column;
use crate::parser::statement::parse_sql;
use sqlparser::ast::Statement;
#[test]
fn unaliased_scalar_aggregate_uses_canonical_aggregate_key() {
let statements = parse_sql("SELECT AVG(amount) FROM orders").unwrap();
let Statement::Query(query) = &statements[0] else {
panic!("expected query");
};
assert_eq!(extract_scalar_column(query), Some("avg(amount)".into()));
}
}