Skip to main content

nodedb_sql/planner/
aggregate.rs

1//! GROUP BY and aggregate planning.
2
3use sqlparser::ast::{self, GroupByExpr};
4
5use crate::engine_rules::{self, AggregateParams};
6use crate::error::Result;
7use crate::functions::registry::{FunctionRegistry, SearchTrigger};
8use crate::parser::normalize::normalize_ident;
9use crate::resolver::columns::ResolvedTable;
10use crate::resolver::expr::convert_expr;
11use crate::types::*;
12
13/// Plan an aggregate query (GROUP BY + aggregate functions).
14pub fn plan_aggregate(
15    select: &ast::Select,
16    table: &ResolvedTable,
17    filters: &[Filter],
18    _scope: &crate::resolver::columns::TableScope,
19    functions: &FunctionRegistry,
20) -> Result<SqlPlan> {
21    let group_by_exprs = convert_group_by(&select.group_by)?;
22    let aggregates = extract_aggregates_from_projection(&select.projection, functions)?;
23    let having = match &select.having {
24        Some(expr) => super::select::convert_where_to_filters(expr)?,
25        None => Vec::new(),
26    };
27
28    // Extract timeseries-specific params (bucket interval, group columns) if applicable.
29    let (bucket_interval_ms, group_columns) =
30        extract_timeseries_params(&select.group_by, &select.projection, functions)?;
31
32    let rules = engine_rules::resolve_engine_rules(table.info.engine);
33    rules.plan_aggregate(AggregateParams {
34        collection: table.name.clone(),
35        alias: table.alias.clone(),
36        filters: filters.to_vec(),
37        group_by: group_by_exprs,
38        aggregates,
39        having,
40        limit: 10000,
41        bucket_interval_ms,
42        group_columns,
43        has_auto_tier: table.info.has_auto_tier,
44    })
45}
46
47/// Extract timeseries-specific parameters from GROUP BY (bucket interval, group columns).
48///
49/// Returns `(Some(interval_ms), group_columns)` if a `time_bucket()` call is found,
50/// or `(None, empty)` otherwise. Non-timeseries engines ignore these values.
51fn extract_timeseries_params(
52    raw_group_by: &GroupByExpr,
53    select_items: &[ast::SelectItem],
54    functions: &FunctionRegistry,
55) -> Result<(Option<i64>, Vec<String>)> {
56    let mut bucket_interval_ms: Option<i64> = None;
57    let mut group_columns = Vec::new();
58
59    if let GroupByExpr::Expressions(exprs, _) = raw_group_by {
60        for expr in exprs {
61            let resolved = resolve_group_by_expr(expr, select_items);
62            let check_expr = resolved.unwrap_or(expr);
63
64            if let Some(interval) = try_extract_time_bucket(check_expr, functions)? {
65                bucket_interval_ms = Some(interval);
66                continue;
67            }
68
69            if let ast::Expr::Identifier(ident) = expr {
70                group_columns.push(normalize_ident(ident));
71            }
72        }
73    }
74
75    Ok((bucket_interval_ms, group_columns))
76}
77
78/// Check if an expression is a time_bucket() call and extract the interval.
79fn try_extract_time_bucket(expr: &ast::Expr, functions: &FunctionRegistry) -> Result<Option<i64>> {
80    if let ast::Expr::Function(func) = expr {
81        let name = func
82            .name
83            .0
84            .iter()
85            .map(|p| match p {
86                ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
87                _ => String::new(),
88            })
89            .collect::<Vec<_>>()
90            .join(".");
91        if functions.search_trigger(&name) == SearchTrigger::TimeBucket {
92            return Ok(Some(extract_bucket_interval(func)?));
93        }
94    }
95    Ok(None)
96}
97
98/// Resolve a GROUP BY expression that references a SELECT alias or ordinal.
99///
100/// `GROUP BY b` where `b` is an alias → returns the aliased expression.
101/// `GROUP BY 1` → returns the 1st SELECT expression (0-indexed).
102fn resolve_group_by_expr<'a>(
103    expr: &ast::Expr,
104    select_items: &'a [ast::SelectItem],
105) -> Option<&'a ast::Expr> {
106    match expr {
107        ast::Expr::Identifier(ident) => {
108            let alias_name = normalize_ident(ident);
109            select_items.iter().find_map(|item| {
110                if let ast::SelectItem::ExprWithAlias { expr, alias } = item
111                    && normalize_ident(alias) == alias_name
112                {
113                    Some(expr)
114                } else {
115                    None
116                }
117            })
118        }
119        ast::Expr::Value(v) => {
120            if let ast::Value::Number(n, _) = &v.value
121                && let Ok(idx) = n.parse::<usize>()
122                && idx >= 1
123                && idx <= select_items.len()
124            {
125                match &select_items[idx - 1] {
126                    ast::SelectItem::UnnamedExpr(e) => Some(e),
127                    ast::SelectItem::ExprWithAlias { expr, .. } => Some(expr),
128                    _ => None,
129                }
130            } else {
131                None
132            }
133        }
134        _ => None,
135    }
136}
137
138/// Extract the bucket interval from a time_bucket() call.
139///
140/// Handles both argument orders:
141/// - `time_bucket('1 hour', timestamp)` — interval first
142/// - `time_bucket(timestamp, '1 hour')` — timestamp first
143/// - `time_bucket(3600, timestamp)` — integer seconds
144fn extract_bucket_interval(func: &ast::Function) -> Result<i64> {
145    let args = match &func.args {
146        ast::FunctionArguments::List(args) => &args.args,
147        _ => return Ok(0),
148    };
149    // Try each argument position for the interval literal.
150    for arg in args {
151        if let ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(ast::Expr::Value(v))) = arg {
152            match &v.value {
153                ast::Value::SingleQuotedString(s) => {
154                    let ms = parse_interval_to_ms(s);
155                    if ms > 0 {
156                        return Ok(ms);
157                    }
158                }
159                ast::Value::Number(n, _) => {
160                    if let Ok(secs) = n.parse::<i64>()
161                        && secs > 0
162                    {
163                        return Ok(secs * 1000);
164                    }
165                }
166                _ => {}
167            }
168        }
169    }
170    Ok(0)
171}
172
173/// Parse an interval string to milliseconds.
174///
175/// Delegates to the canonical `nodedb_types::kv_parsing::parse_interval_to_ms`.
176fn parse_interval_to_ms(s: &str) -> i64 {
177    nodedb_types::kv_parsing::parse_interval_to_ms(s)
178        .map(|ms| ms as i64)
179        .unwrap_or(0)
180}
181
182/// Convert GROUP BY clause to SqlExpr list.
183pub fn convert_group_by(group_by: &GroupByExpr) -> Result<Vec<SqlExpr>> {
184    match group_by {
185        GroupByExpr::All(_) => Ok(Vec::new()),
186        GroupByExpr::Expressions(exprs, _) => exprs.iter().map(convert_expr).collect(),
187    }
188}
189
190/// Extract aggregate expressions from SELECT projection.
191pub fn extract_aggregates_from_projection(
192    items: &[ast::SelectItem],
193    functions: &FunctionRegistry,
194) -> Result<Vec<AggregateExpr>> {
195    let mut aggregates = Vec::new();
196    for item in items {
197        let (expr, alias): (&ast::Expr, String) = match item {
198            ast::SelectItem::UnnamedExpr(expr) => (expr, format!("{expr}")),
199            ast::SelectItem::ExprWithAlias { expr, alias } => (expr, normalize_ident(alias)),
200            _ => continue,
201        };
202        let mut extracted = crate::aggregate_walk::extract_aggregates(expr, &alias, functions)?;
203        aggregates.append(&mut extracted);
204    }
205    Ok(aggregates)
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn parse_intervals() {
214        assert_eq!(parse_interval_to_ms("1h"), 3_600_000);
215        assert_eq!(parse_interval_to_ms("15m"), 900_000);
216        assert_eq!(parse_interval_to_ms("30s"), 30_000);
217        assert_eq!(parse_interval_to_ms("7d"), 604_800_000);
218        // Word-form intervals.
219        assert_eq!(parse_interval_to_ms("1 hour"), 3_600_000);
220        assert_eq!(parse_interval_to_ms("2 hours"), 7_200_000);
221        assert_eq!(parse_interval_to_ms("15 minutes"), 900_000);
222        assert_eq!(parse_interval_to_ms("30 seconds"), 30_000);
223        assert_eq!(parse_interval_to_ms("1 day"), 86_400_000);
224        assert_eq!(parse_interval_to_ms("5 min"), 300_000);
225    }
226
227    /// Helper: parse a SQL SELECT and return the select body + projection.
228    fn parse_select(sql: &str) -> ast::Select {
229        use sqlparser::dialect::GenericDialect;
230        use sqlparser::parser::Parser;
231        let stmts = Parser::parse_sql(&GenericDialect {}, sql).unwrap();
232        match stmts.into_iter().next().unwrap() {
233            ast::Statement::Query(q) => match *q.body {
234                ast::SetExpr::Select(s) => *s,
235                _ => panic!("expected SELECT"),
236            },
237            _ => panic!("expected query"),
238        }
239    }
240
241    #[test]
242    fn resolve_group_by_alias_to_time_bucket() {
243        let select = parse_select(
244            "SELECT time_bucket('1 hour', timestamp) AS b, COUNT(*) FROM t GROUP BY b",
245        );
246        let functions = FunctionRegistry::new();
247
248        if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
249            let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
250            assert!(resolved.is_some(), "alias 'b' should resolve");
251            let interval = try_extract_time_bucket(resolved.unwrap(), &functions).unwrap();
252            assert_eq!(interval, Some(3_600_000));
253        } else {
254            panic!("expected GROUP BY expressions");
255        }
256    }
257
258    #[test]
259    fn resolve_group_by_ordinal_to_time_bucket() {
260        let select =
261            parse_select("SELECT time_bucket('5 minutes', timestamp), COUNT(*) FROM t GROUP BY 1");
262        let functions = FunctionRegistry::new();
263
264        if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
265            let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
266            assert!(resolved.is_some(), "ordinal 1 should resolve");
267            let interval = try_extract_time_bucket(resolved.unwrap(), &functions).unwrap();
268            assert_eq!(interval, Some(300_000));
269        } else {
270            panic!("expected GROUP BY expressions");
271        }
272    }
273
274    #[test]
275    fn resolve_group_by_plain_column_not_time_bucket() {
276        let select = parse_select("SELECT qtype, COUNT(*) FROM t GROUP BY qtype");
277        let functions = FunctionRegistry::new();
278
279        if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
280            let resolved = resolve_group_by_expr(&exprs[0], &select.projection);
281            // 'qtype' is not an alias in SELECT, so resolve returns None.
282            assert!(resolved.is_none());
283            let interval = try_extract_time_bucket(&exprs[0], &functions).unwrap();
284            assert_eq!(interval, None);
285        } else {
286            panic!("expected GROUP BY expressions");
287        }
288    }
289}