nodedb_sql/planner/
grouping_sets.rs1use sqlparser::ast::{self, GroupByExpr};
16
17use crate::error::{Result, SqlError};
18use crate::parser::normalize::normalize_ident;
19use crate::resolver::expr::convert_expr;
20use crate::types::SqlExpr;
21
22#[derive(Debug, Clone)]
24pub struct GroupingSetsExpansion {
25 pub canonical_keys: Vec<SqlExpr>,
27 pub grouping_sets: Vec<Vec<usize>>,
30}
31
32pub fn expand_group_by(group_by: &GroupByExpr) -> Result<Option<GroupingSetsExpansion>> {
37 let exprs = match group_by {
38 GroupByExpr::All(_) => return Ok(None),
39 GroupByExpr::Expressions(exprs, _) => exprs,
40 };
41
42 let has_extension = exprs.iter().any(is_grouping_extension);
44 if !has_extension {
45 return Ok(None);
46 }
47
48 let mut plain_ast: Vec<&ast::Expr> = Vec::new();
52 let mut extension_sets: Option<Vec<Vec<&ast::Expr>>> = None;
53
54 for expr in exprs {
55 match expr {
56 ast::Expr::Rollup(groups) => {
57 if extension_sets.is_some() {
58 return Err(SqlError::Unsupported {
59 detail: "only one ROLLUP/CUBE/GROUPING SETS per GROUP BY is supported"
60 .into(),
61 });
62 }
63 extension_sets = Some(expand_rollup(groups));
64 }
65 ast::Expr::Cube(groups) => {
66 if extension_sets.is_some() {
67 return Err(SqlError::Unsupported {
68 detail: "only one ROLLUP/CUBE/GROUPING SETS per GROUP BY is supported"
69 .into(),
70 });
71 }
72 extension_sets = Some(expand_cube(groups));
73 }
74 ast::Expr::GroupingSets(sets) => {
75 if extension_sets.is_some() {
76 return Err(SqlError::Unsupported {
77 detail: "only one ROLLUP/CUBE/GROUPING SETS per GROUP BY is supported"
78 .into(),
79 });
80 }
81 extension_sets = Some(sets.iter().map(|s| s.iter().collect()).collect());
83 }
84 other => {
85 plain_ast.push(other);
86 }
87 }
88 }
89
90 let ext_sets = extension_sets.unwrap_or_default();
91
92 let mut canonical_names: Vec<String> = Vec::new();
95 let mut canonical_exprs: Vec<SqlExpr> = Vec::new();
96
97 let mut intern = |e: &ast::Expr| -> Result<usize> {
98 let display = format!("{e}");
99 if let Some(pos) = canonical_names.iter().position(|n| n == &display) {
100 return Ok(pos);
101 }
102 let idx = canonical_names.len();
103 canonical_names.push(display);
104 canonical_exprs.push(convert_expr(e)?);
105 Ok(idx)
106 };
107
108 let plain_indices: Vec<usize> = plain_ast.iter().map(|e| intern(e)).collect::<Result<_>>()?;
110
111 let ext_sets_indexed: Vec<Vec<usize>> = ext_sets
113 .into_iter()
114 .map(|set| set.into_iter().map(&mut intern).collect::<Result<_>>())
115 .collect::<Result<_>>()?;
116
117 let grouping_sets: Vec<Vec<usize>> = if ext_sets_indexed.is_empty() {
120 vec![plain_indices]
123 } else {
124 ext_sets_indexed
125 .into_iter()
126 .map(|ext_set| {
127 let mut combined = plain_indices.clone();
129 for idx in &ext_set {
130 if !combined.contains(idx) {
131 combined.push(*idx);
132 }
133 }
134 combined
135 })
136 .collect()
137 };
138
139 Ok(Some(GroupingSetsExpansion {
140 canonical_keys: canonical_exprs,
141 grouping_sets,
142 }))
143}
144
145fn is_grouping_extension(expr: &ast::Expr) -> bool {
147 matches!(
148 expr,
149 ast::Expr::Rollup(_) | ast::Expr::Cube(_) | ast::Expr::GroupingSets(_)
150 )
151}
152
153fn expand_rollup(groups: &[Vec<ast::Expr>]) -> Vec<Vec<&ast::Expr>> {
159 let atoms: Vec<&ast::Expr> = groups.iter().flat_map(|g| g.iter()).collect();
161 let n = atoms.len();
162 (0..=n).rev().map(|len| atoms[..len].to_vec()).collect()
164}
165
166fn expand_cube(groups: &[Vec<ast::Expr>]) -> Vec<Vec<&ast::Expr>> {
168 let atoms: Vec<&ast::Expr> = groups.iter().flat_map(|g| g.iter()).collect();
169 let n = atoms.len();
170 let count = 1usize << n;
171 let mut sets: Vec<Vec<&ast::Expr>> = Vec::with_capacity(count);
172 for mask in (0..count).rev() {
174 let set: Vec<&ast::Expr> = (0..n)
175 .filter(|i| (mask >> i) & 1 == 1)
176 .map(|i| atoms[i])
177 .collect();
178 sets.push(set);
179 }
180 sets
181}
182
183pub fn resolve_grouping_col(col_expr: &ast::Expr, canonical_keys: &[SqlExpr]) -> Result<usize> {
187 let display = format!("{col_expr}");
188 for (i, key) in canonical_keys.iter().enumerate() {
192 if format!("{key:?}").contains(&display) {
193 return Ok(i);
194 }
195 }
196 if let ast::Expr::Identifier(ident) = col_expr {
198 let name = normalize_ident(ident);
199 for (i, key) in canonical_keys.iter().enumerate() {
200 if let SqlExpr::Column { name: col_name, .. } = key
201 && col_name.eq_ignore_ascii_case(&name)
202 {
203 return Ok(i);
204 }
205 }
206 }
207 Err(SqlError::Unsupported {
208 detail: format!(
209 "GROUPING({col_expr}) references a column not found in the canonical key list"
210 ),
211 })
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 fn parse_group_by(sql: &str) -> GroupByExpr {
219 use sqlparser::dialect::GenericDialect;
220 use sqlparser::parser::Parser;
221 let stmts = Parser::parse_sql(&GenericDialect {}, sql).unwrap();
222 match stmts.into_iter().next().unwrap() {
223 ast::Statement::Query(q) => match *q.body {
224 ast::SetExpr::Select(s) => s.group_by,
225 _ => panic!("expected SELECT"),
226 },
227 _ => panic!("expected query"),
228 }
229 }
230
231 #[test]
232 fn rollup_two_cols() {
233 let gb = parse_group_by(
234 "SELECT region, country, SUM(sales) FROM orders GROUP BY ROLLUP (region, country)",
235 );
236 let result = expand_group_by(&gb).unwrap().unwrap();
237 assert_eq!(result.canonical_keys.len(), 2);
239 assert_eq!(result.grouping_sets.len(), 3);
240 assert_eq!(result.grouping_sets[0], vec![0, 1]); assert_eq!(result.grouping_sets[1], vec![0]); assert_eq!(result.grouping_sets[2], Vec::<usize>::new()); }
244
245 #[test]
246 fn cube_two_cols() {
247 let gb = parse_group_by(
248 "SELECT region, country, SUM(sales) FROM orders GROUP BY CUBE (region, country)",
249 );
250 let result = expand_group_by(&gb).unwrap().unwrap();
251 assert_eq!(result.canonical_keys.len(), 2);
253 assert_eq!(result.grouping_sets.len(), 4);
254 assert!(result.grouping_sets[0].contains(&0));
256 assert!(result.grouping_sets[0].contains(&1));
257 assert_eq!(*result.grouping_sets.last().unwrap(), Vec::<usize>::new());
259 }
260
261 #[test]
262 fn grouping_sets_explicit() {
263 let gb = parse_group_by(
264 "SELECT region, country, SUM(sales) FROM orders \
265 GROUP BY GROUPING SETS ((region, country), (region), ())",
266 );
267 let result = expand_group_by(&gb).unwrap().unwrap();
268 assert_eq!(result.canonical_keys.len(), 2);
269 assert_eq!(result.grouping_sets.len(), 3);
270 assert_eq!(result.grouping_sets[0], vec![0, 1]);
271 assert_eq!(result.grouping_sets[1], vec![0]);
272 assert_eq!(result.grouping_sets[2], Vec::<usize>::new());
273 }
274
275 #[test]
276 fn plain_group_by_returns_none() {
277 let gb = parse_group_by("SELECT region, COUNT(*) FROM orders GROUP BY region");
278 let result = expand_group_by(&gb).unwrap();
279 assert!(result.is_none());
280 }
281
282 #[test]
283 fn mixed_plain_and_rollup() {
284 let gb = parse_group_by("SELECT a, b, c, SUM(x) FROM t GROUP BY a, ROLLUP (b, c)");
285 let result = expand_group_by(&gb).unwrap().unwrap();
286 assert_eq!(result.canonical_keys.len(), 3);
293 assert_eq!(result.grouping_sets.len(), 3);
294 assert!(result.grouping_sets[0].contains(&0)); assert!(result.grouping_sets[1].contains(&0));
296 assert!(result.grouping_sets[2].contains(&0));
297 }
298
299 #[test]
300 fn rollup_three_cols() {
301 let gb = parse_group_by("SELECT a, b, c, SUM(x) FROM t GROUP BY ROLLUP (a, b, c)");
302 let result = expand_group_by(&gb).unwrap().unwrap();
303 assert_eq!(result.grouping_sets.len(), 4); }
305}