Skip to main content

nodedb_sql/planner/
aggregate.rs

1//! GROUP BY and aggregate planning.
2
3use sqlparser::ast::{self, GroupByExpr};
4
5use crate::engine_rules::{self, AggregateParams};
6use crate::error::Result;
7use crate::functions::registry::{FunctionRegistry, SearchTrigger};
8use crate::parser::normalize::normalize_ident;
9use crate::resolver::columns::ResolvedTable;
10use crate::resolver::expr::convert_expr;
11use crate::types::*;
12
13/// Plan an aggregate query (GROUP BY + aggregate functions).
14pub fn plan_aggregate(
15    select: &ast::Select,
16    table: &ResolvedTable,
17    filters: &[Filter],
18    _scope: &crate::resolver::columns::TableScope,
19    functions: &FunctionRegistry,
20) -> Result<SqlPlan> {
21    let group_by_exprs = convert_group_by(&select.group_by)?;
22    let aggregates = extract_aggregates_from_projection(&select.projection, functions)?;
23    let having = match &select.having {
24        Some(expr) => super::select::convert_where_to_filters(expr)?,
25        None => Vec::new(),
26    };
27
28    // Extract timeseries-specific params (bucket interval, group columns) if applicable.
29    let (bucket_interval_ms, group_columns) =
30        extract_timeseries_params(&select.group_by, &select.projection, functions)?;
31
32    let rules = engine_rules::resolve_engine_rules(table.info.engine);
33    rules.plan_aggregate(AggregateParams {
34        collection: table.name.clone(),
35        alias: table.alias.clone(),
36        filters: filters.to_vec(),
37        group_by: group_by_exprs,
38        aggregates,
39        having,
40        limit: 10000,
41        bucket_interval_ms,
42        group_columns,
43        has_auto_tier: table.info.has_auto_tier,
44    })
45}
46
47/// Extract timeseries-specific parameters from GROUP BY (bucket interval, group columns).
48///
49/// Returns `(Some(interval_ms), group_columns)` if a `time_bucket()` call is found,
50/// or `(None, empty)` otherwise. Non-timeseries engines ignore these values.
51fn extract_timeseries_params(
52    raw_group_by: &GroupByExpr,
53    select_items: &[ast::SelectItem],
54    functions: &FunctionRegistry,
55) -> Result<(Option<i64>, Vec<String>)> {
56    let mut bucket_interval_ms: Option<i64> = None;
57    let mut group_columns = Vec::new();
58
59    if let GroupByExpr::Expressions(exprs, _) = raw_group_by {
60        for expr in exprs {
61            let resolved = resolve_group_by_expr(expr, select_items);
62            let check_expr = resolved.unwrap_or(expr);
63
64            if let Some(interval) = try_extract_time_bucket(check_expr, functions)? {
65                bucket_interval_ms = Some(interval);
66                continue;
67            }
68
69            if let ast::Expr::Identifier(ident) = expr {
70                group_columns.push(normalize_ident(ident));
71            }
72        }
73    }
74
75    Ok((bucket_interval_ms, group_columns))
76}
77
78/// Check if an expression is a time_bucket() call and extract the interval.
79fn try_extract_time_bucket(expr: &ast::Expr, functions: &FunctionRegistry) -> Result<Option<i64>> {
80    if let ast::Expr::Function(func) = expr {
81        let name = func
82            .name
83            .0
84            .iter()
85            .map(|p| match p {
86                ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
87                _ => String::new(),
88            })
89            .collect::<Vec<_>>()
90            .join(".");
91        if functions.search_trigger(&name) == SearchTrigger::TimeBucket {
92            return Ok(Some(extract_bucket_interval(func)?));
93        }
94    }
95    Ok(None)
96}
97
98/// Resolve a GROUP BY expression that references a SELECT alias or ordinal.
99///
100/// `GROUP BY b` where `b` is an alias → returns the aliased expression.
101/// `GROUP BY 1` → returns the 1st SELECT expression (0-indexed).
102fn resolve_group_by_expr<'a>(
103    expr: &ast::Expr,
104    select_items: &'a [ast::SelectItem],
105) -> Option<&'a ast::Expr> {
106    match expr {
107        ast::Expr::Identifier(ident) => {
108            let alias_name = normalize_ident(ident);
109            select_items.iter().find_map(|item| {
110                if let ast::SelectItem::ExprWithAlias { expr, alias } = item
111                    && normalize_ident(alias) == alias_name
112                {
113                    Some(expr)
114                } else {
115                    None
116                }
117            })
118        }
119        ast::Expr::Value(v) => {
120            if let ast::Value::Number(n, _) = &v.value
121                && let Ok(idx) = n.parse::<usize>()
122                && idx >= 1
123                && idx <= select_items.len()
124            {
125                match &select_items[idx - 1] {
126                    ast::SelectItem::UnnamedExpr(e) => Some(e),
127                    ast::SelectItem::ExprWithAlias { expr, .. } => Some(expr),
128                    _ => None,
129                }
130            } else {
131                None
132            }
133        }
134        _ => None,
135    }
136}
137
138/// Extract the bucket interval from a time_bucket() call.
139///
140/// Handles both argument orders:
141/// - `time_bucket('1 hour', timestamp)` — interval first
142/// - `time_bucket(timestamp, '1 hour')` — timestamp first
143/// - `time_bucket(3600, timestamp)` — integer seconds
144fn extract_bucket_interval(func: &ast::Function) -> Result<i64> {
145    let args = match &func.args {
146        ast::FunctionArguments::List(args) => &args.args,
147        _ => return Ok(0),
148    };
149    // Try each argument position for the interval literal.
150    for arg in args {
151        if let ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(ast::Expr::Value(v))) = arg {
152            match &v.value {
153                ast::Value::SingleQuotedString(s) => {
154                    let ms = parse_interval_to_ms(s);
155                    if ms > 0 {
156                        return Ok(ms);
157                    }
158                }
159                ast::Value::Number(n, _) => {
160                    if let Ok(secs) = n.parse::<i64>()
161                        && secs > 0
162                    {
163                        return Ok(secs * 1000);
164                    }
165                }
166                _ => {}
167            }
168        }
169    }
170    Ok(0)
171}
172
173/// Parse an interval string to milliseconds.
174///
175/// Delegates to the canonical `nodedb_types::kv_parsing::parse_interval_to_ms`.
176fn parse_interval_to_ms(s: &str) -> i64 {
177    nodedb_types::kv_parsing::parse_interval_to_ms(s)
178        .map(|ms| ms as i64)
179        .unwrap_or(0)
180}
181
182/// Convert GROUP BY clause to SqlExpr list.
183pub fn convert_group_by(group_by: &GroupByExpr) -> Result<Vec<SqlExpr>> {
184    match group_by {
185        GroupByExpr::All(_) => Ok(Vec::new()),
186        GroupByExpr::Expressions(exprs, _) => exprs.iter().map(convert_expr).collect(),
187    }
188}
189
190/// Extract aggregate expressions from SELECT projection.
191pub fn extract_aggregates_from_projection(
192    items: &[ast::SelectItem],
193    functions: &FunctionRegistry,
194) -> Result<Vec<AggregateExpr>> {
195    let mut aggregates = Vec::new();
196    for item in items {
197        match item {
198            ast::SelectItem::UnnamedExpr(expr) => {
199                extract_aggregates_from_expr(expr, &format!("{expr}"), functions, &mut aggregates)?;
200            }
201            ast::SelectItem::ExprWithAlias { expr, alias } => {
202                extract_aggregates_from_expr(
203                    expr,
204                    &normalize_ident(alias),
205                    functions,
206                    &mut aggregates,
207                )?;
208            }
209            _ => {}
210        }
211    }
212    Ok(aggregates)
213}
214
215fn extract_aggregates_from_expr(
216    expr: &ast::Expr,
217    alias: &str,
218    functions: &FunctionRegistry,
219    out: &mut Vec<AggregateExpr>,
220) -> Result<()> {
221    match expr {
222        ast::Expr::Function(func) => {
223            let name = func
224                .name
225                .0
226                .iter()
227                .map(|p| match p {
228                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
229                    _ => String::new(),
230                })
231                .collect::<Vec<_>>()
232                .join(".");
233            if functions.is_aggregate(&name) {
234                let args = match &func.args {
235                    ast::FunctionArguments::List(args) => args
236                        .args
237                        .iter()
238                        .filter_map(|a| match a {
239                            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => {
240                                convert_expr(e).ok()
241                            }
242                            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard) => {
243                                Some(SqlExpr::Wildcard)
244                            }
245                            _ => None,
246                        })
247                        .collect(),
248                    _ => Vec::new(),
249                };
250                let distinct = matches!(&func.args,
251                    ast::FunctionArguments::List(args) if matches!(args.duplicate_treatment, Some(ast::DuplicateTreatment::Distinct))
252                );
253                out.push(AggregateExpr {
254                    function: name,
255                    args,
256                    alias: alias.into(),
257                    distinct,
258                });
259            }
260        }
261        ast::Expr::BinaryOp { left, right, .. } => {
262            extract_aggregates_from_expr(left, alias, functions, out)?;
263            extract_aggregates_from_expr(right, alias, functions, out)?;
264        }
265        ast::Expr::Nested(inner) => {
266            extract_aggregates_from_expr(inner, alias, functions, out)?;
267        }
268        _ => {}
269    }
270    Ok(())
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn parse_intervals() {
279        assert_eq!(parse_interval_to_ms("1h"), 3_600_000);
280        assert_eq!(parse_interval_to_ms("15m"), 900_000);
281        assert_eq!(parse_interval_to_ms("30s"), 30_000);
282        assert_eq!(parse_interval_to_ms("7d"), 604_800_000);
283        // Word-form intervals.
284        assert_eq!(parse_interval_to_ms("1 hour"), 3_600_000);
285        assert_eq!(parse_interval_to_ms("2 hours"), 7_200_000);
286        assert_eq!(parse_interval_to_ms("15 minutes"), 900_000);
287        assert_eq!(parse_interval_to_ms("30 seconds"), 30_000);
288        assert_eq!(parse_interval_to_ms("1 day"), 86_400_000);
289        assert_eq!(parse_interval_to_ms("5 min"), 300_000);
290    }
291
292    /// Helper: parse a SQL SELECT and return the select body + projection.
293    fn parse_select(sql: &str) -> ast::Select {
294        use sqlparser::dialect::GenericDialect;
295        use sqlparser::parser::Parser;
296        let stmts = Parser::parse_sql(&GenericDialect {}, sql).unwrap();
297        match stmts.into_iter().next().unwrap() {
298            ast::Statement::Query(q) => match *q.body {
299                ast::SetExpr::Select(s) => *s,
300                _ => panic!("expected SELECT"),
301            },
302            _ => panic!("expected query"),
303        }
304    }
305
306    #[test]
307    fn resolve_group_by_alias_to_time_bucket() {
308        let select = parse_select(
309            "SELECT time_bucket('1 hour', timestamp) AS b, COUNT(*) FROM t GROUP BY b",
310        );
311        let functions = FunctionRegistry::new();
312
313        if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
314            let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
315            assert!(resolved.is_some(), "alias 'b' should resolve");
316            let interval = try_extract_time_bucket(resolved.unwrap(), &functions).unwrap();
317            assert_eq!(interval, Some(3_600_000));
318        } else {
319            panic!("expected GROUP BY expressions");
320        }
321    }
322
323    #[test]
324    fn resolve_group_by_ordinal_to_time_bucket() {
325        let select =
326            parse_select("SELECT time_bucket('5 minutes', timestamp), COUNT(*) FROM t GROUP BY 1");
327        let functions = FunctionRegistry::new();
328
329        if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
330            let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
331            assert!(resolved.is_some(), "ordinal 1 should resolve");
332            let interval = try_extract_time_bucket(resolved.unwrap(), &functions).unwrap();
333            assert_eq!(interval, Some(300_000));
334        } else {
335            panic!("expected GROUP BY expressions");
336        }
337    }
338
339    #[test]
340    fn resolve_group_by_plain_column_not_time_bucket() {
341        let select = parse_select("SELECT qtype, COUNT(*) FROM t GROUP BY qtype");
342        let functions = FunctionRegistry::new();
343
344        if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
345            let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
346            // 'qtype' is not an alias in SELECT, so resolve returns None.
347            assert!(resolved.is_none());
348            let interval = try_extract_time_bucket(&exprs[0], &functions).unwrap();
349            assert_eq!(interval, None);
350        } else {
351            panic!("expected GROUP BY expressions");
352        }
353    }
354}