use super::{PostprocessingError, PostprocessingResult, PostprocessingStep};
use crate::base::{
database::{group_by_util::aggregate_columns, Column, OwnedColumn, OwnedTable},
map::{indexmap, IndexMap, IndexSet},
scalar::Scalar,
};
use alloc::{boxed::Box, format, string::ToString, vec, vec::Vec};
use bumpalo::Bump;
use itertools::{izip, Itertools};
use proof_of_sql_parser::{
intermediate_ast::{AggregationOperator, AliasedResultExpr, Expression},
Identifier,
};
use serde::{Deserialize, Serialize};
use sqlparser::ast::Ident;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GroupByPostprocessing {
remainder_exprs: Vec<AliasedResultExpr>,
group_by_identifiers: Vec<Ident>,
aggregation_exprs: Vec<(AggregationOperator, Expression, Ident)>,
}
fn contains_nested_aggregation(expr: &Expression, is_agg: bool) -> bool {
match expr {
Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => false,
Expression::Aggregation { expr, .. } => is_agg || contains_nested_aggregation(expr, true),
Expression::Binary { left, right, .. } => {
contains_nested_aggregation(left, is_agg) || contains_nested_aggregation(right, is_agg)
}
Expression::Unary { expr, .. } => contains_nested_aggregation(expr, is_agg),
}
}
fn get_free_identifiers_from_expr(expr: &Expression) -> IndexSet<Ident> {
match expr {
Expression::Column(identifier) => IndexSet::from_iter([(*identifier).into()]),
Expression::Literal(_) | Expression::Aggregation { .. } | Expression::Wildcard => {
IndexSet::default()
}
Expression::Binary { left, right, .. } => {
let mut left_identifiers = get_free_identifiers_from_expr(left);
let right_identifiers = get_free_identifiers_from_expr(right);
left_identifiers.extend(right_identifiers);
left_identifiers
}
Expression::Unary { expr, .. } => get_free_identifiers_from_expr(expr),
}
}
fn get_aggregate_and_remainder_expressions(
expr: Expression,
aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Ident>,
) -> Result<Expression, PostprocessingError> {
match expr {
Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => Ok(expr),
Expression::Aggregation { op, expr } => {
let key = (op, (*expr));
if let Some(ident) = aggregation_expr_map.get(&key) {
let identifier = Identifier::try_from(ident.clone()).map_err(|e| {
PostprocessingError::IdentifierConversionError {
error: format!("Failed to convert Ident to Identifier: {e}"),
}
})?;
Ok(Expression::Column(identifier))
} else {
let new_ident = Ident {
value: format!("__col_agg_{}", aggregation_expr_map.len()),
quote_style: None,
};
let new_identifier = Identifier::try_from(new_ident.clone()).map_err(|e| {
PostprocessingError::IdentifierConversionError {
error: format!("Failed to convert Ident to Identifier: {e}"),
}
})?;
aggregation_expr_map.insert(key, new_ident);
Ok(Expression::Column(new_identifier))
}
}
Expression::Binary { op, left, right } => {
let left_remainder =
get_aggregate_and_remainder_expressions(*left, aggregation_expr_map);
let right_remainder =
get_aggregate_and_remainder_expressions(*right, aggregation_expr_map);
Ok(Expression::Binary {
op,
left: Box::new(left_remainder?),
right: Box::new(right_remainder?),
})
}
Expression::Unary { op, expr } => {
let remainder = get_aggregate_and_remainder_expressions(*expr, aggregation_expr_map);
Ok(Expression::Unary {
op,
expr: Box::new(remainder?),
})
}
}
}
fn check_and_get_aggregation_and_remainder(
expr: AliasedResultExpr,
group_by_identifiers: &[Ident],
aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Ident>,
) -> PostprocessingResult<AliasedResultExpr> {
let free_identifiers = get_free_identifiers_from_expr(&expr.expr);
let group_by_identifier_set = group_by_identifiers
.iter()
.cloned()
.collect::<IndexSet<_>>();
if contains_nested_aggregation(&expr.expr, false) {
return Err(PostprocessingError::NestedAggregationInGroupByClause {
error: format!("Nested aggregations found {:?}", expr.expr),
});
}
if free_identifiers.is_subset(&group_by_identifier_set) {
let remainder = get_aggregate_and_remainder_expressions(*expr.expr, aggregation_expr_map);
Ok(AliasedResultExpr {
alias: expr.alias,
expr: Box::new(remainder?),
})
} else {
let diff = free_identifiers
.difference(&group_by_identifier_set)
.next()
.unwrap();
Err(
PostprocessingError::IdentNotInAggregationOperatorOrGroupByClause {
column: diff.clone(),
},
)
}
}
impl GroupByPostprocessing {
pub fn try_new(
by_ids: Vec<Ident>,
aliased_exprs: Vec<AliasedResultExpr>,
) -> PostprocessingResult<Self> {
let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Ident> =
IndexMap::default();
let remainder_exprs: Vec<AliasedResultExpr> = aliased_exprs
.into_iter()
.map(|aliased_expr| -> PostprocessingResult<_> {
check_and_get_aggregation_and_remainder(
aliased_expr,
&by_ids,
&mut aggregation_expr_map,
)
})
.collect::<PostprocessingResult<Vec<AliasedResultExpr>>>()?;
let group_by_identifiers = Vec::from_iter(IndexSet::from_iter(by_ids));
Ok(Self {
remainder_exprs,
group_by_identifiers,
aggregation_exprs: aggregation_expr_map
.into_iter()
.map(|((op, expr), id)| (op, expr, id))
.collect(),
})
}
#[must_use]
pub fn group_by(&self) -> &[Ident] {
&self.group_by_identifiers
}
#[must_use]
pub fn remainder_exprs(&self) -> &[AliasedResultExpr] {
&self.remainder_exprs
}
#[must_use]
pub fn aggregation_exprs(&self) -> &[(AggregationOperator, Expression, Ident)] {
&self.aggregation_exprs
}
}
impl<S: Scalar> PostprocessingStep<S> for GroupByPostprocessing {
#[allow(clippy::too_many_lines)]
fn apply(&self, owned_table: OwnedTable<S>) -> PostprocessingResult<OwnedTable<S>> {
let alloc = Bump::new();
let evaluated_columns = self
.aggregation_exprs
.iter()
.map(|(agg_op, expr, id)| -> PostprocessingResult<_> {
let evaluated_owned_column = owned_table.evaluate(expr)?;
Ok((*agg_op, (id.clone(), evaluated_owned_column)))
})
.process_results(|iter| {
iter.fold(
IndexMap::<_, Vec<_>>::default(),
|mut lookup, (key, val)| {
lookup.entry(key).or_default().push(val);
lookup
},
)
})?;
let group_by_ins = self
.group_by_identifiers
.iter()
.map(|id| {
let column = owned_table.inner_table().get(id).ok_or(
PostprocessingError::ColumnNotFound {
column: id.to_string(),
},
)?;
Ok(Column::<S>::from_owned_column(column, &alloc))
})
.collect::<PostprocessingResult<Vec<_>>>()?;
let selection_in = vec![true; owned_table.num_rows()];
let (sum_identifiers, sum_columns): (Vec<_>, Vec<_>) = evaluated_columns
.get(&AggregationOperator::Sum)
.map_or((vec![], vec![]), |tuple| {
tuple
.iter()
.map(|(id, c)| (id.clone(), Column::<S>::from_owned_column(c, &alloc)))
.unzip()
});
let (max_identifiers, max_columns): (Vec<_>, Vec<_>) = evaluated_columns
.get(&AggregationOperator::Max)
.map_or((vec![], vec![]), |tuple| {
tuple
.iter()
.map(|(id, c)| (id.clone(), Column::<S>::from_owned_column(c, &alloc)))
.unzip()
});
let (min_identifiers, min_columns): (Vec<_>, Vec<_>) = evaluated_columns
.get(&AggregationOperator::Min)
.map_or((vec![], vec![]), |tuple| {
tuple
.iter()
.map(|(id, c)| (id.clone(), Column::<S>::from_owned_column(c, &alloc)))
.unzip()
});
let aggregation_results = aggregate_columns(
&alloc,
&group_by_ins,
&sum_columns,
&max_columns,
&min_columns,
&selection_in,
)?;
let group_by_outs = aggregation_results
.group_by_columns
.iter()
.zip(self.group_by_identifiers.iter())
.map(|(column, id)| Ok((id.clone(), OwnedColumn::from(column))));
let sum_outs = izip!(
aggregation_results.sum_columns,
sum_identifiers,
sum_columns,
)
.map(|(c_out, id, c_in)| {
Ok((
id,
OwnedColumn::try_from_scalars(c_out, c_in.column_type())?,
))
});
let max_outs = izip!(
aggregation_results.max_columns,
max_identifiers,
max_columns,
)
.map(|(c_out, id, c_in)| {
Ok((
id,
OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
))
});
let min_outs = izip!(
aggregation_results.min_columns,
min_identifiers,
min_columns,
)
.map(|(c_out, id, c_in)| {
Ok((
id,
OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
))
});
let count_column = OwnedColumn::BigInt(aggregation_results.count_column.to_vec());
let count_outs = evaluated_columns
.get(&AggregationOperator::Count)
.into_iter()
.flatten()
.map(|(id, _)| -> PostprocessingResult<_> { Ok((id.clone(), count_column.clone())) });
let new_owned_table: OwnedTable<S> = group_by_outs
.into_iter()
.chain(sum_outs)
.chain(max_outs)
.chain(min_outs)
.chain(count_outs)
.process_results(|iter| OwnedTable::try_from_iter(iter))??;
let target_table = if new_owned_table.is_empty() {
OwnedTable::try_new(indexmap! {"__count__".into() => count_column})?
} else {
new_owned_table
};
let result = self
.remainder_exprs
.iter()
.map(|aliased_expr| -> PostprocessingResult<_> {
let column = target_table.evaluate(&aliased_expr.expr)?;
let alias: Ident = aliased_expr.alias.into();
Ok((alias, column))
})
.process_results(|iter| OwnedTable::try_from_iter(iter))??;
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use proof_of_sql_parser::utility::*;
#[test]
fn we_can_detect_nested_aggregation() {
let expr = sum(sum(col("a")));
assert!(contains_nested_aggregation(&expr, false));
assert!(contains_nested_aggregation(&expr, true));
let expr = add(max(col("a")), sum(col("b")));
assert!(!contains_nested_aggregation(&expr, false));
assert!(contains_nested_aggregation(&expr, true));
let expr = add(col("a"), sum(col("b")));
assert!(!contains_nested_aggregation(&expr, false));
assert!(contains_nested_aggregation(&expr, true));
let expr = sub(add(sum(col("a")), col("b")), sum(mul(lit(2), col("c"))));
assert!(!contains_nested_aggregation(&expr, false));
assert!(contains_nested_aggregation(&expr, true));
let expr = add(col("a"), count(sum(col("a"))));
assert!(contains_nested_aggregation(&expr, false));
assert!(contains_nested_aggregation(&expr, true));
let expr = add(add(col("a"), col("b")), lit(1));
assert!(!contains_nested_aggregation(&expr, false));
assert!(!contains_nested_aggregation(&expr, true));
}
#[test]
fn we_can_get_free_identifiers_from_expr() {
let expr = lit("Not an identifier");
let expected: IndexSet<Ident> = IndexSet::default();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
let expr = add(add(col("a"), col("b")), lit(1));
let expected: IndexSet<Ident> = ["a".into(), "b".into()].into_iter().collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
let expr = not(or(equal(col("a"), col("b")), ge(col("c"), col("a"))));
let expected: IndexSet<Ident> = ["a".into(), "b".into(), "c".into()].into_iter().collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
let expr = mul(sum(add(col("a"), col("b"))), lit(2));
let expected: IndexSet<Ident> = IndexSet::default();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
let expr = mul(add(count(add(col("a"), col("b"))), col("c")), col("d"));
let expected: IndexSet<Ident> = ["c".into(), "d".into()].into_iter().collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
}
#[test]
fn we_can_get_aggregate_and_remainder_expressions() {
let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Ident> =
IndexMap::default();
let expr = add(sum(col("a")), col("b"));
let remainder_expr =
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
"__col_agg_0".into()
);
assert_eq!(remainder_expr, Ok(*add(col("__col_agg_0"), col("b"))));
assert_eq!(aggregation_expr_map.len(), 1);
let expr = add(sum(col("a")), sum(col("b")));
let remainder_expr =
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
"__col_agg_0".into()
);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("b"))],
"__col_agg_1".into()
);
assert_eq!(
remainder_expr,
Ok(*add(col("__col_agg_0"), col("__col_agg_1")))
);
assert_eq!(aggregation_expr_map.len(), 2);
let expr = add(
add(
max(col("a") + lit(1)),
min(sub(mul(lit(2), col("b")), lit(4))),
),
col("c"),
);
let remainder_expr =
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Max, *add(col("a"), lit(1)))],
"__col_agg_2".into()
);
assert_eq!(
aggregation_expr_map[&(
AggregationOperator::Min,
*sub(mul(lit(2), col("b")), lit(4))
)],
"__col_agg_3".into()
);
assert_eq!(
remainder_expr,
Ok(*add(add(col("__col_agg_2"), col("__col_agg_3")), col("c")))
);
assert_eq!(aggregation_expr_map.len(), 4);
let expr = add(
add(mul(count(mul(lit(2), col("a"))), lit(2)), sum(col("b"))),
lit(1),
);
let remainder_expr =
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Count, *mul(lit(2), col("a")))],
"__col_agg_4".into()
);
assert_eq!(
remainder_expr,
Ok(*add(
add(mul(col("__col_agg_4"), lit(2)), col("__col_agg_1")),
lit(1)
))
);
assert_eq!(aggregation_expr_map.len(), 5);
}
}