1use super::{PostprocessingError, PostprocessingResult, PostprocessingStep};
2use crate::base::{
3 database::{group_by_util::aggregate_columns, Column, OwnedColumn, OwnedTable},
4 map::{indexmap, IndexMap, IndexSet},
5 scalar::Scalar,
6};
7use alloc::{boxed::Box, format, string::ToString, vec, vec::Vec};
8use bumpalo::Bump;
9use itertools::{izip, Itertools};
10use proof_of_sql_parser::{
11 intermediate_ast::{AggregationOperator, AliasedResultExpr, Expression},
12 Identifier,
13};
14use serde::{Deserialize, Serialize};
15use sqlparser::ast::Ident;
16
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
19pub struct GroupByPostprocessing {
20 remainder_exprs: Vec<AliasedResultExpr>,
22
23 group_by_identifiers: Vec<Ident>,
25
26 aggregation_exprs: Vec<(AggregationOperator, Expression, Ident)>,
28}
29
30fn contains_nested_aggregation(expr: &Expression, is_agg: bool) -> bool {
36 match expr {
37 Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => false,
38 Expression::Aggregation { expr, .. } => is_agg || contains_nested_aggregation(expr, true),
39 Expression::Binary { left, right, .. } => {
40 contains_nested_aggregation(left, is_agg) || contains_nested_aggregation(right, is_agg)
41 }
42 Expression::Unary { expr, .. } => contains_nested_aggregation(expr, is_agg),
43 }
44}
45
46fn get_free_identifiers_from_expr(expr: &Expression) -> IndexSet<Ident> {
48 match expr {
49 Expression::Column(identifier) => IndexSet::from_iter([(*identifier).into()]),
50 Expression::Literal(_) | Expression::Aggregation { .. } | Expression::Wildcard => {
51 IndexSet::default()
52 }
53 Expression::Binary { left, right, .. } => {
54 let mut left_identifiers = get_free_identifiers_from_expr(left);
55 let right_identifiers = get_free_identifiers_from_expr(right);
56 left_identifiers.extend(right_identifiers);
57 left_identifiers
58 }
59 Expression::Unary { expr, .. } => get_free_identifiers_from_expr(expr),
60 }
61}
62
63fn get_aggregate_and_remainder_expressions(
73 expr: Expression,
74 aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Ident>,
75) -> Result<Expression, PostprocessingError> {
76 match expr {
77 Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => Ok(expr),
78 Expression::Aggregation { op, expr } => {
79 let key = (op, (*expr));
80 if let Some(ident) = aggregation_expr_map.get(&key) {
81 let identifier = Identifier::try_from(ident.clone()).map_err(|e| {
82 PostprocessingError::IdentifierConversionError {
83 error: format!("Failed to convert Ident to Identifier: {e}"),
84 }
85 })?;
86 Ok(Expression::Column(identifier))
87 } else {
88 let new_ident = Ident {
89 value: format!("__col_agg_{}", aggregation_expr_map.len()),
90 quote_style: None,
91 };
92
93 let new_identifier = Identifier::try_from(new_ident.clone()).map_err(|e| {
94 PostprocessingError::IdentifierConversionError {
95 error: format!("Failed to convert Ident to Identifier: {e}"),
96 }
97 })?;
98
99 aggregation_expr_map.insert(key, new_ident);
100 Ok(Expression::Column(new_identifier))
101 }
102 }
103 Expression::Binary { op, left, right } => {
104 let left_remainder =
105 get_aggregate_and_remainder_expressions(*left, aggregation_expr_map);
106 let right_remainder =
107 get_aggregate_and_remainder_expressions(*right, aggregation_expr_map);
108 Ok(Expression::Binary {
109 op,
110 left: Box::new(left_remainder?),
111 right: Box::new(right_remainder?),
112 })
113 }
114 Expression::Unary { op, expr } => {
115 let remainder = get_aggregate_and_remainder_expressions(*expr, aggregation_expr_map);
116 Ok(Expression::Unary {
117 op,
118 expr: Box::new(remainder?),
119 })
120 }
121 }
122}
123
124fn check_and_get_aggregation_and_remainder(
129 expr: AliasedResultExpr,
130 group_by_identifiers: &[Ident],
131 aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Ident>,
132) -> PostprocessingResult<AliasedResultExpr> {
133 let free_identifiers = get_free_identifiers_from_expr(&expr.expr);
134 let group_by_identifier_set = group_by_identifiers
135 .iter()
136 .cloned()
137 .collect::<IndexSet<_>>();
138 if contains_nested_aggregation(&expr.expr, false) {
139 return Err(PostprocessingError::NestedAggregationInGroupByClause {
140 error: format!("Nested aggregations found {:?}", expr.expr),
141 });
142 }
143 if free_identifiers.is_subset(&group_by_identifier_set) {
144 let remainder = get_aggregate_and_remainder_expressions(*expr.expr, aggregation_expr_map);
145 Ok(AliasedResultExpr {
146 alias: expr.alias,
147 expr: Box::new(remainder?),
148 })
149 } else {
150 let diff = free_identifiers
151 .difference(&group_by_identifier_set)
152 .next()
153 .unwrap();
154 Err(
155 PostprocessingError::IdentNotInAggregationOperatorOrGroupByClause {
156 column: diff.clone(),
157 },
158 )
159 }
160}
161
162impl GroupByPostprocessing {
163 pub fn try_new(
165 by_ids: Vec<Ident>,
166 aliased_exprs: Vec<AliasedResultExpr>,
167 ) -> PostprocessingResult<Self> {
168 let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Ident> =
169 IndexMap::default();
170 let remainder_exprs: Vec<AliasedResultExpr> = aliased_exprs
172 .into_iter()
173 .map(|aliased_expr| -> PostprocessingResult<_> {
174 check_and_get_aggregation_and_remainder(
175 aliased_expr,
176 &by_ids,
177 &mut aggregation_expr_map,
178 )
179 })
180 .collect::<PostprocessingResult<Vec<AliasedResultExpr>>>()?;
181 let group_by_identifiers = Vec::from_iter(IndexSet::from_iter(by_ids));
182 Ok(Self {
183 remainder_exprs,
184 group_by_identifiers,
185 aggregation_exprs: aggregation_expr_map
186 .into_iter()
187 .map(|((op, expr), id)| (op, expr, id))
188 .collect(),
189 })
190 }
191
192 #[must_use]
194 pub fn group_by(&self) -> &[Ident] {
195 &self.group_by_identifiers
196 }
197
198 #[must_use]
200 pub fn remainder_exprs(&self) -> &[AliasedResultExpr] {
201 &self.remainder_exprs
202 }
203
204 #[must_use]
206 pub fn aggregation_exprs(&self) -> &[(AggregationOperator, Expression, Ident)] {
207 &self.aggregation_exprs
208 }
209}
210
211impl<S: Scalar> PostprocessingStep<S> for GroupByPostprocessing {
212 #[expect(clippy::too_many_lines)]
214 fn apply(&self, owned_table: OwnedTable<S>) -> PostprocessingResult<OwnedTable<S>> {
215 let alloc = Bump::new();
217 let evaluated_columns = self
218 .aggregation_exprs
219 .iter()
220 .map(|(agg_op, expr, id)| -> PostprocessingResult<_> {
221 let evaluated_owned_column = owned_table.evaluate(expr)?;
222 Ok((*agg_op, (id.clone(), evaluated_owned_column)))
223 })
224 .process_results(|iter| {
225 iter.fold(
226 IndexMap::<_, Vec<_>>::default(),
227 |mut lookup, (key, val)| {
228 lookup.entry(key).or_default().push(val);
229 lookup
230 },
231 )
232 })?;
233 let group_by_ins = self
235 .group_by_identifiers
236 .iter()
237 .map(|id| {
238 let column = owned_table.inner_table().get(id).ok_or(
239 PostprocessingError::ColumnNotFound {
240 column: id.to_string(),
241 },
242 )?;
243 Ok(Column::<S>::from_owned_column(column, &alloc))
244 })
245 .collect::<PostprocessingResult<Vec<_>>>()?;
246 let selection_in = vec![true; owned_table.num_rows()];
248 let (sum_identifiers, sum_columns): (Vec<_>, Vec<_>) = evaluated_columns
249 .get(&AggregationOperator::Sum)
250 .map_or((vec![], vec![]), |tuple| {
251 tuple
252 .iter()
253 .map(|(id, c)| (id.clone(), Column::<S>::from_owned_column(c, &alloc)))
254 .unzip()
255 });
256 let (max_identifiers, max_columns): (Vec<_>, Vec<_>) = evaluated_columns
257 .get(&AggregationOperator::Max)
258 .map_or((vec![], vec![]), |tuple| {
259 tuple
260 .iter()
261 .map(|(id, c)| (id.clone(), Column::<S>::from_owned_column(c, &alloc)))
262 .unzip()
263 });
264 let (min_identifiers, min_columns): (Vec<_>, Vec<_>) = evaluated_columns
265 .get(&AggregationOperator::Min)
266 .map_or((vec![], vec![]), |tuple| {
267 tuple
268 .iter()
269 .map(|(id, c)| (id.clone(), Column::<S>::from_owned_column(c, &alloc)))
270 .unzip()
271 });
272 let aggregation_results = aggregate_columns(
273 &alloc,
274 &group_by_ins,
275 &sum_columns,
276 &max_columns,
277 &min_columns,
278 &selection_in,
279 )?;
280 let group_by_outs = aggregation_results
283 .group_by_columns
284 .iter()
285 .zip(self.group_by_identifiers.iter())
286 .map(|(column, id)| Ok((id.clone(), OwnedColumn::from(column))));
287 let sum_outs = izip!(
288 aggregation_results.sum_columns,
289 sum_identifiers,
290 sum_columns,
291 )
292 .map(|(c_out, id, c_in)| {
293 Ok((
294 id,
295 OwnedColumn::try_from_scalars(c_out, c_in.column_type())?,
296 ))
297 });
298 let max_outs = izip!(
299 aggregation_results.max_columns,
300 max_identifiers,
301 max_columns,
302 )
303 .map(|(c_out, id, c_in)| {
304 Ok((
305 id,
306 OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
307 ))
308 });
309 let min_outs = izip!(
310 aggregation_results.min_columns,
311 min_identifiers,
312 min_columns,
313 )
314 .map(|(c_out, id, c_in)| {
315 Ok((
316 id,
317 OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
318 ))
319 });
320 let count_column = OwnedColumn::BigInt(aggregation_results.count_column.to_vec());
322 let count_outs = evaluated_columns
323 .get(&AggregationOperator::Count)
324 .into_iter()
325 .flatten()
326 .map(|(id, _)| -> PostprocessingResult<_> { Ok((id.clone(), count_column.clone())) });
327 let new_owned_table: OwnedTable<S> = group_by_outs
328 .into_iter()
329 .chain(sum_outs)
330 .chain(max_outs)
331 .chain(min_outs)
332 .chain(count_outs)
333 .process_results(|iter| OwnedTable::try_from_iter(iter))??;
334 let target_table = if new_owned_table.is_empty() {
337 OwnedTable::try_new(indexmap! {"__count__".into() => count_column})?
338 } else {
339 new_owned_table
340 };
341 let result = self
342 .remainder_exprs
343 .iter()
344 .map(|aliased_expr| -> PostprocessingResult<_> {
345 let column = target_table.evaluate(&aliased_expr.expr)?;
346 let alias: Ident = aliased_expr.alias.into();
347 Ok((alias, column))
348 })
349 .process_results(|iter| OwnedTable::try_from_iter(iter))??;
350 Ok(result)
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use proof_of_sql_parser::utility::*;
358
359 #[test]
360 fn we_can_detect_nested_aggregation() {
361 let expr = sum(sum(col("a")));
363 assert!(contains_nested_aggregation(&expr, false));
364 assert!(contains_nested_aggregation(&expr, true));
365
366 let expr = add(max(col("a")), sum(col("b")));
368 assert!(!contains_nested_aggregation(&expr, false));
369 assert!(contains_nested_aggregation(&expr, true));
370
371 let expr = add(col("a"), sum(col("b")));
373 assert!(!contains_nested_aggregation(&expr, false));
374 assert!(contains_nested_aggregation(&expr, true));
375
376 let expr = sub(add(sum(col("a")), col("b")), sum(mul(lit(2), col("c"))));
378 assert!(!contains_nested_aggregation(&expr, false));
379 assert!(contains_nested_aggregation(&expr, true));
380
381 let expr = add(col("a"), count(sum(col("a"))));
383 assert!(contains_nested_aggregation(&expr, false));
384 assert!(contains_nested_aggregation(&expr, true));
385
386 let expr = add(add(col("a"), col("b")), lit(1));
388 assert!(!contains_nested_aggregation(&expr, false));
389 assert!(!contains_nested_aggregation(&expr, true));
390 }
391
392 #[test]
393 fn we_can_get_free_identifiers_from_expr() {
394 let expr = lit("Not an identifier");
396 let expected: IndexSet<Ident> = IndexSet::default();
397 let actual = get_free_identifiers_from_expr(&expr);
398 assert_eq!(actual, expected);
399
400 let expr = add(add(col("a"), col("b")), lit(1));
402 let expected: IndexSet<Ident> = ["a".into(), "b".into()].into_iter().collect();
403 let actual = get_free_identifiers_from_expr(&expr);
404 assert_eq!(actual, expected);
405
406 let expr = not(or(equal(col("a"), col("b")), ge(col("c"), col("a"))));
408 let expected: IndexSet<Ident> = ["a".into(), "b".into(), "c".into()].into_iter().collect();
409 let actual = get_free_identifiers_from_expr(&expr);
410 assert_eq!(actual, expected);
411
412 let expr = mul(sum(add(col("a"), col("b"))), lit(2));
414 let expected: IndexSet<Ident> = IndexSet::default();
415 let actual = get_free_identifiers_from_expr(&expr);
416 assert_eq!(actual, expected);
417
418 let expr = mul(add(count(add(col("a"), col("b"))), col("c")), col("d"));
420 let expected: IndexSet<Ident> = ["c".into(), "d".into()].into_iter().collect();
421 let actual = get_free_identifiers_from_expr(&expr);
422 assert_eq!(actual, expected);
423 }
424
425 #[test]
426 fn we_can_get_aggregate_and_remainder_expressions() {
427 let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Ident> =
428 IndexMap::default();
429 let expr = add(sum(col("a")), col("b"));
431 let remainder_expr =
432 get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
433 assert_eq!(
434 aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
435 "__col_agg_0".into()
436 );
437 assert_eq!(remainder_expr, Ok(*add(col("__col_agg_0"), col("b"))));
438 assert_eq!(aggregation_expr_map.len(), 1);
439
440 let expr = add(sum(col("a")), sum(col("b")));
442 let remainder_expr =
443 get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
444 assert_eq!(
445 aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
446 "__col_agg_0".into()
447 );
448 assert_eq!(
449 aggregation_expr_map[&(AggregationOperator::Sum, *col("b"))],
450 "__col_agg_1".into()
451 );
452 assert_eq!(
453 remainder_expr,
454 Ok(*add(col("__col_agg_0"), col("__col_agg_1")))
455 );
456 assert_eq!(aggregation_expr_map.len(), 2);
457
458 let expr = add(
460 add(
461 max(col("a") + lit(1)),
462 min(sub(mul(lit(2), col("b")), lit(4))),
463 ),
464 col("c"),
465 );
466 let remainder_expr =
467 get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
468 assert_eq!(
469 aggregation_expr_map[&(AggregationOperator::Max, *add(col("a"), lit(1)))],
470 "__col_agg_2".into()
471 );
472 assert_eq!(
473 aggregation_expr_map[&(
474 AggregationOperator::Min,
475 *sub(mul(lit(2), col("b")), lit(4))
476 )],
477 "__col_agg_3".into()
478 );
479 assert_eq!(
480 remainder_expr,
481 Ok(*add(add(col("__col_agg_2"), col("__col_agg_3")), col("c")))
482 );
483 assert_eq!(aggregation_expr_map.len(), 4);
484
485 let expr = add(
487 add(mul(count(mul(lit(2), col("a"))), lit(2)), sum(col("b"))),
488 lit(1),
489 );
490 let remainder_expr =
491 get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
492 assert_eq!(
493 aggregation_expr_map[&(AggregationOperator::Count, *mul(lit(2), col("a")))],
494 "__col_agg_4".into()
495 );
496 assert_eq!(
497 remainder_expr,
498 Ok(*add(
499 add(mul(col("__col_agg_4"), lit(2)), col("__col_agg_1")),
500 lit(1)
501 ))
502 );
503 assert_eq!(aggregation_expr_map.len(), 5);
504 }
505}