use super::{PostprocessingError, PostprocessingResult, PostprocessingStep};
use crate::base::{
database::{group_by_util::aggregate_columns, Column, OwnedColumn, OwnedTable},
map::{indexmap, IndexMap, IndexSet},
scalar::Scalar,
};
use alloc::{boxed::Box, format, string::ToString, vec, vec::Vec};
use bumpalo::Bump;
use itertools::{izip, Itertools};
use proof_of_sql_parser::{
intermediate_ast::{AggregationOperator, AliasedResultExpr, Expression},
Identifier,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GroupByPostprocessing {
remainder_exprs: Vec<AliasedResultExpr>,
group_by_identifiers: Vec<Identifier>,
aggregation_exprs: Vec<(AggregationOperator, Expression, Identifier)>,
}
fn contains_nested_aggregation(expr: &Expression, is_agg: bool) -> bool {
match expr {
Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => false,
Expression::Aggregation { expr, .. } => is_agg || contains_nested_aggregation(expr, true),
Expression::Binary { left, right, .. } => {
contains_nested_aggregation(left, is_agg) || contains_nested_aggregation(right, is_agg)
}
Expression::Unary { expr, .. } => contains_nested_aggregation(expr, is_agg),
}
}
fn get_free_identifiers_from_expr(expr: &Expression) -> IndexSet<Identifier> {
match expr {
Expression::Column(identifier) => IndexSet::from_iter([*identifier]),
Expression::Literal(_) | Expression::Aggregation { .. } | Expression::Wildcard => {
IndexSet::default()
}
Expression::Binary { left, right, .. } => {
let mut left_identifiers = get_free_identifiers_from_expr(left);
let right_identifiers = get_free_identifiers_from_expr(right);
left_identifiers.extend(right_identifiers);
left_identifiers
}
Expression::Unary { expr, .. } => get_free_identifiers_from_expr(expr),
}
}
fn get_aggregate_and_remainder_expressions(
expr: Expression,
aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Identifier>,
) -> Expression {
match expr {
Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => expr,
Expression::Aggregation { op, expr } => {
let key = (op, (*expr));
if aggregation_expr_map.contains_key(&key) {
Expression::Column(*aggregation_expr_map.get(&key).unwrap())
} else {
let new_col_id = format!("__col_agg_{}", aggregation_expr_map.len())
.parse()
.unwrap();
aggregation_expr_map.insert(key, new_col_id);
Expression::Column(new_col_id)
}
}
Expression::Binary { op, left, right } => {
let left_remainder =
get_aggregate_and_remainder_expressions(*left, aggregation_expr_map);
let right_remainder =
get_aggregate_and_remainder_expressions(*right, aggregation_expr_map);
Expression::Binary {
op,
left: Box::new(left_remainder),
right: Box::new(right_remainder),
}
}
Expression::Unary { op, expr } => {
let remainder = get_aggregate_and_remainder_expressions(*expr, aggregation_expr_map);
Expression::Unary {
op,
expr: Box::new(remainder),
}
}
}
}
fn check_and_get_aggregation_and_remainder(
expr: AliasedResultExpr,
group_by_identifiers: &[Identifier],
aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Identifier>,
) -> PostprocessingResult<AliasedResultExpr> {
let free_identifiers = get_free_identifiers_from_expr(&expr.expr);
let group_by_identifier_set = group_by_identifiers
.iter()
.copied()
.collect::<IndexSet<_>>();
if contains_nested_aggregation(&expr.expr, false) {
return Err(PostprocessingError::NestedAggregationInGroupByClause {
error: format!("Nested aggregations found {:?}", expr.expr),
});
}
if free_identifiers.is_subset(&group_by_identifier_set) {
let remainder = get_aggregate_and_remainder_expressions(*expr.expr, aggregation_expr_map);
Ok(AliasedResultExpr {
alias: expr.alias,
expr: Box::new(remainder),
})
} else {
let diff = free_identifiers
.difference(&group_by_identifier_set)
.next()
.unwrap();
Err(
PostprocessingError::IdentifierNotInAggregationOperatorOrGroupByClause {
column: *diff,
},
)
}
}
impl GroupByPostprocessing {
pub fn try_new(
by_ids: Vec<Identifier>,
aliased_exprs: Vec<AliasedResultExpr>,
) -> PostprocessingResult<Self> {
let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Identifier> =
IndexMap::default();
let remainder_exprs: Vec<AliasedResultExpr> = aliased_exprs
.into_iter()
.map(|aliased_expr| -> PostprocessingResult<_> {
check_and_get_aggregation_and_remainder(
aliased_expr,
&by_ids,
&mut aggregation_expr_map,
)
})
.collect::<PostprocessingResult<Vec<AliasedResultExpr>>>()?;
let group_by_identifiers = Vec::from_iter(IndexSet::from_iter(by_ids));
Ok(Self {
remainder_exprs,
group_by_identifiers,
aggregation_exprs: aggregation_expr_map
.into_iter()
.map(|((op, expr), id)| (op, expr, id))
.collect(),
})
}
#[must_use]
pub fn group_by(&self) -> &[Identifier] {
&self.group_by_identifiers
}
#[must_use]
pub fn remainder_exprs(&self) -> &[AliasedResultExpr] {
&self.remainder_exprs
}
#[must_use]
pub fn aggregation_exprs(&self) -> &[(AggregationOperator, Expression, Identifier)] {
&self.aggregation_exprs
}
}
impl<S: Scalar> PostprocessingStep<S> for GroupByPostprocessing {
#[allow(clippy::too_many_lines)]
fn apply(&self, owned_table: OwnedTable<S>) -> PostprocessingResult<OwnedTable<S>> {
let alloc = Bump::new();
let evaluated_columns = self
.aggregation_exprs
.iter()
.map(|(agg_op, expr, id)| -> PostprocessingResult<_> {
let evaluated_owned_column = owned_table.evaluate(expr)?;
Ok((*agg_op, (*id, evaluated_owned_column)))
})
.process_results(|iter| {
iter.fold(
IndexMap::<_, Vec<_>>::default(),
|mut lookup, (key, val)| {
lookup.entry(key).or_default().push(val);
lookup
},
)
})?;
let group_by_ins = self
.group_by_identifiers
.iter()
.map(|id| {
let column = owned_table.inner_table().get(id).ok_or(
PostprocessingError::ColumnNotFound {
column: id.to_string(),
},
)?;
Ok(Column::<S>::from_owned_column(column, &alloc))
})
.collect::<PostprocessingResult<Vec<_>>>()?;
let selection_in = vec![true; owned_table.num_rows()];
let (sum_identifiers, sum_columns): (Vec<_>, Vec<_>) = evaluated_columns
.get(&AggregationOperator::Sum)
.map_or((vec![], vec![]), |tuple| {
tuple
.iter()
.map(|(id, c)| (*id, Column::<S>::from_owned_column(c, &alloc)))
.unzip()
});
let (max_identifiers, max_columns): (Vec<_>, Vec<_>) = evaluated_columns
.get(&AggregationOperator::Max)
.map_or((vec![], vec![]), |tuple| {
tuple
.iter()
.map(|(id, c)| (*id, Column::<S>::from_owned_column(c, &alloc)))
.unzip()
});
let (min_identifiers, min_columns): (Vec<_>, Vec<_>) = evaluated_columns
.get(&AggregationOperator::Min)
.map_or((vec![], vec![]), |tuple| {
tuple
.iter()
.map(|(id, c)| (*id, Column::<S>::from_owned_column(c, &alloc)))
.unzip()
});
let aggregation_results = aggregate_columns(
&alloc,
&group_by_ins,
&sum_columns,
&max_columns,
&min_columns,
&selection_in,
)?;
let group_by_outs = aggregation_results
.group_by_columns
.iter()
.zip(self.group_by_identifiers.iter())
.map(|(column, id)| Ok((*id, OwnedColumn::from(column))));
let sum_outs = izip!(
aggregation_results.sum_columns,
sum_identifiers,
sum_columns,
)
.map(|(c_out, id, c_in)| {
Ok((
id,
OwnedColumn::try_from_scalars(c_out, c_in.column_type())?,
))
});
let max_outs = izip!(
aggregation_results.max_columns,
max_identifiers,
max_columns,
)
.map(|(c_out, id, c_in)| {
Ok((
id,
OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
))
});
let min_outs = izip!(
aggregation_results.min_columns,
min_identifiers,
min_columns,
)
.map(|(c_out, id, c_in)| {
Ok((
id,
OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
))
});
let count_column = OwnedColumn::BigInt(aggregation_results.count_column.to_vec());
let count_outs = evaluated_columns
.get(&AggregationOperator::Count)
.into_iter()
.flatten()
.map(|(id, _)| -> PostprocessingResult<_> { Ok((*id, count_column.clone())) });
let new_owned_table: OwnedTable<S> = group_by_outs
.into_iter()
.chain(sum_outs)
.chain(max_outs)
.chain(min_outs)
.chain(count_outs)
.process_results(|iter| OwnedTable::try_from_iter(iter))??;
let target_table = if new_owned_table.is_empty() {
OwnedTable::try_new(indexmap! {"__count__".parse().unwrap() => count_column})?
} else {
new_owned_table
};
let result = self
.remainder_exprs
.iter()
.map(|aliased_expr| -> PostprocessingResult<_> {
let column = target_table.evaluate(&aliased_expr.expr)?;
Ok((aliased_expr.alias, column))
})
.process_results(|iter| OwnedTable::try_from_iter(iter))??;
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use proof_of_sql_parser::utility::*;
#[test]
fn we_can_detect_nested_aggregation() {
let expr = sum(sum(col("a")));
assert!(contains_nested_aggregation(&expr, false));
assert!(contains_nested_aggregation(&expr, true));
let expr = add(max(col("a")), sum(col("b")));
assert!(!contains_nested_aggregation(&expr, false));
assert!(contains_nested_aggregation(&expr, true));
let expr = add(col("a"), sum(col("b")));
assert!(!contains_nested_aggregation(&expr, false));
assert!(contains_nested_aggregation(&expr, true));
let expr = sub(add(sum(col("a")), col("b")), sum(mul(lit(2), col("c"))));
assert!(!contains_nested_aggregation(&expr, false));
assert!(contains_nested_aggregation(&expr, true));
let expr = add(col("a"), count(sum(col("a"))));
assert!(contains_nested_aggregation(&expr, false));
assert!(contains_nested_aggregation(&expr, true));
let expr = add(add(col("a"), col("b")), lit(1));
assert!(!contains_nested_aggregation(&expr, false));
assert!(!contains_nested_aggregation(&expr, true));
}
#[test]
fn we_can_get_free_identifiers_from_expr() {
let expr = lit("Not an identifier");
let expected: IndexSet<Identifier> = IndexSet::default();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
let expr = add(add(col("a"), col("b")), lit(1));
let expected: IndexSet<Identifier> = [ident("a"), ident("b")].iter().copied().collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
let expr = not(or(equal(col("a"), col("b")), ge(col("c"), col("a"))));
let expected: IndexSet<Identifier> = [ident("a"), ident("b"), ident("c")]
.iter()
.copied()
.collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
let expr = mul(sum(add(col("a"), col("b"))), lit(2));
let expected: IndexSet<Identifier> = IndexSet::default();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
let expr = mul(add(count(add(col("a"), col("b"))), col("c")), col("d"));
let expected: IndexSet<Identifier> = [ident("c"), ident("d")].iter().copied().collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
}
#[test]
fn we_can_get_aggregate_and_remainder_expressions() {
let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Identifier> =
IndexMap::default();
let expr = add(sum(col("a")), col("b"));
let remainder_expr =
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
ident("__col_agg_0")
);
assert_eq!(remainder_expr, *add(col("__col_agg_0"), col("b")));
assert_eq!(aggregation_expr_map.len(), 1);
let expr = add(sum(col("a")), sum(col("b")));
let remainder_expr =
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
ident("__col_agg_0")
);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("b"))],
ident("__col_agg_1")
);
assert_eq!(remainder_expr, *add(col("__col_agg_0"), col("__col_agg_1")));
assert_eq!(aggregation_expr_map.len(), 2);
let expr = add(
add(
max(col("a") + lit(1)),
min(sub(mul(lit(2), col("b")), lit(4))),
),
col("c"),
);
let remainder_expr =
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Max, *add(col("a"), lit(1)))],
ident("__col_agg_2")
);
assert_eq!(
aggregation_expr_map[&(
AggregationOperator::Min,
*sub(mul(lit(2), col("b")), lit(4))
)],
ident("__col_agg_3")
);
assert_eq!(
remainder_expr,
*add(add(col("__col_agg_2"), col("__col_agg_3")), col("c"))
);
assert_eq!(aggregation_expr_map.len(), 4);
let expr = add(
add(mul(count(mul(lit(2), col("a"))), lit(2)), sum(col("b"))),
lit(1),
);
let remainder_expr =
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Count, *mul(lit(2), col("a")))],
ident("__col_agg_4")
);
assert_eq!(
remainder_expr,
*add(
add(mul(col("__col_agg_4"), lit(2)), col("__col_agg_1")),
lit(1)
)
);
assert_eq!(aggregation_expr_map.len(), 5);
}
}