Skip to main content

nodedb_sql/planner/
grouping_sets.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! ROLLUP / CUBE / GROUPING SETS expansion and canonical key extraction.
4//!
5//! The public entry point is `expand_group_by`, which inspects the raw AST
6//! GROUP BY clause and returns:
7//!
8//! - `canonical_keys`: the full deduplicated column list (indices 0..N).
9//! - `grouping_sets`: one entry per set; each entry is the subset of
10//!   canonical-key indices that are *active* (non-NULL) in that set.
11//!
12//! `GROUPING(col)` at query time resolves `col` to its canonical-key index and
13//! checks whether that bit is absent from the row's active set.
14
15use sqlparser::ast::{self, GroupByExpr};
16
17use crate::error::{Result, SqlError};
18use crate::parser::normalize::normalize_ident;
19use crate::resolver::expr::convert_expr;
20use crate::types::SqlExpr;
21
22/// Result of expanding a GROUP BY clause that contains ROLLUP/CUBE/GROUPING SETS.
23#[derive(Debug, Clone)]
24pub struct GroupingSetsExpansion {
25    /// All distinct group-key expressions in canonical order.
26    pub canonical_keys: Vec<SqlExpr>,
27    /// One entry per logical grouping set; each entry is the indices into
28    /// `canonical_keys` that are *present* (non-NULL) for rows in that set.
29    pub grouping_sets: Vec<Vec<usize>>,
30}
31
32/// Expand the GROUP BY clause if it contains ROLLUP/CUBE/GROUPING SETS.
33///
34/// Returns `None` when the GROUP BY is a plain expression list with no
35/// extensions — callers fall back to the existing single-set path.
36pub fn expand_group_by(group_by: &GroupByExpr) -> Result<Option<GroupingSetsExpansion>> {
37    let exprs = match group_by {
38        GroupByExpr::All(_) => return Ok(None),
39        GroupByExpr::Expressions(exprs, _) => exprs,
40    };
41
42    // Check whether any expression is ROLLUP / CUBE / GROUPING SETS.
43    let has_extension = exprs.iter().any(is_grouping_extension);
44    if !has_extension {
45        return Ok(None);
46    }
47
48    // Split into plain columns and the single extension expression.
49    // SQL standard: only one extension per GROUP BY; mixed is allowed but
50    // forms a cross-product with the plain columns.
51    let mut plain_ast: Vec<&ast::Expr> = Vec::new();
52    let mut extension_sets: Option<Vec<Vec<&ast::Expr>>> = None;
53
54    for expr in exprs {
55        match expr {
56            ast::Expr::Rollup(groups) => {
57                if extension_sets.is_some() {
58                    return Err(SqlError::Unsupported {
59                        detail: "only one ROLLUP/CUBE/GROUPING SETS per GROUP BY is supported"
60                            .into(),
61                    });
62                }
63                extension_sets = Some(expand_rollup(groups));
64            }
65            ast::Expr::Cube(groups) => {
66                if extension_sets.is_some() {
67                    return Err(SqlError::Unsupported {
68                        detail: "only one ROLLUP/CUBE/GROUPING SETS per GROUP BY is supported"
69                            .into(),
70                    });
71                }
72                extension_sets = Some(expand_cube(groups));
73            }
74            ast::Expr::GroupingSets(sets) => {
75                if extension_sets.is_some() {
76                    return Err(SqlError::Unsupported {
77                        detail: "only one ROLLUP/CUBE/GROUPING SETS per GROUP BY is supported"
78                            .into(),
79                    });
80                }
81                // GroupingSets: each inner Vec<Expr> is one set.
82                extension_sets = Some(sets.iter().map(|s| s.iter().collect()).collect());
83            }
84            other => {
85                plain_ast.push(other);
86            }
87        }
88    }
89
90    let ext_sets = extension_sets.unwrap_or_default();
91
92    // Build canonical key list: plain columns first, then extension columns
93    // (deduped by display name so identical columns share an index).
94    let mut canonical_names: Vec<String> = Vec::new();
95    let mut canonical_exprs: Vec<SqlExpr> = Vec::new();
96
97    let mut intern = |e: &ast::Expr| -> Result<usize> {
98        let display = format!("{e}");
99        if let Some(pos) = canonical_names.iter().position(|n| n == &display) {
100            return Ok(pos);
101        }
102        let idx = canonical_names.len();
103        canonical_names.push(display);
104        canonical_exprs.push(convert_expr(e)?);
105        Ok(idx)
106    };
107
108    // Plain columns get canonical indices first.
109    let plain_indices: Vec<usize> = plain_ast.iter().map(|e| intern(e)).collect::<Result<_>>()?;
110
111    // Extension sets: each set is a list of ast::Expr refs → indices.
112    let ext_sets_indexed: Vec<Vec<usize>> = ext_sets
113        .into_iter()
114        .map(|set| set.into_iter().map(&mut intern).collect::<Result<_>>())
115        .collect::<Result<_>>()?;
116
117    // Cross-product: plain_indices × ext_sets_indexed.
118    // For each extension set, prepend the plain indices.
119    let grouping_sets: Vec<Vec<usize>> = if ext_sets_indexed.is_empty() {
120        // Only plain columns — this shouldn't happen (caught by has_extension),
121        // but handle gracefully.
122        vec![plain_indices]
123    } else {
124        ext_sets_indexed
125            .into_iter()
126            .map(|ext_set| {
127                // plain_indices always present; ext_set columns also present.
128                let mut combined = plain_indices.clone();
129                for idx in &ext_set {
130                    if !combined.contains(idx) {
131                        combined.push(*idx);
132                    }
133                }
134                combined
135            })
136            .collect()
137    };
138
139    Ok(Some(GroupingSetsExpansion {
140        canonical_keys: canonical_exprs,
141        grouping_sets,
142    }))
143}
144
145/// Returns true if the expression is a ROLLUP/CUBE/GROUPING SETS node.
146fn is_grouping_extension(expr: &ast::Expr) -> bool {
147    matches!(
148        expr,
149        ast::Expr::Rollup(_) | ast::Expr::Cube(_) | ast::Expr::GroupingSets(_)
150    )
151}
152
153/// Expand `ROLLUP(a, b, c)` → `[[a,b,c], [a,b], [a], []]`.
154///
155/// The input is `Vec<Vec<Expr>>` where each inner vec is one composite element.
156/// We flatten composite elements to individual expressions for simplicity — the
157/// outer product is: suffix-strip from all-present down to empty.
158fn expand_rollup(groups: &[Vec<ast::Expr>]) -> Vec<Vec<&ast::Expr>> {
159    // Flatten composite groups (e.g. `(a, b)` as one element) into atoms.
160    let atoms: Vec<&ast::Expr> = groups.iter().flat_map(|g| g.iter()).collect();
161    let n = atoms.len();
162    // Prefixes: atoms[0..n], atoms[0..n-1], ..., atoms[0..0] (empty).
163    (0..=n).rev().map(|len| atoms[..len].to_vec()).collect()
164}
165
166/// Expand `CUBE(a, b)` → all 2^N subsets.
167fn expand_cube(groups: &[Vec<ast::Expr>]) -> Vec<Vec<&ast::Expr>> {
168    let atoms: Vec<&ast::Expr> = groups.iter().flat_map(|g| g.iter()).collect();
169    let n = atoms.len();
170    let count = 1usize << n;
171    let mut sets: Vec<Vec<&ast::Expr>> = Vec::with_capacity(count);
172    // Enumerate all bitmasks from (all-present) down to 0 (empty).
173    for mask in (0..count).rev() {
174        let set: Vec<&ast::Expr> = (0..n)
175            .filter(|i| (mask >> i) & 1 == 1)
176            .map(|i| atoms[i])
177            .collect();
178        sets.push(set);
179    }
180    sets
181}
182
183/// Resolve the canonical index for a `GROUPING(col)` argument.
184///
185/// Matches by display string against `canonical_names`.
186pub fn resolve_grouping_col(col_expr: &ast::Expr, canonical_keys: &[SqlExpr]) -> Result<usize> {
187    let display = format!("{col_expr}");
188    // Find by rebuilding the display of each canonical key via its SqlExpr.
189    // Since canonical_keys are built from the same AST exprs, we can compare
190    // their SqlExpr display strings.
191    for (i, key) in canonical_keys.iter().enumerate() {
192        if format!("{key:?}").contains(&display) {
193            return Ok(i);
194        }
195    }
196    // Fallback: try normalized ident match.
197    if let ast::Expr::Identifier(ident) = col_expr {
198        let name = normalize_ident(ident);
199        for (i, key) in canonical_keys.iter().enumerate() {
200            if let SqlExpr::Column { name: col_name, .. } = key
201                && col_name.eq_ignore_ascii_case(&name)
202            {
203                return Ok(i);
204            }
205        }
206    }
207    Err(SqlError::Unsupported {
208        detail: format!(
209            "GROUPING({col_expr}) references a column not found in the canonical key list"
210        ),
211    })
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    fn parse_group_by(sql: &str) -> GroupByExpr {
219        use sqlparser::dialect::GenericDialect;
220        use sqlparser::parser::Parser;
221        let stmts = Parser::parse_sql(&GenericDialect {}, sql).unwrap();
222        match stmts.into_iter().next().unwrap() {
223            ast::Statement::Query(q) => match *q.body {
224                ast::SetExpr::Select(s) => s.group_by,
225                _ => panic!("expected SELECT"),
226            },
227            _ => panic!("expected query"),
228        }
229    }
230
231    #[test]
232    fn rollup_two_cols() {
233        let gb = parse_group_by(
234            "SELECT region, country, SUM(sales) FROM orders GROUP BY ROLLUP (region, country)",
235        );
236        let result = expand_group_by(&gb).unwrap().unwrap();
237        // ROLLUP(region, country) → [[0,1], [0], []]
238        assert_eq!(result.canonical_keys.len(), 2);
239        assert_eq!(result.grouping_sets.len(), 3);
240        assert_eq!(result.grouping_sets[0], vec![0, 1]); // (region, country)
241        assert_eq!(result.grouping_sets[1], vec![0]); // (region)
242        assert_eq!(result.grouping_sets[2], Vec::<usize>::new()); // ()
243    }
244
245    #[test]
246    fn cube_two_cols() {
247        let gb = parse_group_by(
248            "SELECT region, country, SUM(sales) FROM orders GROUP BY CUBE (region, country)",
249        );
250        let result = expand_group_by(&gb).unwrap().unwrap();
251        // CUBE(region, country) → [[0,1], [0], [1], []]
252        assert_eq!(result.canonical_keys.len(), 2);
253        assert_eq!(result.grouping_sets.len(), 4);
254        // All-present first.
255        assert!(result.grouping_sets[0].contains(&0));
256        assert!(result.grouping_sets[0].contains(&1));
257        // Empty set last.
258        assert_eq!(*result.grouping_sets.last().unwrap(), Vec::<usize>::new());
259    }
260
261    #[test]
262    fn grouping_sets_explicit() {
263        let gb = parse_group_by(
264            "SELECT region, country, SUM(sales) FROM orders \
265             GROUP BY GROUPING SETS ((region, country), (region), ())",
266        );
267        let result = expand_group_by(&gb).unwrap().unwrap();
268        assert_eq!(result.canonical_keys.len(), 2);
269        assert_eq!(result.grouping_sets.len(), 3);
270        assert_eq!(result.grouping_sets[0], vec![0, 1]);
271        assert_eq!(result.grouping_sets[1], vec![0]);
272        assert_eq!(result.grouping_sets[2], Vec::<usize>::new());
273    }
274
275    #[test]
276    fn plain_group_by_returns_none() {
277        let gb = parse_group_by("SELECT region, COUNT(*) FROM orders GROUP BY region");
278        let result = expand_group_by(&gb).unwrap();
279        assert!(result.is_none());
280    }
281
282    #[test]
283    fn mixed_plain_and_rollup() {
284        let gb = parse_group_by("SELECT a, b, c, SUM(x) FROM t GROUP BY a, ROLLUP (b, c)");
285        let result = expand_group_by(&gb).unwrap().unwrap();
286        // Canonical: a(0), b(1), c(2).
287        // Extension sets (from ROLLUP(b,c)): [[b,c], [b], []].
288        // Cross-product with plain [a]:
289        //   set 0: [a, b, c] = [0,1,2]
290        //   set 1: [a, b]    = [0,1]
291        //   set 2: [a]       = [0]
292        assert_eq!(result.canonical_keys.len(), 3);
293        assert_eq!(result.grouping_sets.len(), 3);
294        assert!(result.grouping_sets[0].contains(&0)); // a always present
295        assert!(result.grouping_sets[1].contains(&0));
296        assert!(result.grouping_sets[2].contains(&0));
297    }
298
299    #[test]
300    fn rollup_three_cols() {
301        let gb = parse_group_by("SELECT a, b, c, SUM(x) FROM t GROUP BY ROLLUP (a, b, c)");
302        let result = expand_group_by(&gb).unwrap().unwrap();
303        assert_eq!(result.grouping_sets.len(), 4); // (a,b,c),(a,b),(a),()
304    }
305}