Skip to main content

nodedb_sql/planner/
aggregate.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! GROUP BY and aggregate planning.
4
5use sqlparser::ast::{self, GroupByExpr};
6
7use crate::engine_rules::{self, AggregateParams};
8use crate::error::Result;
9use crate::functions::registry::{FunctionRegistry, SearchTrigger};
10use crate::parser::normalize::normalize_ident;
11use crate::planner::grouping_sets::expand_group_by;
12use crate::resolver::columns::ResolvedTable;
13use crate::resolver::expr::convert_expr;
14use crate::temporal::TemporalScope;
15use crate::types::*;
16
17/// Plan an aggregate query (GROUP BY + aggregate functions).
18pub fn plan_aggregate(
19    select: &ast::Select,
20    table: &ResolvedTable,
21    filters: &[Filter],
22    _scope: &crate::resolver::columns::TableScope,
23    functions: &FunctionRegistry,
24    temporal: &TemporalScope,
25) -> Result<SqlPlan> {
26    // Detect ROLLUP / CUBE / GROUPING SETS before falling through to plain convert.
27    let grouping_expansion = expand_group_by(&select.group_by)?;
28
29    let (group_by_exprs, grouping_sets) = if let Some(exp) = grouping_expansion {
30        (exp.canonical_keys, Some(exp.grouping_sets))
31    } else {
32        (convert_group_by(&select.group_by)?, None)
33    };
34
35    let mut aggregates = extract_aggregates_from_projection(&select.projection, functions)?;
36    let having = match &select.having {
37        Some(expr) => super::select::convert_where_to_filters(expr)?,
38        None => Vec::new(),
39    };
40
41    // When grouping sets are present, detect GROUPING(col) in the projection and
42    // synthesize AggregateExpr entries so the executor can compute them per-set.
43    if grouping_sets.is_some() {
44        let grouping_aggs = extract_grouping_calls(&select.projection, &group_by_exprs)?;
45        aggregates.extend(grouping_aggs);
46    }
47
48    // Extract timeseries-specific params (bucket interval, group columns) if applicable.
49    let (bucket_interval_ms, group_columns) =
50        extract_timeseries_params(&select.group_by, &select.projection, functions)?;
51
52    let rules = engine_rules::resolve_engine_rules(table.info.engine);
53    let base_plan = rules.plan_aggregate(AggregateParams {
54        collection: table.name.clone(),
55        alias: table.alias.clone(),
56        filters: filters.to_vec(),
57        group_by: group_by_exprs.clone(),
58        aggregates: aggregates.clone(),
59        having: having.clone(),
60        limit: 10000,
61        bucket_interval_ms,
62        group_columns,
63        has_auto_tier: table.info.has_auto_tier,
64        bitemporal: table.info.bitemporal,
65        temporal: *temporal,
66    })?;
67
68    // Wrap the plan to attach grouping sets if present.
69    if let Some(sets) = grouping_sets {
70        return Ok(attach_grouping_sets(
71            base_plan,
72            group_by_exprs,
73            aggregates,
74            having,
75            sets,
76        ));
77    }
78
79    Ok(base_plan)
80}
81
82/// Attach `grouping_sets` to an existing `SqlPlan::Aggregate` node, or wrap it
83/// in a new one when the engine rules returned a non-Aggregate plan.
84fn attach_grouping_sets(
85    base_plan: SqlPlan,
86    group_by: Vec<SqlExpr>,
87    aggregates: Vec<AggregateExpr>,
88    having: Vec<Filter>,
89    grouping_sets: Vec<Vec<usize>>,
90) -> SqlPlan {
91    match base_plan {
92        SqlPlan::Aggregate {
93            input,
94            limit,
95            grouping_sets: _,
96            sort_keys,
97            ..
98        } => SqlPlan::Aggregate {
99            input,
100            group_by,
101            aggregates,
102            having,
103            limit,
104            grouping_sets: Some(grouping_sets),
105            sort_keys,
106        },
107        other => {
108            // Engine returned something other than Aggregate (e.g. TimeseriesIngest).
109            // Wrap it so grouping sets are not silently dropped.
110            SqlPlan::Aggregate {
111                input: Box::new(other),
112                group_by,
113                aggregates,
114                having,
115                limit: 10000,
116                grouping_sets: Some(grouping_sets),
117                sort_keys: Vec::new(),
118            }
119        }
120    }
121}
122
123/// Extract timeseries-specific parameters from GROUP BY (bucket interval, group columns).
124///
125/// Returns `(Some(interval_ms), group_columns)` if a `time_bucket()` call is found,
126/// or `(None, empty)` otherwise. Non-timeseries engines ignore these values.
127fn extract_timeseries_params(
128    raw_group_by: &GroupByExpr,
129    select_items: &[ast::SelectItem],
130    functions: &FunctionRegistry,
131) -> Result<(Option<i64>, Vec<String>)> {
132    let mut bucket_interval_ms: Option<i64> = None;
133    let mut group_columns = Vec::new();
134
135    if let GroupByExpr::Expressions(exprs, _) = raw_group_by {
136        for expr in exprs {
137            let resolved = resolve_group_by_expr(expr, select_items);
138            let check_expr = resolved.unwrap_or(expr);
139
140            if let Some(interval) = try_extract_time_bucket(check_expr, functions)? {
141                bucket_interval_ms = Some(interval);
142                continue;
143            }
144
145            if let ast::Expr::Identifier(ident) = expr {
146                group_columns.push(normalize_ident(ident));
147            }
148        }
149    }
150
151    Ok((bucket_interval_ms, group_columns))
152}
153
154/// Check if an expression is a time_bucket() call and extract the interval.
155fn try_extract_time_bucket(expr: &ast::Expr, functions: &FunctionRegistry) -> Result<Option<i64>> {
156    if let ast::Expr::Function(func) = expr {
157        let name = func
158            .name
159            .0
160            .iter()
161            .map(|p| match p {
162                ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
163                _ => String::new(),
164            })
165            .collect::<Vec<_>>()
166            .join(".");
167        if functions.search_trigger(&name) == SearchTrigger::TimeBucket {
168            return Ok(Some(extract_bucket_interval(func)?));
169        }
170    }
171    Ok(None)
172}
173
174/// Resolve a GROUP BY expression that references a SELECT alias or ordinal.
175///
176/// `GROUP BY b` where `b` is an alias → returns the aliased expression.
177/// `GROUP BY 1` → returns the 1st SELECT expression (0-indexed).
178fn resolve_group_by_expr<'a>(
179    expr: &ast::Expr,
180    select_items: &'a [ast::SelectItem],
181) -> Option<&'a ast::Expr> {
182    match expr {
183        ast::Expr::Identifier(ident) => {
184            let alias_name = normalize_ident(ident);
185            select_items.iter().find_map(|item| {
186                if let ast::SelectItem::ExprWithAlias { expr, alias } = item
187                    && normalize_ident(alias) == alias_name
188                {
189                    Some(expr)
190                } else {
191                    None
192                }
193            })
194        }
195        ast::Expr::Value(v) => {
196            if let ast::Value::Number(n, _) = &v.value
197                && let Ok(idx) = n.parse::<usize>()
198                && idx >= 1
199                && idx <= select_items.len()
200            {
201                match &select_items[idx - 1] {
202                    ast::SelectItem::UnnamedExpr(e) => Some(e),
203                    ast::SelectItem::ExprWithAlias { expr, .. } => Some(expr),
204                    _ => None,
205                }
206            } else {
207                None
208            }
209        }
210        _ => None,
211    }
212}
213
214/// Extract the bucket interval from a time_bucket() call.
215///
216/// Handles both argument orders:
217/// - `time_bucket('1 hour', timestamp)` — interval first
218/// - `time_bucket(timestamp, '1 hour')` — timestamp first
219/// - `time_bucket(3600, timestamp)` — integer seconds
220fn extract_bucket_interval(func: &ast::Function) -> Result<i64> {
221    let args = match &func.args {
222        ast::FunctionArguments::List(args) => &args.args,
223        _ => return Ok(0),
224    };
225    // Try each argument position for the interval literal.
226    for arg in args {
227        if let ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(ast::Expr::Value(v))) = arg {
228            match &v.value {
229                ast::Value::SingleQuotedString(s) => {
230                    let ms = parse_interval_to_ms(s);
231                    if ms > 0 {
232                        return Ok(ms);
233                    }
234                }
235                ast::Value::Number(n, _) => {
236                    if let Ok(secs) = n.parse::<i64>()
237                        && secs > 0
238                    {
239                        return Ok(secs * 1000);
240                    }
241                }
242                _ => {}
243            }
244        }
245    }
246    Ok(0)
247}
248
249/// Parse an interval string to milliseconds.
250///
251/// Delegates to the canonical `nodedb_types::kv_parsing::parse_interval_to_ms`.
252fn parse_interval_to_ms(s: &str) -> i64 {
253    nodedb_types::kv_parsing::parse_interval_to_ms(s)
254        .map(|ms| ms as i64)
255        .unwrap_or(0)
256}
257
258/// Convert GROUP BY clause to SqlExpr list.
259pub fn convert_group_by(group_by: &GroupByExpr) -> Result<Vec<SqlExpr>> {
260    match group_by {
261        GroupByExpr::All(_) => Ok(Vec::new()),
262        GroupByExpr::Expressions(exprs, _) => exprs.iter().map(convert_expr).collect(),
263    }
264}
265
266/// Scan the SELECT projection for `GROUPING(col)` calls and return synthetic
267/// `AggregateExpr` entries so the executor can compute the per-set bitmask value.
268///
269/// For `GROUPING(col)`, the canonical index of `col` in `canonical_keys` is
270/// encoded into the `field` of the resulting `AggregateExpr` (as a decimal string).
271/// The executor reads this index and checks the grouping-set bitmask at run time.
272fn extract_grouping_calls(
273    items: &[ast::SelectItem],
274    canonical_keys: &[SqlExpr],
275) -> Result<Vec<AggregateExpr>> {
276    let mut out = Vec::new();
277    for item in items {
278        let (expr, alias): (&ast::Expr, String) = match item {
279            ast::SelectItem::UnnamedExpr(expr) => (expr, format!("{expr}")),
280            ast::SelectItem::ExprWithAlias { expr, alias } => (expr, normalize_ident(alias)),
281            _ => continue,
282        };
283        collect_grouping_from_expr(expr, &alias, canonical_keys, &mut out)?;
284    }
285    Ok(out)
286}
287
288/// Recursively collect `GROUPING(col)` calls from an expression.
289fn collect_grouping_from_expr(
290    expr: &ast::Expr,
291    alias: &str,
292    canonical_keys: &[SqlExpr],
293    out: &mut Vec<AggregateExpr>,
294) -> Result<()> {
295    match expr {
296        ast::Expr::Function(f) => {
297            let name = normalize_function_name(f);
298            if name.eq_ignore_ascii_case("grouping") {
299                // Extract the column argument(s).
300                let args = function_args_exprs(f);
301                for col_expr in &args {
302                    let canonical_idx = crate::planner::grouping_sets::resolve_grouping_col(
303                        col_expr,
304                        canonical_keys,
305                    )?;
306                    // Encode index in the field name; alias is user-visible output name.
307                    out.push(AggregateExpr {
308                        function: "grouping".into(),
309                        args: vec![convert_expr(col_expr)?],
310                        alias: alias.to_string(),
311                        distinct: false,
312                        grouping_col_index: Some(canonical_idx),
313                    });
314                }
315            }
316        }
317        // Recurse into binary ops and other wrappers.
318        ast::Expr::BinaryOp { left, right, .. } => {
319            collect_grouping_from_expr(left, alias, canonical_keys, out)?;
320            collect_grouping_from_expr(right, alias, canonical_keys, out)?;
321        }
322        _ => {}
323    }
324    Ok(())
325}
326
327/// Extract the positional expression arguments from a function call.
328fn function_args_exprs(f: &ast::Function) -> Vec<&ast::Expr> {
329    match &f.args {
330        ast::FunctionArguments::List(list) => list
331            .args
332            .iter()
333            .filter_map(|a| match a {
334                ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => Some(e),
335                _ => None,
336            })
337            .collect(),
338        _ => Vec::new(),
339    }
340}
341
342/// Return the simple lowercase name of a function (unqualified part only).
343fn normalize_function_name(f: &ast::Function) -> String {
344    f.name
345        .0
346        .last()
347        .map(|p| match p {
348            ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
349            _ => String::new(),
350        })
351        .unwrap_or_default()
352}
353
354/// Extract aggregate expressions from SELECT projection.
355pub fn extract_aggregates_from_projection(
356    items: &[ast::SelectItem],
357    functions: &FunctionRegistry,
358) -> Result<Vec<AggregateExpr>> {
359    let mut aggregates = Vec::new();
360    for item in items {
361        let (expr, alias): (&ast::Expr, String) = match item {
362            // Lowercase the unparsed expression text so the resulting
363            // alias (which becomes the JSON row key after the
364            // `apply_user_aliases_to_rows` rename) matches the pgwire
365            // lookup key built by `expr_column_names`, which also
366            // lowercases. Without this match the row description
367            // column `count(distinct user_id)` would not resolve to
368            // the JSON value stored under `COUNT(DISTINCT user_id)`,
369            // and the client would see NULL.
370            ast::SelectItem::UnnamedExpr(expr) => (expr, format!("{expr}").to_lowercase()),
371            ast::SelectItem::ExprWithAlias { expr, alias } => (expr, normalize_ident(alias)),
372            _ => continue,
373        };
374        let mut extracted = crate::aggregate_walk::extract_aggregates(expr, &alias, functions)?;
375        aggregates.append(&mut extracted);
376    }
377    Ok(aggregates)
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn parse_intervals() {
386        assert_eq!(parse_interval_to_ms("1h"), 3_600_000);
387        assert_eq!(parse_interval_to_ms("15m"), 900_000);
388        assert_eq!(parse_interval_to_ms("30s"), 30_000);
389        assert_eq!(parse_interval_to_ms("7d"), 604_800_000);
390        // Word-form intervals.
391        assert_eq!(parse_interval_to_ms("1 hour"), 3_600_000);
392        assert_eq!(parse_interval_to_ms("2 hours"), 7_200_000);
393        assert_eq!(parse_interval_to_ms("15 minutes"), 900_000);
394        assert_eq!(parse_interval_to_ms("30 seconds"), 30_000);
395        assert_eq!(parse_interval_to_ms("1 day"), 86_400_000);
396        assert_eq!(parse_interval_to_ms("5 min"), 300_000);
397    }
398
399    /// Helper: parse a SQL SELECT and return the select body + projection.
400    fn parse_select(sql: &str) -> ast::Select {
401        use sqlparser::dialect::GenericDialect;
402        use sqlparser::parser::Parser;
403        let stmts = Parser::parse_sql(&GenericDialect {}, sql).unwrap();
404        match stmts.into_iter().next().unwrap() {
405            ast::Statement::Query(q) => match *q.body {
406                ast::SetExpr::Select(s) => *s,
407                _ => panic!("expected SELECT"),
408            },
409            _ => panic!("expected query"),
410        }
411    }
412
413    #[test]
414    fn resolve_group_by_alias_to_time_bucket() {
415        let select = parse_select(
416            "SELECT time_bucket('1 hour', timestamp) AS b, COUNT(*) FROM t GROUP BY b",
417        );
418        let functions = FunctionRegistry::new();
419
420        if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
421            let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
422            assert!(resolved.is_some(), "alias 'b' should resolve");
423            let interval = try_extract_time_bucket(resolved.unwrap(), &functions).unwrap();
424            assert_eq!(interval, Some(3_600_000));
425        } else {
426            panic!("expected GROUP BY expressions");
427        }
428    }
429
430    #[test]
431    fn resolve_group_by_ordinal_to_time_bucket() {
432        let select =
433            parse_select("SELECT time_bucket('5 minutes', timestamp), COUNT(*) FROM t GROUP BY 1");
434        let functions = FunctionRegistry::new();
435
436        if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
437            let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
438            assert!(resolved.is_some(), "ordinal 1 should resolve");
439            let interval = try_extract_time_bucket(resolved.unwrap(), &functions).unwrap();
440            assert_eq!(interval, Some(300_000));
441        } else {
442            panic!("expected GROUP BY expressions");
443        }
444    }
445
446    #[test]
447    fn resolve_group_by_plain_column_not_time_bucket() {
448        let select = parse_select("SELECT qtype, COUNT(*) FROM t GROUP BY qtype");
449        let functions = FunctionRegistry::new();
450
451        if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
452            let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
453            // 'qtype' is not an alias in SELECT, so resolve returns None.
454            assert!(resolved.is_none());
455            let interval = try_extract_time_bucket(&exprs[0], &functions).unwrap();
456            assert_eq!(interval, None);
457        } else {
458            panic!("expected GROUP BY expressions");
459        }
460    }
461}