use sqlparser::ast::{self, GroupByExpr};
use crate::error::{Result, SqlError};
use crate::parser::normalize::normalize_ident;
use crate::resolver::expr::convert_expr;
use crate::types::SqlExpr;
#[derive(Debug, Clone)]
pub struct GroupingSetsExpansion {
pub canonical_keys: Vec<SqlExpr>,
pub grouping_sets: Vec<Vec<usize>>,
}
pub fn expand_group_by(group_by: &GroupByExpr) -> Result<Option<GroupingSetsExpansion>> {
let exprs = match group_by {
GroupByExpr::All(_) => return Ok(None),
GroupByExpr::Expressions(exprs, _) => exprs,
};
let has_extension = exprs.iter().any(is_grouping_extension);
if !has_extension {
return Ok(None);
}
let mut plain_ast: Vec<&ast::Expr> = Vec::new();
let mut extension_sets: Option<Vec<Vec<&ast::Expr>>> = None;
for expr in exprs {
match expr {
ast::Expr::Rollup(groups) => {
if extension_sets.is_some() {
return Err(SqlError::Unsupported {
detail: "only one ROLLUP/CUBE/GROUPING SETS per GROUP BY is supported"
.into(),
});
}
extension_sets = Some(expand_rollup(groups));
}
ast::Expr::Cube(groups) => {
if extension_sets.is_some() {
return Err(SqlError::Unsupported {
detail: "only one ROLLUP/CUBE/GROUPING SETS per GROUP BY is supported"
.into(),
});
}
extension_sets = Some(expand_cube(groups));
}
ast::Expr::GroupingSets(sets) => {
if extension_sets.is_some() {
return Err(SqlError::Unsupported {
detail: "only one ROLLUP/CUBE/GROUPING SETS per GROUP BY is supported"
.into(),
});
}
extension_sets = Some(sets.iter().map(|s| s.iter().collect()).collect());
}
other => {
plain_ast.push(other);
}
}
}
let ext_sets = extension_sets.unwrap_or_default();
let mut canonical_names: Vec<String> = Vec::new();
let mut canonical_exprs: Vec<SqlExpr> = Vec::new();
let mut intern = |e: &ast::Expr| -> Result<usize> {
let display = format!("{e}");
if let Some(pos) = canonical_names.iter().position(|n| n == &display) {
return Ok(pos);
}
let idx = canonical_names.len();
canonical_names.push(display);
canonical_exprs.push(convert_expr(e)?);
Ok(idx)
};
let plain_indices: Vec<usize> = plain_ast.iter().map(|e| intern(e)).collect::<Result<_>>()?;
let ext_sets_indexed: Vec<Vec<usize>> = ext_sets
.into_iter()
.map(|set| set.into_iter().map(&mut intern).collect::<Result<_>>())
.collect::<Result<_>>()?;
let grouping_sets: Vec<Vec<usize>> = if ext_sets_indexed.is_empty() {
vec![plain_indices]
} else {
ext_sets_indexed
.into_iter()
.map(|ext_set| {
let mut combined = plain_indices.clone();
for idx in &ext_set {
if !combined.contains(idx) {
combined.push(*idx);
}
}
combined
})
.collect()
};
Ok(Some(GroupingSetsExpansion {
canonical_keys: canonical_exprs,
grouping_sets,
}))
}
fn is_grouping_extension(expr: &ast::Expr) -> bool {
matches!(
expr,
ast::Expr::Rollup(_) | ast::Expr::Cube(_) | ast::Expr::GroupingSets(_)
)
}
fn expand_rollup(groups: &[Vec<ast::Expr>]) -> Vec<Vec<&ast::Expr>> {
let atoms: Vec<&ast::Expr> = groups.iter().flat_map(|g| g.iter()).collect();
let n = atoms.len();
(0..=n).rev().map(|len| atoms[..len].to_vec()).collect()
}
fn expand_cube(groups: &[Vec<ast::Expr>]) -> Vec<Vec<&ast::Expr>> {
let atoms: Vec<&ast::Expr> = groups.iter().flat_map(|g| g.iter()).collect();
let n = atoms.len();
let count = 1usize << n;
let mut sets: Vec<Vec<&ast::Expr>> = Vec::with_capacity(count);
for mask in (0..count).rev() {
let set: Vec<&ast::Expr> = (0..n)
.filter(|i| (mask >> i) & 1 == 1)
.map(|i| atoms[i])
.collect();
sets.push(set);
}
sets
}
pub fn resolve_grouping_col(col_expr: &ast::Expr, canonical_keys: &[SqlExpr]) -> Result<usize> {
let display = format!("{col_expr}");
for (i, key) in canonical_keys.iter().enumerate() {
if format!("{key:?}").contains(&display) {
return Ok(i);
}
}
if let ast::Expr::Identifier(ident) = col_expr {
let name = normalize_ident(ident);
for (i, key) in canonical_keys.iter().enumerate() {
if let SqlExpr::Column { name: col_name, .. } = key
&& col_name.eq_ignore_ascii_case(&name)
{
return Ok(i);
}
}
}
Err(SqlError::Unsupported {
detail: format!(
"GROUPING({col_expr}) references a column not found in the canonical key list"
),
})
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_group_by(sql: &str) -> GroupByExpr {
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
let stmts = Parser::parse_sql(&GenericDialect {}, sql).unwrap();
match stmts.into_iter().next().unwrap() {
ast::Statement::Query(q) => match *q.body {
ast::SetExpr::Select(s) => s.group_by,
_ => panic!("expected SELECT"),
},
_ => panic!("expected query"),
}
}
#[test]
fn rollup_two_cols() {
let gb = parse_group_by(
"SELECT region, country, SUM(sales) FROM orders GROUP BY ROLLUP (region, country)",
);
let result = expand_group_by(&gb).unwrap().unwrap();
assert_eq!(result.canonical_keys.len(), 2);
assert_eq!(result.grouping_sets.len(), 3);
assert_eq!(result.grouping_sets[0], vec![0, 1]); assert_eq!(result.grouping_sets[1], vec![0]); assert_eq!(result.grouping_sets[2], Vec::<usize>::new()); }
#[test]
fn cube_two_cols() {
let gb = parse_group_by(
"SELECT region, country, SUM(sales) FROM orders GROUP BY CUBE (region, country)",
);
let result = expand_group_by(&gb).unwrap().unwrap();
assert_eq!(result.canonical_keys.len(), 2);
assert_eq!(result.grouping_sets.len(), 4);
assert!(result.grouping_sets[0].contains(&0));
assert!(result.grouping_sets[0].contains(&1));
assert_eq!(*result.grouping_sets.last().unwrap(), Vec::<usize>::new());
}
#[test]
fn grouping_sets_explicit() {
let gb = parse_group_by(
"SELECT region, country, SUM(sales) FROM orders \
GROUP BY GROUPING SETS ((region, country), (region), ())",
);
let result = expand_group_by(&gb).unwrap().unwrap();
assert_eq!(result.canonical_keys.len(), 2);
assert_eq!(result.grouping_sets.len(), 3);
assert_eq!(result.grouping_sets[0], vec![0, 1]);
assert_eq!(result.grouping_sets[1], vec![0]);
assert_eq!(result.grouping_sets[2], Vec::<usize>::new());
}
#[test]
fn plain_group_by_returns_none() {
let gb = parse_group_by("SELECT region, COUNT(*) FROM orders GROUP BY region");
let result = expand_group_by(&gb).unwrap();
assert!(result.is_none());
}
#[test]
fn mixed_plain_and_rollup() {
let gb = parse_group_by("SELECT a, b, c, SUM(x) FROM t GROUP BY a, ROLLUP (b, c)");
let result = expand_group_by(&gb).unwrap().unwrap();
assert_eq!(result.canonical_keys.len(), 3);
assert_eq!(result.grouping_sets.len(), 3);
assert!(result.grouping_sets[0].contains(&0)); assert!(result.grouping_sets[1].contains(&0));
assert!(result.grouping_sets[2].contains(&0));
}
#[test]
fn rollup_three_cols() {
let gb = parse_group_by("SELECT a, b, c, SUM(x) FROM t GROUP BY ROLLUP (a, b, c)");
let result = expand_group_by(&gb).unwrap().unwrap();
assert_eq!(result.grouping_sets.len(), 4); }
}