Skip to main content

nodedb_sql/
aggregate_walk.rs

1//! AST traversal for aggregate detection and extraction.
2//!
3//! Uses sqlparser's own `Visitor` so aggregates nested inside any
4//! expression position — `CASE`, `CAST`, `COALESCE`, `-SUM(x)`,
5//! `SUM(x) BETWEEN ...`, `SUM(x) IN (...)`, window specs, array
6//! elements — are found without us maintaining a hand-written walker
7//! that has to enumerate every `Expr` variant.
8//!
9//! # What this replaces
10//!
11//! Two earlier hand-written walkers only descended through `Function`,
12//! `BinaryOp`, and `Nested`. Every other position was a silent
13//! fall-through that produced wrong plans for ordinary SQL:
14//!
15//! - `SELECT CASE WHEN x > 0 THEN SUM(y) ELSE 0 END FROM t`
16//! - `SELECT CAST(SUM(x) AS TEXT) FROM t`
17//! - `SELECT -SUM(x) FROM t`
18//! - `SELECT COALESCE(SUM(x), 0) FROM t`
19//!
20//! None of those were recognised as aggregates; the planner took the
21//! non-aggregate path and produced a plan that executed the inner
22//! column scan without grouping.
23
24use core::ops::ControlFlow;
25
26use sqlparser::ast::{self, Expr, Visit, Visitor};
27
28use crate::error::{Result, SqlError};
29use crate::functions::registry::FunctionRegistry;
30use crate::parser::normalize::normalize_ident;
31use crate::resolver::expr::convert_expr;
32use crate::types::{AggregateExpr, SqlExpr};
33
34/// Return `true` if any aggregate function call appears anywhere inside
35/// `expr` — at any nesting depth, inside any expression position.
36pub fn contains_aggregate(expr: &Expr, functions: &FunctionRegistry) -> bool {
37    let mut detector = AggregateDetector {
38        functions,
39        found: false,
40    };
41    let _ = expr.visit(&mut detector);
42    detector.found
43}
44
45/// Extract every aggregate function call from `expr`, binding each to
46/// the given output `alias`. Nested aggregates (e.g. `SUM(AVG(x))`,
47/// which is illegal SQL in Postgres and most other systems) are
48/// reported as a planner error rather than silently double-extracted.
49pub fn extract_aggregates(
50    expr: &Expr,
51    alias: &str,
52    functions: &FunctionRegistry,
53) -> Result<Vec<AggregateExpr>> {
54    let mut extractor = AggregateExtractor {
55        functions,
56        alias,
57        inside_aggregate: 0,
58        out: Vec::new(),
59        error: None,
60    };
61    let _ = expr.visit(&mut extractor);
62    if let Some(e) = extractor.error {
63        return Err(e);
64    }
65    Ok(extractor.out)
66}
67
68// ── Detector ────────────────────────────────────────────────────────
69
70struct AggregateDetector<'a> {
71    functions: &'a FunctionRegistry,
72    found: bool,
73}
74
75impl Visitor for AggregateDetector<'_> {
76    type Break = ();
77
78    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
79        if let Expr::Function(f) = expr
80            && self.functions.is_aggregate(&function_name(f))
81        {
82            self.found = true;
83            return ControlFlow::Break(());
84        }
85        ControlFlow::Continue(())
86    }
87}
88
89// ── Extractor ───────────────────────────────────────────────────────
90
91struct AggregateExtractor<'a> {
92    functions: &'a FunctionRegistry,
93    alias: &'a str,
94    /// Depth counter: >0 means we're currently inside the argument
95    /// subtree of an already-extracted aggregate. A second aggregate
96    /// found in that subtree is an illegal nested aggregate.
97    inside_aggregate: u32,
98    out: Vec<AggregateExpr>,
99    /// Deferred error — `Visitor::Break` is `()` so we can't carry the
100    /// error through the control flow directly without boxing. Storing
101    /// it and short-circuiting the traversal on next pre-visit gives
102    /// the same observable behavior.
103    error: Option<SqlError>,
104}
105
106impl Visitor for AggregateExtractor<'_> {
107    type Break = ();
108
109    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
110        if self.error.is_some() {
111            return ControlFlow::Break(());
112        }
113        if let Expr::Function(f) = expr
114            && self.functions.is_aggregate(&function_name(f))
115        {
116            if self.inside_aggregate > 0 {
117                self.error = Some(SqlError::Unsupported {
118                    detail: format!(
119                        "nested aggregate functions are not allowed: {}(...{}...)",
120                        function_name(f),
121                        function_name(f),
122                    ),
123                });
124                return ControlFlow::Break(());
125            }
126            let (args, distinct) = function_args_and_distinct(f);
127            self.out.push(AggregateExpr {
128                function: function_name(f),
129                args,
130                alias: self.alias.into(),
131                distinct,
132            });
133            self.inside_aggregate += 1;
134        }
135        ControlFlow::Continue(())
136    }
137
138    fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
139        if let Expr::Function(f) = expr
140            && self.functions.is_aggregate(&function_name(f))
141            && self.inside_aggregate > 0
142        {
143            self.inside_aggregate -= 1;
144        }
145        ControlFlow::Continue(())
146    }
147}
148
149// ── Helpers ─────────────────────────────────────────────────────────
150
151fn function_name(f: &ast::Function) -> String {
152    f.name
153        .0
154        .iter()
155        .map(|p| match p {
156            ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
157            _ => String::new(),
158        })
159        .collect::<Vec<_>>()
160        .join(".")
161}
162
163fn function_args_and_distinct(f: &ast::Function) -> (Vec<SqlExpr>, bool) {
164    let ast::FunctionArguments::List(args) = &f.args else {
165        return (Vec::new(), false);
166    };
167    let parsed = args
168        .args
169        .iter()
170        .filter_map(|a| match a {
171            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => convert_expr(e).ok(),
172            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard) => Some(SqlExpr::Wildcard),
173            _ => None,
174        })
175        .collect();
176    let distinct = matches!(
177        args.duplicate_treatment,
178        Some(ast::DuplicateTreatment::Distinct)
179    );
180    (parsed, distinct)
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::parser::statement::parse_sql;
187
188    fn first_select_projection(sql: &str) -> Vec<ast::SelectItem> {
189        let stmts = parse_sql(sql).unwrap();
190        match stmts.into_iter().next().unwrap() {
191            ast::Statement::Query(q) => match *q.body {
192                ast::SetExpr::Select(s) => s.projection,
193                _ => panic!(),
194            },
195            _ => panic!(),
196        }
197    }
198
199    fn first_expr(sql: &str) -> ast::Expr {
200        match first_select_projection(sql).into_iter().next().unwrap() {
201            ast::SelectItem::UnnamedExpr(e) | ast::SelectItem::ExprWithAlias { expr: e, .. } => e,
202            _ => panic!(),
203        }
204    }
205
206    fn functions() -> FunctionRegistry {
207        FunctionRegistry::new()
208    }
209
210    // ── detection ──
211
212    #[test]
213    fn detect_plain_aggregate() {
214        assert!(contains_aggregate(
215            &first_expr("SELECT SUM(x) FROM t"),
216            &functions()
217        ));
218    }
219
220    #[test]
221    fn detect_aggregate_inside_case() {
222        assert!(contains_aggregate(
223            &first_expr("SELECT CASE WHEN x > 0 THEN SUM(y) ELSE 0 END FROM t"),
224            &functions(),
225        ));
226    }
227
228    #[test]
229    fn detect_aggregate_inside_cast() {
230        assert!(contains_aggregate(
231            &first_expr("SELECT CAST(SUM(x) AS TEXT) FROM t"),
232            &functions(),
233        ));
234    }
235
236    #[test]
237    fn detect_aggregate_inside_unary_op() {
238        assert!(contains_aggregate(
239            &first_expr("SELECT -SUM(x) FROM t"),
240            &functions(),
241        ));
242    }
243
244    #[test]
245    fn detect_aggregate_inside_coalesce() {
246        assert!(contains_aggregate(
247            &first_expr("SELECT COALESCE(SUM(x), 0) FROM t"),
248            &functions(),
249        ));
250    }
251
252    #[test]
253    fn detect_aggregate_inside_between() {
254        assert!(contains_aggregate(
255            &first_expr("SELECT SUM(x) BETWEEN 1 AND 10 FROM t"),
256            &functions(),
257        ));
258    }
259
260    #[test]
261    fn detect_aggregate_inside_in_list() {
262        assert!(contains_aggregate(
263            &first_expr("SELECT SUM(x) IN (1, 2, 3) FROM t"),
264            &functions(),
265        ));
266    }
267
268    #[test]
269    fn no_aggregate_in_plain_select() {
270        assert!(!contains_aggregate(
271            &first_expr("SELECT x FROM t"),
272            &functions()
273        ));
274        assert!(!contains_aggregate(
275            &first_expr("SELECT x + 1 FROM t"),
276            &functions()
277        ));
278        assert!(!contains_aggregate(
279            &first_expr("SELECT upper(name) FROM t"),
280            &functions(),
281        ));
282    }
283
284    // ── extraction ──
285
286    #[test]
287    fn extract_plain_aggregate() {
288        let aggs =
289            extract_aggregates(&first_expr("SELECT SUM(x) FROM t"), "total", &functions()).unwrap();
290        assert_eq!(aggs.len(), 1);
291        assert_eq!(aggs[0].function, "sum");
292        assert_eq!(aggs[0].alias, "total");
293    }
294
295    #[test]
296    fn extract_aggregate_inside_cast() {
297        let aggs = extract_aggregates(
298            &first_expr("SELECT CAST(SUM(x) AS TEXT) AS n FROM t"),
299            "n",
300            &functions(),
301        )
302        .unwrap();
303        assert_eq!(aggs.len(), 1);
304        assert_eq!(aggs[0].function, "sum");
305    }
306
307    #[test]
308    fn extract_aggregate_inside_case() {
309        let aggs = extract_aggregates(
310            &first_expr("SELECT CASE WHEN x > 0 THEN SUM(y) ELSE 0 END FROM t"),
311            "r",
312            &functions(),
313        )
314        .unwrap();
315        assert_eq!(aggs.len(), 1);
316        assert_eq!(aggs[0].function, "sum");
317    }
318
319    #[test]
320    fn extract_aggregate_inside_coalesce() {
321        let aggs = extract_aggregates(
322            &first_expr("SELECT COALESCE(SUM(x), 0) FROM t"),
323            "r",
324            &functions(),
325        )
326        .unwrap();
327        assert_eq!(aggs.len(), 1);
328        assert_eq!(aggs[0].function, "sum");
329    }
330
331    #[test]
332    fn extract_two_aggregates_under_one_alias() {
333        let aggs = extract_aggregates(
334            &first_expr("SELECT SUM(x) + COUNT(y) AS total FROM t"),
335            "total",
336            &functions(),
337        )
338        .unwrap();
339        assert_eq!(aggs.len(), 2);
340        let names: Vec<&str> = aggs.iter().map(|a| a.function.as_str()).collect();
341        assert!(names.contains(&"sum"));
342        assert!(names.contains(&"count"));
343    }
344
345    #[test]
346    fn nested_aggregate_directly_inside_aggregate_rejected() {
347        let err = extract_aggregates(&first_expr("SELECT SUM(AVG(x)) FROM t"), "r", &functions())
348            .unwrap_err();
349        let msg = format!("{err:?}");
350        assert!(
351            msg.to_lowercase().contains("nested aggregate"),
352            "error must identify the nested-aggregate class: {msg}"
353        );
354    }
355
356    /// Nested aggregate buried under an intermediate non-aggregate
357    /// (`AVG(x)` wrapped in `CAST(...)` inside `SUM(...)`) must also
358    /// be caught — the depth counter tracks the ancestor aggregate,
359    /// not the direct parent.
360    #[test]
361    fn nested_aggregate_through_cast_rejected() {
362        let err = extract_aggregates(
363            &first_expr("SELECT SUM(CAST(AVG(x) AS BIGINT)) FROM t"),
364            "r",
365            &functions(),
366        )
367        .unwrap_err();
368        assert!(
369            format!("{err:?}")
370                .to_lowercase()
371                .contains("nested aggregate"),
372            "got: {err:?}"
373        );
374    }
375
376    /// Two top-level aggregates separated by a non-aggregate binary
377    /// operator are siblings, not nested — both must extract cleanly.
378    #[test]
379    fn sibling_aggregates_not_treated_as_nested() {
380        let aggs = extract_aggregates(
381            &first_expr("SELECT CAST(SUM(x) AS TEXT) || CAST(COUNT(y) AS TEXT) FROM t"),
382            "r",
383            &functions(),
384        )
385        .unwrap();
386        assert_eq!(aggs.len(), 2);
387    }
388
389    #[test]
390    fn extract_distinct_preserved() {
391        let aggs = extract_aggregates(
392            &first_expr("SELECT COUNT(DISTINCT x) FROM t"),
393            "c",
394            &functions(),
395        )
396        .unwrap();
397        assert_eq!(aggs.len(), 1);
398        assert!(aggs[0].distinct);
399    }
400}