use sqlparser::ast::{self, GroupByExpr};
use crate::engine_rules::{self, AggregateParams};
use crate::error::Result;
use crate::functions::registry::{FunctionRegistry, SearchTrigger};
use crate::parser::normalize::normalize_ident;
use crate::planner::grouping_sets::expand_group_by;
use crate::resolver::columns::ResolvedTable;
use crate::resolver::expr::convert_expr;
use crate::temporal::TemporalScope;
use crate::types::*;
pub fn plan_aggregate(
select: &ast::Select,
table: &ResolvedTable,
filters: &[Filter],
_scope: &crate::resolver::columns::TableScope,
functions: &FunctionRegistry,
temporal: &TemporalScope,
) -> Result<SqlPlan> {
let grouping_expansion = expand_group_by(&select.group_by)?;
let (group_by_exprs, grouping_sets) = if let Some(exp) = grouping_expansion {
(exp.canonical_keys, Some(exp.grouping_sets))
} else {
(convert_group_by(&select.group_by)?, None)
};
let mut aggregates = extract_aggregates_from_projection(&select.projection, functions)?;
let having = match &select.having {
Some(expr) => super::select::convert_where_to_filters(expr)?,
None => Vec::new(),
};
if grouping_sets.is_some() {
let grouping_aggs = extract_grouping_calls(&select.projection, &group_by_exprs)?;
aggregates.extend(grouping_aggs);
}
let (bucket_interval_ms, group_columns) =
extract_timeseries_params(&select.group_by, &select.projection, functions)?;
let rules = engine_rules::resolve_engine_rules(table.info.engine);
let base_plan = rules.plan_aggregate(AggregateParams {
collection: table.name.clone(),
alias: table.alias.clone(),
filters: filters.to_vec(),
group_by: group_by_exprs.clone(),
aggregates: aggregates.clone(),
having: having.clone(),
limit: 10000,
bucket_interval_ms,
group_columns,
has_auto_tier: table.info.has_auto_tier,
bitemporal: table.info.bitemporal,
temporal: *temporal,
})?;
if let Some(sets) = grouping_sets {
return Ok(attach_grouping_sets(
base_plan,
group_by_exprs,
aggregates,
having,
sets,
));
}
Ok(base_plan)
}
fn attach_grouping_sets(
base_plan: SqlPlan,
group_by: Vec<SqlExpr>,
aggregates: Vec<AggregateExpr>,
having: Vec<Filter>,
grouping_sets: Vec<Vec<usize>>,
) -> SqlPlan {
match base_plan {
SqlPlan::Aggregate {
input,
limit,
grouping_sets: _,
..
} => SqlPlan::Aggregate {
input,
group_by,
aggregates,
having,
limit,
grouping_sets: Some(grouping_sets),
},
other => {
SqlPlan::Aggregate {
input: Box::new(other),
group_by,
aggregates,
having,
limit: 10000,
grouping_sets: Some(grouping_sets),
}
}
}
}
fn extract_timeseries_params(
raw_group_by: &GroupByExpr,
select_items: &[ast::SelectItem],
functions: &FunctionRegistry,
) -> Result<(Option<i64>, Vec<String>)> {
let mut bucket_interval_ms: Option<i64> = None;
let mut group_columns = Vec::new();
if let GroupByExpr::Expressions(exprs, _) = raw_group_by {
for expr in exprs {
let resolved = resolve_group_by_expr(expr, select_items);
let check_expr = resolved.unwrap_or(expr);
if let Some(interval) = try_extract_time_bucket(check_expr, functions)? {
bucket_interval_ms = Some(interval);
continue;
}
if let ast::Expr::Identifier(ident) = expr {
group_columns.push(normalize_ident(ident));
}
}
}
Ok((bucket_interval_ms, group_columns))
}
fn try_extract_time_bucket(expr: &ast::Expr, functions: &FunctionRegistry) -> Result<Option<i64>> {
if let ast::Expr::Function(func) = expr {
let name = func
.name
.0
.iter()
.map(|p| match p {
ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
_ => String::new(),
})
.collect::<Vec<_>>()
.join(".");
if functions.search_trigger(&name) == SearchTrigger::TimeBucket {
return Ok(Some(extract_bucket_interval(func)?));
}
}
Ok(None)
}
fn resolve_group_by_expr<'a>(
expr: &ast::Expr,
select_items: &'a [ast::SelectItem],
) -> Option<&'a ast::Expr> {
match expr {
ast::Expr::Identifier(ident) => {
let alias_name = normalize_ident(ident);
select_items.iter().find_map(|item| {
if let ast::SelectItem::ExprWithAlias { expr, alias } = item
&& normalize_ident(alias) == alias_name
{
Some(expr)
} else {
None
}
})
}
ast::Expr::Value(v) => {
if let ast::Value::Number(n, _) = &v.value
&& let Ok(idx) = n.parse::<usize>()
&& idx >= 1
&& idx <= select_items.len()
{
match &select_items[idx - 1] {
ast::SelectItem::UnnamedExpr(e) => Some(e),
ast::SelectItem::ExprWithAlias { expr, .. } => Some(expr),
_ => None,
}
} else {
None
}
}
_ => None,
}
}
fn extract_bucket_interval(func: &ast::Function) -> Result<i64> {
let args = match &func.args {
ast::FunctionArguments::List(args) => &args.args,
_ => return Ok(0),
};
for arg in args {
if let ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(ast::Expr::Value(v))) = arg {
match &v.value {
ast::Value::SingleQuotedString(s) => {
let ms = parse_interval_to_ms(s);
if ms > 0 {
return Ok(ms);
}
}
ast::Value::Number(n, _) => {
if let Ok(secs) = n.parse::<i64>()
&& secs > 0
{
return Ok(secs * 1000);
}
}
_ => {}
}
}
}
Ok(0)
}
fn parse_interval_to_ms(s: &str) -> i64 {
nodedb_types::kv_parsing::parse_interval_to_ms(s)
.map(|ms| ms as i64)
.unwrap_or(0)
}
pub fn convert_group_by(group_by: &GroupByExpr) -> Result<Vec<SqlExpr>> {
match group_by {
GroupByExpr::All(_) => Ok(Vec::new()),
GroupByExpr::Expressions(exprs, _) => exprs.iter().map(convert_expr).collect(),
}
}
fn extract_grouping_calls(
items: &[ast::SelectItem],
canonical_keys: &[SqlExpr],
) -> Result<Vec<AggregateExpr>> {
let mut out = Vec::new();
for item in items {
let (expr, alias): (&ast::Expr, String) = match item {
ast::SelectItem::UnnamedExpr(expr) => (expr, format!("{expr}")),
ast::SelectItem::ExprWithAlias { expr, alias } => (expr, normalize_ident(alias)),
_ => continue,
};
collect_grouping_from_expr(expr, &alias, canonical_keys, &mut out)?;
}
Ok(out)
}
fn collect_grouping_from_expr(
expr: &ast::Expr,
alias: &str,
canonical_keys: &[SqlExpr],
out: &mut Vec<AggregateExpr>,
) -> Result<()> {
match expr {
ast::Expr::Function(f) => {
let name = normalize_function_name(f);
if name.eq_ignore_ascii_case("grouping") {
let args = function_args_exprs(f);
for col_expr in &args {
let canonical_idx = crate::planner::grouping_sets::resolve_grouping_col(
col_expr,
canonical_keys,
)?;
out.push(AggregateExpr {
function: "grouping".into(),
args: vec![convert_expr(col_expr)?],
alias: alias.to_string(),
distinct: false,
grouping_col_index: Some(canonical_idx),
});
}
}
}
ast::Expr::BinaryOp { left, right, .. } => {
collect_grouping_from_expr(left, alias, canonical_keys, out)?;
collect_grouping_from_expr(right, alias, canonical_keys, out)?;
}
_ => {}
}
Ok(())
}
fn function_args_exprs(f: &ast::Function) -> Vec<&ast::Expr> {
match &f.args {
ast::FunctionArguments::List(list) => list
.args
.iter()
.filter_map(|a| match a {
ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => Some(e),
_ => None,
})
.collect(),
_ => Vec::new(),
}
}
fn normalize_function_name(f: &ast::Function) -> String {
f.name
.0
.last()
.map(|p| match p {
ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
_ => String::new(),
})
.unwrap_or_default()
}
pub fn extract_aggregates_from_projection(
items: &[ast::SelectItem],
functions: &FunctionRegistry,
) -> Result<Vec<AggregateExpr>> {
let mut aggregates = Vec::new();
for item in items {
let (expr, alias): (&ast::Expr, String) = match item {
ast::SelectItem::UnnamedExpr(expr) => (expr, format!("{expr}")),
ast::SelectItem::ExprWithAlias { expr, alias } => (expr, normalize_ident(alias)),
_ => continue,
};
let mut extracted = crate::aggregate_walk::extract_aggregates(expr, &alias, functions)?;
aggregates.append(&mut extracted);
}
Ok(aggregates)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_intervals() {
assert_eq!(parse_interval_to_ms("1h"), 3_600_000);
assert_eq!(parse_interval_to_ms("15m"), 900_000);
assert_eq!(parse_interval_to_ms("30s"), 30_000);
assert_eq!(parse_interval_to_ms("7d"), 604_800_000);
assert_eq!(parse_interval_to_ms("1 hour"), 3_600_000);
assert_eq!(parse_interval_to_ms("2 hours"), 7_200_000);
assert_eq!(parse_interval_to_ms("15 minutes"), 900_000);
assert_eq!(parse_interval_to_ms("30 seconds"), 30_000);
assert_eq!(parse_interval_to_ms("1 day"), 86_400_000);
assert_eq!(parse_interval_to_ms("5 min"), 300_000);
}
fn parse_select(sql: &str) -> ast::Select {
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
let stmts = Parser::parse_sql(&GenericDialect {}, sql).unwrap();
match stmts.into_iter().next().unwrap() {
ast::Statement::Query(q) => match *q.body {
ast::SetExpr::Select(s) => *s,
_ => panic!("expected SELECT"),
},
_ => panic!("expected query"),
}
}
#[test]
fn resolve_group_by_alias_to_time_bucket() {
let select = parse_select(
"SELECT time_bucket('1 hour', timestamp) AS b, COUNT(*) FROM t GROUP BY b",
);
let functions = FunctionRegistry::new();
if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
assert!(resolved.is_some(), "alias 'b' should resolve");
let interval = try_extract_time_bucket(resolved.unwrap(), &functions).unwrap();
assert_eq!(interval, Some(3_600_000));
} else {
panic!("expected GROUP BY expressions");
}
}
#[test]
fn resolve_group_by_ordinal_to_time_bucket() {
let select =
parse_select("SELECT time_bucket('5 minutes', timestamp), COUNT(*) FROM t GROUP BY 1");
let functions = FunctionRegistry::new();
if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
assert!(resolved.is_some(), "ordinal 1 should resolve");
let interval = try_extract_time_bucket(resolved.unwrap(), &functions).unwrap();
assert_eq!(interval, Some(300_000));
} else {
panic!("expected GROUP BY expressions");
}
}
#[test]
fn resolve_group_by_plain_column_not_time_bucket() {
let select = parse_select("SELECT qtype, COUNT(*) FROM t GROUP BY qtype");
let functions = FunctionRegistry::new();
if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
assert!(resolved.is_none());
let interval = try_extract_time_bucket(&exprs[0], &functions).unwrap();
assert_eq!(interval, None);
} else {
panic!("expected GROUP BY expressions");
}
}
}