use crate::query_plan::pipeline::ASTTransformer;
use crate::sql::parser::ast::{
ColumnRef, Condition, LogicalOp, OrderByItem, PivotAggregate, QuoteStyle, SelectItem,
SelectStatement, SqlExpression, TableSource, WhenBranch,
};
use anyhow::{anyhow, Result};
pub struct PivotExpander;
impl PivotExpander {
pub fn expand(mut statement: SelectStatement) -> Result<SelectStatement> {
if let Some(ref from_source) = statement.from_source {
match from_source {
TableSource::Pivot {
source,
aggregate,
pivot_column,
pivot_values,
alias,
} => {
return Self::expand_pivot(
source,
aggregate,
pivot_column,
pivot_values,
alias,
);
}
TableSource::DerivedTable { query, .. } => {
let processed_subquery = Self::expand(*query.clone())?;
statement.from_source = Some(TableSource::DerivedTable {
query: Box::new(processed_subquery),
alias: match from_source {
TableSource::DerivedTable { alias, .. } => alias.clone(),
_ => String::new(),
},
});
}
TableSource::Table(_) => {
}
}
}
Ok(statement)
}
pub fn expand_pivot(
source: &TableSource,
aggregate: &PivotAggregate,
pivot_column: &str,
pivot_values: &[String],
alias: &Option<String>,
) -> Result<SelectStatement> {
let (base_table, base_alias, base_subquery) = Self::extract_base_source(source)?;
let group_by_columns = Self::determine_group_by_columns(
&base_table,
&base_alias,
&base_subquery,
pivot_column,
&aggregate.column,
)?;
let mut select_items = Vec::new();
for col in &group_by_columns {
select_items.push(SelectItem::Column {
column: ColumnRef::unquoted(col.clone()),
leading_comments: Vec::new(),
trailing_comment: None,
});
}
for pivot_value in pivot_values {
let case_expr = Self::build_pivot_case_expression(
pivot_column,
pivot_value,
&aggregate.column,
&aggregate.function,
)?;
select_items.push(SelectItem::Expression {
expr: case_expr,
alias: pivot_value.clone(),
leading_comments: Vec::new(),
trailing_comment: None,
});
}
let from_source = if let Some(ref table) = base_table {
Some(TableSource::Table(table.clone()))
} else if let Some(ref subquery) = base_subquery {
Some(TableSource::DerivedTable {
query: subquery.clone(),
alias: base_alias.clone().unwrap_or_default(),
})
} else {
None
};
let mut result = SelectStatement {
distinct: false,
columns: Vec::new(), select_items,
from_source,
#[allow(deprecated)]
from_table: base_table,
#[allow(deprecated)]
from_subquery: base_subquery,
#[allow(deprecated)]
from_function: None,
#[allow(deprecated)]
from_alias: base_alias.or_else(|| alias.clone()),
joins: Vec::new(),
where_clause: None,
order_by: None,
group_by: Some(
group_by_columns
.iter()
.map(|col| SqlExpression::Column(ColumnRef::unquoted(col.clone())))
.collect(),
),
having: None,
qualify: None,
limit: None,
offset: None,
ctes: Vec::new(),
into_table: None,
set_operations: Vec::new(),
leading_comments: Vec::new(),
trailing_comment: None,
};
Ok(result)
}
fn extract_base_source(
source: &TableSource,
) -> Result<(Option<String>, Option<String>, Option<Box<SelectStatement>>)> {
match source {
TableSource::Table(name) => Ok((Some(name.clone()), None, None)),
TableSource::DerivedTable { query, alias } => {
Ok((None, Some(alias.clone()), Some(query.clone())))
}
TableSource::Pivot { .. } => Err(anyhow!("Nested PIVOT operations are not supported")),
}
}
fn determine_group_by_columns(
base_table: &Option<String>,
base_alias: &Option<String>,
base_subquery: &Option<Box<SelectStatement>>,
pivot_column: &str,
aggregate_column: &str,
) -> Result<Vec<String>> {
if let Some(subquery) = base_subquery {
let mut columns = Vec::new();
for item in &subquery.select_items {
match item {
SelectItem::Column { column, .. } => {
let col_name = column.name.clone();
if col_name != pivot_column && col_name != aggregate_column {
columns.push(col_name);
}
}
SelectItem::Expression { alias, .. } => {
if alias != pivot_column && alias != aggregate_column {
columns.push(alias.clone());
}
}
SelectItem::Star { .. } => {
return Err(anyhow!(
"PIVOT with SELECT * is not supported. Please specify columns explicitly."
));
}
SelectItem::StarExclude { .. } => {
return Err(anyhow!(
"PIVOT with SELECT * EXCLUDE is not supported. Please specify columns explicitly."
));
}
}
}
Ok(columns)
} else {
Err(anyhow!(
"PIVOT on table sources requires explicit column specification. \
Use a subquery: SELECT col1, col2, pivot_col, agg_col FROM table"
))
}
}
fn build_pivot_case_expression(
pivot_column: &str,
pivot_value: &str,
aggregate_column: &str,
aggregate_function: &str,
) -> Result<SqlExpression> {
let case_expr = SqlExpression::CaseExpression {
when_branches: vec![WhenBranch {
condition: Box::new(SqlExpression::BinaryOp {
left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
pivot_column.to_string(),
))),
op: "=".to_string(),
right: Box::new(SqlExpression::StringLiteral(pivot_value.to_string())),
}),
result: Box::new(SqlExpression::Column(ColumnRef::unquoted(
aggregate_column.to_string(),
))),
}],
else_branch: Some(Box::new(SqlExpression::Null)),
};
let aggregated = SqlExpression::FunctionCall {
name: aggregate_function.to_uppercase(),
args: vec![case_expr],
distinct: false,
};
Ok(aggregated)
}
}
impl ASTTransformer for PivotExpander {
fn name(&self) -> &str {
"PivotExpander"
}
fn description(&self) -> &str {
"Expands PIVOT operations into CASE expressions with GROUP BY"
}
fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
Self::expand(stmt)
}
}
impl Default for PivotExpander {
fn default() -> Self {
Self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_pivot_case_expression() {
let expr =
PivotExpander::build_pivot_case_expression("FoodName", "Sammich", "AmountEaten", "MAX")
.unwrap();
match expr {
SqlExpression::FunctionCall { name, args, .. } => {
assert_eq!(name, "MAX");
assert_eq!(args.len(), 1);
match &args[0] {
SqlExpression::CaseExpression {
when_branches,
else_branch,
} => {
assert_eq!(when_branches.len(), 1);
assert!(else_branch.is_some());
}
_ => panic!("Expected CaseExpression inside function call"),
}
}
_ => panic!("Expected FunctionCall"),
}
}
#[test]
fn test_determine_group_by_columns_with_subquery() {
let subquery = SelectStatement {
distinct: false,
columns: Vec::new(),
select_items: vec![
SelectItem::Column {
column: ColumnRef::unquoted("Date".to_string()),
leading_comments: Vec::new(),
trailing_comment: None,
},
SelectItem::Column {
column: ColumnRef::unquoted("FoodName".to_string()),
leading_comments: Vec::new(),
trailing_comment: None,
},
SelectItem::Column {
column: ColumnRef::unquoted("AmountEaten".to_string()),
leading_comments: Vec::new(),
trailing_comment: None,
},
],
from_source: None,
#[allow(deprecated)]
from_table: Some("food_eaten".to_string()),
#[allow(deprecated)]
from_subquery: None,
#[allow(deprecated)]
from_function: None,
#[allow(deprecated)]
from_alias: None,
joins: Vec::new(),
where_clause: None,
order_by: None,
group_by: None,
having: None,
qualify: None,
limit: None,
offset: None,
ctes: Vec::new(),
into_table: None,
set_operations: Vec::new(),
leading_comments: Vec::new(),
trailing_comment: None,
};
let columns = PivotExpander::determine_group_by_columns(
&None,
&Some("src".to_string()),
&Some(Box::new(subquery)),
"FoodName",
"AmountEaten",
)
.unwrap();
assert_eq!(columns.len(), 1);
assert_eq!(columns[0], "Date");
}
}