proof_of_sql/sql/postprocessing/
group_by_postprocessing.rs

1use super::{PostprocessingError, PostprocessingResult, PostprocessingStep};
2use crate::base::{
3    database::{group_by_util::aggregate_columns, Column, OwnedColumn, OwnedTable},
4    map::{indexmap, IndexMap, IndexSet},
5    scalar::Scalar,
6};
7use alloc::{boxed::Box, format, string::ToString, vec, vec::Vec};
8use bumpalo::Bump;
9use itertools::{izip, Itertools};
10use proof_of_sql_parser::{
11    intermediate_ast::{AggregationOperator, AliasedResultExpr, Expression},
12    Identifier,
13};
14use serde::{Deserialize, Serialize};
15use sqlparser::ast::Ident;
16
17/// A group by expression
18#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
19pub struct GroupByPostprocessing {
20    /// A list of `AliasedResultExpr` that exclusively use identifiers in the group by clause or results of aggregation expressions
21    remainder_exprs: Vec<AliasedResultExpr>,
22
23    /// A list of identifiers in the group by clause
24    group_by_identifiers: Vec<Ident>,
25
26    /// A list of aggregation expressions
27    aggregation_exprs: Vec<(AggregationOperator, Expression, Ident)>,
28}
29
30/// Check whether multiple layers of aggregation exist within the same GROUP BY clause
31/// since this is not allowed in SQL
32///
33/// If the context is within an aggregation function, then any aggregation function is considered nested.
34/// Otherwise we need two layers of aggregation functions to be nested.
35fn contains_nested_aggregation(expr: &Expression, is_agg: bool) -> bool {
36    match expr {
37        Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => false,
38        Expression::Aggregation { expr, .. } => is_agg || contains_nested_aggregation(expr, true),
39        Expression::Binary { left, right, .. } => {
40            contains_nested_aggregation(left, is_agg) || contains_nested_aggregation(right, is_agg)
41        }
42        Expression::Unary { expr, .. } => contains_nested_aggregation(expr, is_agg),
43    }
44}
45
46/// Get identifiers NOT in aggregate functions
47fn get_free_identifiers_from_expr(expr: &Expression) -> IndexSet<Ident> {
48    match expr {
49        Expression::Column(identifier) => IndexSet::from_iter([(*identifier).into()]),
50        Expression::Literal(_) | Expression::Aggregation { .. } | Expression::Wildcard => {
51            IndexSet::default()
52        }
53        Expression::Binary { left, right, .. } => {
54            let mut left_identifiers = get_free_identifiers_from_expr(left);
55            let right_identifiers = get_free_identifiers_from_expr(right);
56            left_identifiers.extend(right_identifiers);
57            left_identifiers
58        }
59        Expression::Unary { expr, .. } => get_free_identifiers_from_expr(expr),
60    }
61}
62
63/// Get aggregate expressions from an expression as well as the remainder
64///
65/// The idea here is to recursively traverse the expression tree and collect all the aggregation expressions
66/// and then label them as new columns post-aggregation and replace them with these new columns so that
67/// the post-aggregation expression tree doesn't contain any aggregation expressions and can be simply evaluated.
68/// # Panics
69///
70/// Will panic if the key for an aggregation expression cannot be parsed as a valid identifier
71/// or if there are issues retrieving an identifier from the map.
72fn get_aggregate_and_remainder_expressions(
73    expr: Expression,
74    aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Ident>,
75) -> Result<Expression, PostprocessingError> {
76    match expr {
77        Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => Ok(expr),
78        Expression::Aggregation { op, expr } => {
79            let key = (op, (*expr));
80            if let Some(ident) = aggregation_expr_map.get(&key) {
81                let identifier = Identifier::try_from(ident.clone()).map_err(|e| {
82                    PostprocessingError::IdentifierConversionError {
83                        error: format!("Failed to convert Ident to Identifier: {e}"),
84                    }
85                })?;
86                Ok(Expression::Column(identifier))
87            } else {
88                let new_ident = Ident {
89                    value: format!("__col_agg_{}", aggregation_expr_map.len()),
90                    quote_style: None,
91                };
92
93                let new_identifier = Identifier::try_from(new_ident.clone()).map_err(|e| {
94                    PostprocessingError::IdentifierConversionError {
95                        error: format!("Failed to convert Ident to Identifier: {e}"),
96                    }
97                })?;
98
99                aggregation_expr_map.insert(key, new_ident);
100                Ok(Expression::Column(new_identifier))
101            }
102        }
103        Expression::Binary { op, left, right } => {
104            let left_remainder =
105                get_aggregate_and_remainder_expressions(*left, aggregation_expr_map);
106            let right_remainder =
107                get_aggregate_and_remainder_expressions(*right, aggregation_expr_map);
108            Ok(Expression::Binary {
109                op,
110                left: Box::new(left_remainder?),
111                right: Box::new(right_remainder?),
112            })
113        }
114        Expression::Unary { op, expr } => {
115            let remainder = get_aggregate_and_remainder_expressions(*expr, aggregation_expr_map);
116            Ok(Expression::Unary {
117                op,
118                expr: Box::new(remainder?),
119            })
120        }
121    }
122}
123
124/// Given an `AliasedResultExpr`, check if it is legitimate and if so grab the relevant aggregation expression
125/// # Panics
126///
127/// Will panic if there is an issue retrieving the first element from the difference of free identifiers and group-by identifiers, indicating a logical inconsistency in the identifiers.
128fn check_and_get_aggregation_and_remainder(
129    expr: AliasedResultExpr,
130    group_by_identifiers: &[Ident],
131    aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Ident>,
132) -> PostprocessingResult<AliasedResultExpr> {
133    let free_identifiers = get_free_identifiers_from_expr(&expr.expr);
134    let group_by_identifier_set = group_by_identifiers
135        .iter()
136        .cloned()
137        .collect::<IndexSet<_>>();
138    if contains_nested_aggregation(&expr.expr, false) {
139        return Err(PostprocessingError::NestedAggregationInGroupByClause {
140            error: format!("Nested aggregations found {:?}", expr.expr),
141        });
142    }
143    if free_identifiers.is_subset(&group_by_identifier_set) {
144        let remainder = get_aggregate_and_remainder_expressions(*expr.expr, aggregation_expr_map);
145        Ok(AliasedResultExpr {
146            alias: expr.alias,
147            expr: Box::new(remainder?),
148        })
149    } else {
150        let diff = free_identifiers
151            .difference(&group_by_identifier_set)
152            .next()
153            .unwrap();
154        Err(
155            PostprocessingError::IdentNotInAggregationOperatorOrGroupByClause {
156                column: diff.clone(),
157            },
158        )
159    }
160}
161
162impl GroupByPostprocessing {
163    /// Create a new group by expression containing the group by and aggregation expressions
164    pub fn try_new(
165        by_ids: Vec<Ident>,
166        aliased_exprs: Vec<AliasedResultExpr>,
167    ) -> PostprocessingResult<Self> {
168        let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Ident> =
169            IndexMap::default();
170        // Look for aggregation expressions and check for non-aggregation expressions that contain identifiers not in the group by clause
171        let remainder_exprs: Vec<AliasedResultExpr> = aliased_exprs
172            .into_iter()
173            .map(|aliased_expr| -> PostprocessingResult<_> {
174                check_and_get_aggregation_and_remainder(
175                    aliased_expr,
176                    &by_ids,
177                    &mut aggregation_expr_map,
178                )
179            })
180            .collect::<PostprocessingResult<Vec<AliasedResultExpr>>>()?;
181        let group_by_identifiers = Vec::from_iter(IndexSet::from_iter(by_ids));
182        Ok(Self {
183            remainder_exprs,
184            group_by_identifiers,
185            aggregation_exprs: aggregation_expr_map
186                .into_iter()
187                .map(|((op, expr), id)| (op, expr, id))
188                .collect(),
189        })
190    }
191
192    /// Get group by identifiers
193    #[must_use]
194    pub fn group_by(&self) -> &[Ident] {
195        &self.group_by_identifiers
196    }
197
198    /// Get remainder expressions for SELECT
199    #[must_use]
200    pub fn remainder_exprs(&self) -> &[AliasedResultExpr] {
201        &self.remainder_exprs
202    }
203
204    /// Get aggregation expressions
205    #[must_use]
206    pub fn aggregation_exprs(&self) -> &[(AggregationOperator, Expression, Ident)] {
207        &self.aggregation_exprs
208    }
209}
210
211impl<S: Scalar> PostprocessingStep<S> for GroupByPostprocessing {
212    /// Apply the group by transformation to the given `OwnedTable`.
213    #[expect(clippy::too_many_lines)]
214    fn apply(&self, owned_table: OwnedTable<S>) -> PostprocessingResult<OwnedTable<S>> {
215        // First evaluate all the aggregated columns
216        let alloc = Bump::new();
217        let evaluated_columns = self
218            .aggregation_exprs
219            .iter()
220            .map(|(agg_op, expr, id)| -> PostprocessingResult<_> {
221                let evaluated_owned_column = owned_table.evaluate(expr)?;
222                Ok((*agg_op, (id.clone(), evaluated_owned_column)))
223            })
224            .process_results(|iter| {
225                iter.fold(
226                    IndexMap::<_, Vec<_>>::default(),
227                    |mut lookup, (key, val)| {
228                        lookup.entry(key).or_default().push(val);
229                        lookup
230                    },
231                )
232            })?;
233        // Next actually do the GROUP BY
234        let group_by_ins = self
235            .group_by_identifiers
236            .iter()
237            .map(|id| {
238                let column = owned_table.inner_table().get(id).ok_or(
239                    PostprocessingError::ColumnNotFound {
240                        column: id.to_string(),
241                    },
242                )?;
243                Ok(Column::<S>::from_owned_column(column, &alloc))
244            })
245            .collect::<PostprocessingResult<Vec<_>>>()?;
246        // TODO: Allow a filter
247        let selection_in = vec![true; owned_table.num_rows()];
248        let (sum_identifiers, sum_columns): (Vec<_>, Vec<_>) = evaluated_columns
249            .get(&AggregationOperator::Sum)
250            .map_or((vec![], vec![]), |tuple| {
251                tuple
252                    .iter()
253                    .map(|(id, c)| (id.clone(), Column::<S>::from_owned_column(c, &alloc)))
254                    .unzip()
255            });
256        let (max_identifiers, max_columns): (Vec<_>, Vec<_>) = evaluated_columns
257            .get(&AggregationOperator::Max)
258            .map_or((vec![], vec![]), |tuple| {
259                tuple
260                    .iter()
261                    .map(|(id, c)| (id.clone(), Column::<S>::from_owned_column(c, &alloc)))
262                    .unzip()
263            });
264        let (min_identifiers, min_columns): (Vec<_>, Vec<_>) = evaluated_columns
265            .get(&AggregationOperator::Min)
266            .map_or((vec![], vec![]), |tuple| {
267                tuple
268                    .iter()
269                    .map(|(id, c)| (id.clone(), Column::<S>::from_owned_column(c, &alloc)))
270                    .unzip()
271            });
272        let aggregation_results = aggregate_columns(
273            &alloc,
274            &group_by_ins,
275            &sum_columns,
276            &max_columns,
277            &min_columns,
278            &selection_in,
279        )?;
280        // Finally do another round of evaluation to get the final result
281        // Gather the results into a new OwnedTable
282        let group_by_outs = aggregation_results
283            .group_by_columns
284            .iter()
285            .zip(self.group_by_identifiers.iter())
286            .map(|(column, id)| Ok((id.clone(), OwnedColumn::from(column))));
287        let sum_outs = izip!(
288            aggregation_results.sum_columns,
289            sum_identifiers,
290            sum_columns,
291        )
292        .map(|(c_out, id, c_in)| {
293            Ok((
294                id,
295                OwnedColumn::try_from_scalars(c_out, c_in.column_type())?,
296            ))
297        });
298        let max_outs = izip!(
299            aggregation_results.max_columns,
300            max_identifiers,
301            max_columns,
302        )
303        .map(|(c_out, id, c_in)| {
304            Ok((
305                id,
306                OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
307            ))
308        });
309        let min_outs = izip!(
310            aggregation_results.min_columns,
311            min_identifiers,
312            min_columns,
313        )
314        .map(|(c_out, id, c_in)| {
315            Ok((
316                id,
317                OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?,
318            ))
319        });
320        //TODO: When we have NULLs we need to differentiate between count(1) and count(expression)
321        let count_column = OwnedColumn::BigInt(aggregation_results.count_column.to_vec());
322        let count_outs = evaluated_columns
323            .get(&AggregationOperator::Count)
324            .into_iter()
325            .flatten()
326            .map(|(id, _)| -> PostprocessingResult<_> { Ok((id.clone(), count_column.clone())) });
327        let new_owned_table: OwnedTable<S> = group_by_outs
328            .into_iter()
329            .chain(sum_outs)
330            .chain(max_outs)
331            .chain(min_outs)
332            .chain(count_outs)
333            .process_results(|iter| OwnedTable::try_from_iter(iter))??;
334        // If there are no columns at all we need to have the count column so that we can handle
335        // queries such as `SELECT 1 FROM table`
336        let target_table = if new_owned_table.is_empty() {
337            OwnedTable::try_new(indexmap! {"__count__".into() => count_column})?
338        } else {
339            new_owned_table
340        };
341        let result = self
342            .remainder_exprs
343            .iter()
344            .map(|aliased_expr| -> PostprocessingResult<_> {
345                let column = target_table.evaluate(&aliased_expr.expr)?;
346                let alias: Ident = aliased_expr.alias.into();
347                Ok((alias, column))
348            })
349            .process_results(|iter| OwnedTable::try_from_iter(iter))??;
350        Ok(result)
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use proof_of_sql_parser::utility::*;
358
359    #[test]
360    fn we_can_detect_nested_aggregation() {
361        // SUM(SUM(a))
362        let expr = sum(sum(col("a")));
363        assert!(contains_nested_aggregation(&expr, false));
364        assert!(contains_nested_aggregation(&expr, true));
365
366        // MAX(a) + SUM(b)
367        let expr = add(max(col("a")), sum(col("b")));
368        assert!(!contains_nested_aggregation(&expr, false));
369        assert!(contains_nested_aggregation(&expr, true));
370
371        // a + SUM(b)
372        let expr = add(col("a"), sum(col("b")));
373        assert!(!contains_nested_aggregation(&expr, false));
374        assert!(contains_nested_aggregation(&expr, true));
375
376        // SUM(a) + b - SUM(2 * c)
377        let expr = sub(add(sum(col("a")), col("b")), sum(mul(lit(2), col("c"))));
378        assert!(!contains_nested_aggregation(&expr, false));
379        assert!(contains_nested_aggregation(&expr, true));
380
381        // a + COUNT(SUM(a))
382        let expr = add(col("a"), count(sum(col("a"))));
383        assert!(contains_nested_aggregation(&expr, false));
384        assert!(contains_nested_aggregation(&expr, true));
385
386        // a + b + 1
387        let expr = add(add(col("a"), col("b")), lit(1));
388        assert!(!contains_nested_aggregation(&expr, false));
389        assert!(!contains_nested_aggregation(&expr, true));
390    }
391
392    #[test]
393    fn we_can_get_free_identifiers_from_expr() {
394        // Literal
395        let expr = lit("Not an identifier");
396        let expected: IndexSet<Ident> = IndexSet::default();
397        let actual = get_free_identifiers_from_expr(&expr);
398        assert_eq!(actual, expected);
399
400        // a + b + 1
401        let expr = add(add(col("a"), col("b")), lit(1));
402        let expected: IndexSet<Ident> = ["a".into(), "b".into()].into_iter().collect();
403        let actual = get_free_identifiers_from_expr(&expr);
404        assert_eq!(actual, expected);
405
406        // ! (a == b || c >= a)
407        let expr = not(or(equal(col("a"), col("b")), ge(col("c"), col("a"))));
408        let expected: IndexSet<Ident> = ["a".into(), "b".into(), "c".into()].into_iter().collect();
409        let actual = get_free_identifiers_from_expr(&expr);
410        assert_eq!(actual, expected);
411
412        // SUM(a + b) * 2
413        let expr = mul(sum(add(col("a"), col("b"))), lit(2));
414        let expected: IndexSet<Ident> = IndexSet::default();
415        let actual = get_free_identifiers_from_expr(&expr);
416        assert_eq!(actual, expected);
417
418        // (COUNT(a + b) + c) * d
419        let expr = mul(add(count(add(col("a"), col("b"))), col("c")), col("d"));
420        let expected: IndexSet<Ident> = ["c".into(), "d".into()].into_iter().collect();
421        let actual = get_free_identifiers_from_expr(&expr);
422        assert_eq!(actual, expected);
423    }
424
425    #[test]
426    fn we_can_get_aggregate_and_remainder_expressions() {
427        let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Ident> =
428            IndexMap::default();
429        // SUM(a) + b
430        let expr = add(sum(col("a")), col("b"));
431        let remainder_expr =
432            get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
433        assert_eq!(
434            aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
435            "__col_agg_0".into()
436        );
437        assert_eq!(remainder_expr, Ok(*add(col("__col_agg_0"), col("b"))));
438        assert_eq!(aggregation_expr_map.len(), 1);
439
440        // SUM(a) + SUM(b)
441        let expr = add(sum(col("a")), sum(col("b")));
442        let remainder_expr =
443            get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
444        assert_eq!(
445            aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
446            "__col_agg_0".into()
447        );
448        assert_eq!(
449            aggregation_expr_map[&(AggregationOperator::Sum, *col("b"))],
450            "__col_agg_1".into()
451        );
452        assert_eq!(
453            remainder_expr,
454            Ok(*add(col("__col_agg_0"), col("__col_agg_1")))
455        );
456        assert_eq!(aggregation_expr_map.len(), 2);
457
458        // MAX(a + 1) + MIN(2 * b - 4) + c
459        let expr = add(
460            add(
461                max(col("a") + lit(1)),
462                min(sub(mul(lit(2), col("b")), lit(4))),
463            ),
464            col("c"),
465        );
466        let remainder_expr =
467            get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
468        assert_eq!(
469            aggregation_expr_map[&(AggregationOperator::Max, *add(col("a"), lit(1)))],
470            "__col_agg_2".into()
471        );
472        assert_eq!(
473            aggregation_expr_map[&(
474                AggregationOperator::Min,
475                *sub(mul(lit(2), col("b")), lit(4))
476            )],
477            "__col_agg_3".into()
478        );
479        assert_eq!(
480            remainder_expr,
481            Ok(*add(add(col("__col_agg_2"), col("__col_agg_3")), col("c")))
482        );
483        assert_eq!(aggregation_expr_map.len(), 4);
484
485        // COUNT(2 * a) * 2 + SUM(b) + 1
486        let expr = add(
487            add(mul(count(mul(lit(2), col("a"))), lit(2)), sum(col("b"))),
488            lit(1),
489        );
490        let remainder_expr =
491            get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
492        assert_eq!(
493            aggregation_expr_map[&(AggregationOperator::Count, *mul(lit(2), col("a")))],
494            "__col_agg_4".into()
495        );
496        assert_eq!(
497            remainder_expr,
498            Ok(*add(
499                add(mul(col("__col_agg_4"), lit(2)), col("__col_agg_1")),
500                lit(1)
501            ))
502        );
503        assert_eq!(aggregation_expr_map.len(), 5);
504    }
505}