use crate::query_plan::pipeline::ASTTransformer;
use crate::sql::parser::ast::{SelectItem, SelectStatement, SqlExpression};
use anyhow::Result;
use std::collections::HashMap;
use tracing::debug;
pub struct WhereAliasExpander {
expansions: usize,
}
impl WhereAliasExpander {
pub fn new() -> Self {
Self { expansions: 0 }
}
fn extract_aliases(select_items: &[SelectItem]) -> HashMap<String, SqlExpression> {
let mut aliases = HashMap::new();
for item in select_items {
if let SelectItem::Expression { expr, alias, .. } = item {
if !alias.is_empty() {
aliases.insert(alias.clone(), expr.clone());
debug!("Found SELECT alias: {} -> {:?}", alias, expr);
}
}
}
aliases
}
fn expand_expression(
expr: &SqlExpression,
aliases: &HashMap<String, SqlExpression>,
) -> (SqlExpression, bool) {
match expr {
SqlExpression::Column(col_ref) => {
if col_ref.table_prefix.is_none() {
if let Some(alias_expr) = aliases.get(&col_ref.name) {
debug!(
"Expanding alias '{}' in WHERE to: {:?}",
col_ref.name, alias_expr
);
return (alias_expr.clone(), true);
}
}
(expr.clone(), false)
}
SqlExpression::BinaryOp { left, op, right } => {
let (new_left, left_expanded) = Self::expand_expression(left, aliases);
let (new_right, right_expanded) = Self::expand_expression(right, aliases);
let expanded = left_expanded || right_expanded;
(
SqlExpression::BinaryOp {
left: Box::new(new_left),
op: op.clone(),
right: Box::new(new_right),
},
expanded,
)
}
SqlExpression::Not { expr: inner } => {
let (new_expr, expanded) = Self::expand_expression(inner, aliases);
(
SqlExpression::Not {
expr: Box::new(new_expr),
},
expanded,
)
}
SqlExpression::FunctionCall {
name,
args,
distinct,
} => {
let mut expanded = false;
let new_args: Vec<SqlExpression> = args
.iter()
.map(|arg| {
let (new_arg, arg_expanded) = Self::expand_expression(arg, aliases);
expanded = expanded || arg_expanded;
new_arg
})
.collect();
(
SqlExpression::FunctionCall {
name: name.clone(),
args: new_args,
distinct: *distinct,
},
expanded,
)
}
SqlExpression::InList {
expr: inner,
values,
} => {
let (new_expr, expr_expanded) = Self::expand_expression(inner, aliases);
let mut expanded = expr_expanded;
let new_values: Vec<SqlExpression> = values
.iter()
.map(|val| {
let (new_val, val_expanded) = Self::expand_expression(val, aliases);
expanded = expanded || val_expanded;
new_val
})
.collect();
(
SqlExpression::InList {
expr: Box::new(new_expr),
values: new_values,
},
expanded,
)
}
SqlExpression::NotInList {
expr: inner,
values,
} => {
let (new_expr, expr_expanded) = Self::expand_expression(inner, aliases);
let mut expanded = expr_expanded;
let new_values: Vec<SqlExpression> = values
.iter()
.map(|val| {
let (new_val, val_expanded) = Self::expand_expression(val, aliases);
expanded = expanded || val_expanded;
new_val
})
.collect();
(
SqlExpression::NotInList {
expr: Box::new(new_expr),
values: new_values,
},
expanded,
)
}
SqlExpression::Between { expr, lower, upper } => {
let (new_expr, expr_expanded) = Self::expand_expression(expr, aliases);
let (new_lower, lower_expanded) = Self::expand_expression(lower, aliases);
let (new_upper, upper_expanded) = Self::expand_expression(upper, aliases);
let expanded = expr_expanded || lower_expanded || upper_expanded;
(
SqlExpression::Between {
expr: Box::new(new_expr),
lower: Box::new(new_lower),
upper: Box::new(new_upper),
},
expanded,
)
}
SqlExpression::CaseExpression {
when_branches,
else_branch,
} => {
let mut expanded = false;
let new_branches: Vec<_> = when_branches
.iter()
.map(|branch| {
let (new_condition, cond_expanded) =
Self::expand_expression(&branch.condition, aliases);
let (new_result, result_expanded) =
Self::expand_expression(&branch.result, aliases);
expanded = expanded || cond_expanded || result_expanded;
crate::sql::parser::ast::WhenBranch {
condition: Box::new(new_condition),
result: Box::new(new_result),
}
})
.collect();
let new_else = else_branch.as_ref().map(|e| {
let (new_e, else_expanded) = Self::expand_expression(e, aliases);
expanded = expanded || else_expanded;
Box::new(new_e)
});
(
SqlExpression::CaseExpression {
when_branches: new_branches,
else_branch: new_else,
},
expanded,
)
}
SqlExpression::SimpleCaseExpression {
expr,
when_branches,
else_branch,
} => {
let (new_expr, expr_expanded) = Self::expand_expression(expr, aliases);
let mut expanded = expr_expanded;
let new_branches: Vec<_> = when_branches
.iter()
.map(|branch| {
let (new_value, value_expanded) =
Self::expand_expression(&branch.value, aliases);
let (new_result, result_expanded) =
Self::expand_expression(&branch.result, aliases);
expanded = expanded || value_expanded || result_expanded;
crate::sql::parser::ast::SimpleWhenBranch {
value: Box::new(new_value),
result: Box::new(new_result),
}
})
.collect();
let new_else = else_branch.as_ref().map(|e| {
let (new_e, else_expanded) = Self::expand_expression(e, aliases);
expanded = expanded || else_expanded;
Box::new(new_e)
});
(
SqlExpression::SimpleCaseExpression {
expr: Box::new(new_expr),
when_branches: new_branches,
else_branch: new_else,
},
expanded,
)
}
_ => (expr.clone(), false),
}
}
fn expand_where_clause(
&mut self,
where_clause: &mut crate::sql::parser::ast::WhereClause,
aliases: &HashMap<String, SqlExpression>,
) -> bool {
let mut any_expanded = false;
for condition in &mut where_clause.conditions {
let (new_expr, expanded) = Self::expand_expression(&condition.expr, aliases);
if expanded {
condition.expr = new_expr;
any_expanded = true;
self.expansions += 1;
}
}
any_expanded
}
}
impl Default for WhereAliasExpander {
fn default() -> Self {
Self::new()
}
}
impl ASTTransformer for WhereAliasExpander {
fn name(&self) -> &str {
"WhereAliasExpander"
}
fn description(&self) -> &str {
"Expands SELECT aliases in WHERE clauses to their full expressions"
}
fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
if stmt.where_clause.is_none() {
return Ok(stmt);
}
let aliases = Self::extract_aliases(&stmt.select_items);
if aliases.is_empty() {
return Ok(stmt);
}
if let Some(ref mut where_clause) = stmt.where_clause {
let expanded = self.expand_where_clause(where_clause, &aliases);
if expanded {
debug!(
"Expanded {} alias reference(s) in WHERE clause",
self.expansions
);
}
}
Ok(stmt)
}
fn begin(&mut self) -> Result<()> {
self.expansions = 0;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::parser::ast::{ColumnRef, Condition, QuoteStyle, WhereClause};
#[test]
fn test_extract_aliases() {
let double_a_expr = SqlExpression::BinaryOp {
left: Box::new(SqlExpression::Column(ColumnRef {
name: "a".to_string(),
quote_style: QuoteStyle::None,
table_prefix: None,
})),
op: "*".to_string(),
right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
};
let select_items = vec![SelectItem::Expression {
expr: double_a_expr.clone(),
alias: "double_a".to_string(),
leading_comments: vec![],
trailing_comment: None,
}];
let aliases = WhereAliasExpander::extract_aliases(&select_items);
assert_eq!(aliases.len(), 1);
assert!(aliases.contains_key("double_a"));
}
#[test]
fn test_expand_simple_column_reference() {
let aliases = HashMap::from([(
"double_a".to_string(),
SqlExpression::BinaryOp {
left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
op: "*".to_string(),
right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
},
)]);
let expr = SqlExpression::Column(ColumnRef::unquoted("double_a".to_string()));
let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
assert!(changed);
assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
}
#[test]
fn test_expand_in_binary_op() {
let aliases = HashMap::from([(
"double_a".to_string(),
SqlExpression::BinaryOp {
left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
op: "*".to_string(),
right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
},
)]);
let expr = SqlExpression::BinaryOp {
left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
"double_a".to_string(),
))),
op: ">".to_string(),
right: Box::new(SqlExpression::NumberLiteral("10".to_string())),
};
let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
assert!(changed);
if let SqlExpression::BinaryOp { left, op, right } = expanded {
assert_eq!(op, ">");
assert!(matches!(left.as_ref(), SqlExpression::BinaryOp { .. }));
assert!(matches!(
right.as_ref(),
SqlExpression::NumberLiteral(s) if s == "10"
));
} else {
panic!("Expected BinaryOp");
}
}
#[test]
fn test_transform_with_no_where() {
let mut transformer = WhereAliasExpander::new();
let stmt = SelectStatement {
where_clause: None,
..Default::default()
};
let result = transformer.transform(stmt);
assert!(result.is_ok());
}
#[test]
fn test_transform_expands_alias() {
let mut transformer = WhereAliasExpander::new();
let double_a_expr = SqlExpression::BinaryOp {
left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
op: "*".to_string(),
right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
};
let stmt = SelectStatement {
select_items: vec![SelectItem::Expression {
expr: double_a_expr.clone(),
alias: "double_a".to_string(),
leading_comments: vec![],
trailing_comment: None,
}],
where_clause: Some(WhereClause {
conditions: vec![Condition {
expr: SqlExpression::BinaryOp {
left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
"double_a".to_string(),
))),
op: ">".to_string(),
right: Box::new(SqlExpression::NumberLiteral("10".to_string())),
},
connector: None,
}],
}),
..Default::default()
};
let result = transformer.transform(stmt).unwrap();
if let Some(where_clause) = &result.where_clause {
if let SqlExpression::BinaryOp { left, .. } = &where_clause.conditions[0].expr {
assert!(matches!(left.as_ref(), SqlExpression::BinaryOp { .. }));
} else {
panic!("Expected BinaryOp in WHERE");
}
} else {
panic!("Expected WHERE clause");
}
assert_eq!(transformer.expansions, 1);
}
#[test]
fn test_does_not_expand_table_prefixed_columns() {
let aliases = HashMap::from([(
"double_a".to_string(),
SqlExpression::BinaryOp {
left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
op: "*".to_string(),
right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
},
)]);
let expr = SqlExpression::Column(ColumnRef {
name: "double_a".to_string(),
quote_style: QuoteStyle::None,
table_prefix: Some("t".to_string()),
});
let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
assert!(!changed);
assert!(matches!(expanded, SqlExpression::Column(_)));
}
}