use crate::query_plan::pipeline::ASTTransformer;
use crate::sql::parser::ast::{
CTEType, ColumnRef, SelectItem, SelectStatement, SqlExpression, TableSource,
};
use anyhow::Result;
use std::collections::{HashMap, HashSet};
use tracing::debug;
pub const HIDDEN_ORDERBY_PREFIX: &str = "__hidden_orderby_";
pub struct OrderByAliasTransformer {
alias_counter: usize,
hidden_counter: usize,
}
impl OrderByAliasTransformer {
pub fn new() -> Self {
Self {
alias_counter: 0,
hidden_counter: 0,
}
}
fn is_aggregate_function(expr: &SqlExpression) -> bool {
matches!(
expr,
SqlExpression::FunctionCall { name, .. }
if matches!(
name.to_uppercase().as_str(),
"COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
)
)
}
fn generate_alias(&mut self) -> String {
self.alias_counter += 1;
format!("__orderby_agg_{}", self.alias_counter)
}
fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
match expr {
SqlExpression::FunctionCall { name, args, .. } => {
let args_str = args
.iter()
.map(|arg| match arg {
SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
SqlExpression::StringLiteral(s) => format!("'{}'", s).to_uppercase(),
SqlExpression::NumberLiteral(n) => n.to_uppercase(),
_ => format!("{:?}", arg).to_uppercase(), })
.collect::<Vec<_>>()
.join(", ");
format!("{}({})", name.to_uppercase(), args_str)
}
_ => String::new(),
}
}
fn build_aggregate_map(
&mut self,
select_items: &mut Vec<SelectItem>,
) -> HashMap<String, String> {
let mut aggregate_map = HashMap::new();
for item in select_items.iter_mut() {
if let SelectItem::Expression { expr, alias, .. } = item {
if Self::is_aggregate_function(expr) {
let normalized = Self::normalize_aggregate_expr(expr);
if alias.is_empty() {
*alias = self.generate_alias();
debug!(
"Generated alias '{}' for aggregate in ORDER BY: {}",
alias, normalized
);
}
debug!("Mapped aggregate '{}' to alias '{}'", normalized, alias);
aggregate_map.insert(normalized, alias.clone());
}
}
}
aggregate_map
}
fn expression_to_string(expr: &SqlExpression) -> String {
match expr {
SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
SqlExpression::StringLiteral(s) => format!("'{}'", s),
SqlExpression::FunctionCall { name, args, .. } => {
let args_str = args
.iter()
.map(|arg| Self::expression_to_string(arg))
.collect::<Vec<_>>()
.join(", ");
format!("{}({})", name.to_uppercase(), args_str)
}
_ => "expr".to_string(), }
}
fn extract_aggregate_from_order_column(column_name: &str) -> Option<String> {
let upper = column_name.to_uppercase();
if (upper.starts_with("COUNT(") && upper.ends_with(')'))
|| (upper.starts_with("SUM(") && upper.ends_with(')'))
|| (upper.starts_with("AVG(") && upper.ends_with(')'))
|| (upper.starts_with("MIN(") && upper.ends_with(')'))
|| (upper.starts_with("MAX(") && upper.ends_with(')'))
|| (upper.starts_with("COUNT_DISTINCT(") && upper.ends_with(')'))
{
Some(upper)
} else {
None
}
}
}
impl Default for OrderByAliasTransformer {
fn default() -> Self {
Self::new()
}
}
impl ASTTransformer for OrderByAliasTransformer {
fn name(&self) -> &str {
"OrderByAliasTransformer"
}
fn description(&self) -> &str {
"Rewrites ORDER BY aggregate expressions to use SELECT aliases"
}
fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
self.transform_statement(stmt)
}
}
impl OrderByAliasTransformer {
#[allow(deprecated)]
fn transform_statement(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
for cte in stmt.ctes.iter_mut() {
if let CTEType::Standard(ref mut inner) = cte.cte_type {
let taken = std::mem::take(inner);
*inner = self.transform_statement(taken)?;
}
}
if let Some(TableSource::DerivedTable { query, .. }) = stmt.from_source.as_mut() {
let taken = std::mem::take(query.as_mut());
**query = self.transform_statement(taken)?;
}
if let Some(subq) = stmt.from_subquery.as_mut() {
let taken = std::mem::take(subq.as_mut());
**subq = self.transform_statement(taken)?;
}
for (_op, rhs) in stmt.set_operations.iter_mut() {
let taken = std::mem::take(rhs.as_mut());
**rhs = self.transform_statement(taken)?;
}
self.apply_rewrite(&mut stmt);
Ok(stmt)
}
fn apply_rewrite(&mut self, stmt: &mut SelectStatement) {
if stmt.order_by.is_none() {
return;
}
let aggregate_map = self.build_aggregate_map(&mut stmt.select_items);
if !aggregate_map.is_empty() {
if let Some(order_by) = stmt.order_by.as_mut() {
let mut modified = false;
for order_col in order_by.iter_mut() {
let expr_str = Self::expression_to_string(&order_col.expr);
if let Some(normalized) = Self::extract_aggregate_from_order_column(&expr_str) {
if let Some(alias) = aggregate_map.get(&normalized) {
debug!("Rewriting ORDER BY '{}' to use alias '{}'", expr_str, alias);
order_col.expr =
SqlExpression::Column(ColumnRef::unquoted(alias.clone()));
modified = true;
}
}
}
if modified {
debug!(
"Rewrote ORDER BY to use {} aggregate alias(es)",
aggregate_map.len()
);
}
}
}
self.promote_hidden_order_by_columns(stmt);
}
fn promote_hidden_order_by_columns(&mut self, stmt: &mut SelectStatement) {
let order_by = match stmt.order_by.as_mut() {
Some(o) if !o.is_empty() => o,
_ => return,
};
if stmt
.select_items
.iter()
.any(|i| matches!(i, SelectItem::Star { .. } | SelectItem::StarExclude { .. }))
{
return;
}
let mut visible: HashSet<String> = HashSet::new();
for item in stmt.select_items.iter() {
match item {
SelectItem::Column { column, .. } => {
visible.insert(column.name.to_lowercase());
}
SelectItem::Expression { alias, .. } if !alias.is_empty() => {
visible.insert(alias.to_lowercase());
}
_ => {}
}
}
let mut promoted_columns: HashMap<String, String> = HashMap::new();
let mut promoted_exprs: HashMap<String, String> = HashMap::new();
for order_col in order_by.iter_mut() {
if let SqlExpression::Column(c) = &order_col.expr {
if visible.contains(&c.name.to_lowercase()) {
continue;
}
}
let expr_to_promote = order_col.expr.clone();
let (dedup_key, is_column) = match &expr_to_promote {
SqlExpression::Column(c) => (c.name.to_lowercase(), true),
other => (format!("{:?}", other), false),
};
let existing_alias = if is_column {
promoted_columns.get(&dedup_key).cloned()
} else {
promoted_exprs.get(&dedup_key).cloned()
};
let hidden_alias = if let Some(alias) = existing_alias {
alias
} else {
self.hidden_counter += 1;
let alias = format!("{}{}", HIDDEN_ORDERBY_PREFIX, self.hidden_counter);
debug!(
"Promoting ORDER BY expression as hidden SELECT item '{}': {:?}",
alias, expr_to_promote
);
stmt.select_items.push(SelectItem::Expression {
expr: expr_to_promote,
alias: alias.clone(),
leading_comments: Vec::new(),
trailing_comment: None,
});
if is_column {
promoted_columns.insert(dedup_key, alias.clone());
} else {
promoted_exprs.insert(dedup_key, alias.clone());
}
alias
};
order_col.expr = SqlExpression::Column(ColumnRef::unquoted(hidden_alias));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::parser::ast::{ColumnRef, QuoteStyle, SortDirection};
#[test]
fn test_extract_aggregate_from_order_column() {
assert_eq!(
OrderByAliasTransformer::extract_aggregate_from_order_column("SUM(sales_amount)"),
Some("SUM(SALES_AMOUNT)".to_string())
);
assert_eq!(
OrderByAliasTransformer::extract_aggregate_from_order_column("COUNT(*)"),
Some("COUNT(*)".to_string())
);
assert_eq!(
OrderByAliasTransformer::extract_aggregate_from_order_column("region"),
None
);
assert_eq!(
OrderByAliasTransformer::extract_aggregate_from_order_column("total"),
None
);
}
#[test]
fn test_normalize_aggregate_expr() {
let expr = SqlExpression::FunctionCall {
name: "SUM".to_string(),
args: vec![SqlExpression::Column(ColumnRef {
name: "sales_amount".to_string(),
quote_style: QuoteStyle::None,
table_prefix: None,
})],
distinct: false,
};
assert_eq!(
OrderByAliasTransformer::normalize_aggregate_expr(&expr),
"SUM(SALES_AMOUNT)"
);
}
#[test]
fn test_is_aggregate_function() {
let sum_expr = SqlExpression::FunctionCall {
name: "SUM".to_string(),
args: vec![],
distinct: false,
};
assert!(OrderByAliasTransformer::is_aggregate_function(&sum_expr));
let upper_expr = SqlExpression::FunctionCall {
name: "UPPER".to_string(),
args: vec![],
distinct: false,
};
assert!(!OrderByAliasTransformer::is_aggregate_function(&upper_expr));
}
}