Skip to main content

nodedb_sql/
aggregate_walk.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! AST traversal for aggregate detection and extraction.
4//!
5//! Uses sqlparser's own `Visitor` so aggregates nested inside any
6//! expression position — `CASE`, `CAST`, `COALESCE`, `-SUM(x)`,
7//! `SUM(x) BETWEEN ...`, `SUM(x) IN (...)`, window specs, array
8//! elements — are found without us maintaining a hand-written walker
9//! that has to enumerate every `Expr` variant.
10//!
11//! # What this replaces
12//!
13//! Two earlier hand-written walkers only descended through `Function`,
14//! `BinaryOp`, and `Nested`. Every other position was a silent
15//! fall-through that produced wrong plans for ordinary SQL:
16//!
17//! - `SELECT CASE WHEN x > 0 THEN SUM(y) ELSE 0 END FROM t`
18//! - `SELECT CAST(SUM(x) AS TEXT) FROM t`
19//! - `SELECT -SUM(x) FROM t`
20//! - `SELECT COALESCE(SUM(x), 0) FROM t`
21//!
22//! None of those were recognised as aggregates; the planner took the
23//! non-aggregate path and produced a plan that executed the inner
24//! column scan without grouping.
25
26use core::ops::ControlFlow;
27
28use sqlparser::ast::{self, Expr, Visit, Visitor};
29
30use crate::error::{Result, SqlError};
31use crate::functions::registry::FunctionRegistry;
32use crate::parser::normalize::normalize_ident;
33use crate::resolver::expr::convert_expr;
34use crate::types::{AggregateExpr, SqlExpr};
35
36/// Return `true` if any aggregate function call appears anywhere inside
37/// `expr` — at any nesting depth, inside any expression position.
38pub fn contains_aggregate(expr: &Expr, functions: &FunctionRegistry) -> bool {
39    let mut detector = AggregateDetector {
40        functions,
41        found: false,
42    };
43    let _ = expr.visit(&mut detector);
44    detector.found
45}
46
47/// Extract every aggregate function call from `expr`, binding each to
48/// the given output `alias`. Nested aggregates (e.g. `SUM(AVG(x))`,
49/// which is illegal SQL in Postgres and most other systems) are
50/// reported as a planner error rather than silently double-extracted.
51pub fn extract_aggregates(
52    expr: &Expr,
53    alias: &str,
54    functions: &FunctionRegistry,
55) -> Result<Vec<AggregateExpr>> {
56    let mut extractor = AggregateExtractor {
57        functions,
58        alias,
59        inside_aggregate: 0,
60        out: Vec::new(),
61        error: None,
62    };
63    let _ = expr.visit(&mut extractor);
64    if let Some(e) = extractor.error {
65        return Err(e);
66    }
67    Ok(extractor.out)
68}
69
70// ── Detector ────────────────────────────────────────────────────────
71
72struct AggregateDetector<'a> {
73    functions: &'a FunctionRegistry,
74    found: bool,
75}
76
77impl Visitor for AggregateDetector<'_> {
78    type Break = ();
79
80    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
81        if let Expr::Function(f) = expr
82            && self.functions.is_aggregate(&function_name(f))
83            // A function with an OVER clause is a window function, not an
84            // aggregate. Window functions are handled by the window planner
85            // and must not trigger the GROUP BY / aggregate plan path.
86            && f.over.is_none()
87        {
88            self.found = true;
89            return ControlFlow::Break(());
90        }
91        ControlFlow::Continue(())
92    }
93}
94
95// ── Extractor ───────────────────────────────────────────────────────
96
97struct AggregateExtractor<'a> {
98    functions: &'a FunctionRegistry,
99    alias: &'a str,
100    /// Depth counter: >0 means we're currently inside the argument
101    /// subtree of an already-extracted aggregate. A second aggregate
102    /// found in that subtree is an illegal nested aggregate.
103    inside_aggregate: u32,
104    out: Vec<AggregateExpr>,
105    /// Deferred error — `Visitor::Break` is `()` so we can't carry the
106    /// error through the control flow directly without boxing. Storing
107    /// it and short-circuiting the traversal on next pre-visit gives
108    /// the same observable behavior.
109    error: Option<SqlError>,
110}
111
112impl Visitor for AggregateExtractor<'_> {
113    type Break = ();
114
115    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
116        if self.error.is_some() {
117            return ControlFlow::Break(());
118        }
119        if let Expr::Function(f) = expr
120            && self.functions.is_aggregate(&function_name(f))
121            // Skip window function calls — those with OVER are handled by the
122            // window planner, not the aggregate planner.
123            && f.over.is_none()
124        {
125            if self.inside_aggregate > 0 {
126                self.error = Some(SqlError::Unsupported {
127                    detail: format!(
128                        "nested aggregate functions are not allowed: {}(...{}...)",
129                        function_name(f),
130                        function_name(f),
131                    ),
132                });
133                return ControlFlow::Break(());
134            }
135            let (args, distinct) = function_args_and_distinct(f);
136            self.out.push(AggregateExpr {
137                function: function_name(f),
138                args,
139                alias: self.alias.into(),
140                distinct,
141                grouping_col_index: None,
142            });
143            self.inside_aggregate += 1;
144        }
145        ControlFlow::Continue(())
146    }
147
148    fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
149        if let Expr::Function(f) = expr
150            && self.functions.is_aggregate(&function_name(f))
151            && f.over.is_none()
152            && self.inside_aggregate > 0
153        {
154            self.inside_aggregate -= 1;
155        }
156        ControlFlow::Continue(())
157    }
158}
159
160// ── Helpers ─────────────────────────────────────────────────────────
161
162/// Return the function name, or a qualified stub that will never match any
163/// registered function if the name is schema-qualified.
164///
165/// Detection (`contains_aggregate`) uses this to look up the function in the
166/// registry — a qualified name won't be found, so the aggregate check is a
167/// safe no-op. Extraction (`extract_aggregates`) will not reach this path
168/// for schema-qualified names because `convert_expr` rejects them first.
169fn function_name(f: &ast::Function) -> String {
170    if f.name.0.len() > 1 {
171        // Return a sentinel that cannot match any registry entry.
172        // The actual rejection happens in convert_expr / convert_function_depth.
173        let qualified: String = f
174            .name
175            .0
176            .iter()
177            .map(|p| match p {
178                ast::ObjectNamePart::Identifier(ident) => ident.value.clone(),
179                _ => String::new(),
180            })
181            .collect::<Vec<_>>()
182            .join(".");
183        return format!("__schema_qualified__{qualified}");
184    }
185    f.name
186        .0
187        .iter()
188        .map(|p| match p {
189            ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
190            _ => String::new(),
191        })
192        .collect::<Vec<_>>()
193        .join(".")
194}
195
196fn function_args_and_distinct(f: &ast::Function) -> (Vec<SqlExpr>, bool) {
197    let ast::FunctionArguments::List(args) = &f.args else {
198        return (Vec::new(), false);
199    };
200    let parsed = args
201        .args
202        .iter()
203        .filter_map(|a| match a {
204            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => convert_expr(e).ok(),
205            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard) => Some(SqlExpr::Wildcard),
206            _ => None,
207        })
208        .collect();
209    let distinct = matches!(
210        args.duplicate_treatment,
211        Some(ast::DuplicateTreatment::Distinct)
212    );
213    (parsed, distinct)
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::parser::statement::parse_sql;
220
221    fn first_select_projection(sql: &str) -> Vec<ast::SelectItem> {
222        let stmts = parse_sql(sql).unwrap();
223        match stmts.into_iter().next().unwrap() {
224            ast::Statement::Query(q) => match *q.body {
225                ast::SetExpr::Select(s) => s.projection,
226                _ => panic!(),
227            },
228            _ => panic!(),
229        }
230    }
231
232    fn first_expr(sql: &str) -> ast::Expr {
233        match first_select_projection(sql).into_iter().next().unwrap() {
234            ast::SelectItem::UnnamedExpr(e) | ast::SelectItem::ExprWithAlias { expr: e, .. } => e,
235            _ => panic!(),
236        }
237    }
238
239    fn functions() -> FunctionRegistry {
240        FunctionRegistry::new()
241    }
242
243    // ── detection ──
244
245    #[test]
246    fn detect_plain_aggregate() {
247        assert!(contains_aggregate(
248            &first_expr("SELECT SUM(x) FROM t"),
249            &functions()
250        ));
251    }
252
253    #[test]
254    fn detect_aggregate_inside_case() {
255        assert!(contains_aggregate(
256            &first_expr("SELECT CASE WHEN x > 0 THEN SUM(y) ELSE 0 END FROM t"),
257            &functions(),
258        ));
259    }
260
261    #[test]
262    fn detect_aggregate_inside_cast() {
263        assert!(contains_aggregate(
264            &first_expr("SELECT CAST(SUM(x) AS TEXT) FROM t"),
265            &functions(),
266        ));
267    }
268
269    #[test]
270    fn detect_aggregate_inside_unary_op() {
271        assert!(contains_aggregate(
272            &first_expr("SELECT -SUM(x) FROM t"),
273            &functions(),
274        ));
275    }
276
277    #[test]
278    fn detect_aggregate_inside_coalesce() {
279        assert!(contains_aggregate(
280            &first_expr("SELECT COALESCE(SUM(x), 0) FROM t"),
281            &functions(),
282        ));
283    }
284
285    #[test]
286    fn detect_aggregate_inside_between() {
287        assert!(contains_aggregate(
288            &first_expr("SELECT SUM(x) BETWEEN 1 AND 10 FROM t"),
289            &functions(),
290        ));
291    }
292
293    #[test]
294    fn detect_aggregate_inside_in_list() {
295        assert!(contains_aggregate(
296            &first_expr("SELECT SUM(x) IN (1, 2, 3) FROM t"),
297            &functions(),
298        ));
299    }
300
301    #[test]
302    fn no_aggregate_in_plain_select() {
303        assert!(!contains_aggregate(
304            &first_expr("SELECT x FROM t"),
305            &functions()
306        ));
307        assert!(!contains_aggregate(
308            &first_expr("SELECT x + 1 FROM t"),
309            &functions()
310        ));
311        assert!(!contains_aggregate(
312            &first_expr("SELECT upper(name) FROM t"),
313            &functions(),
314        ));
315    }
316
317    // ── extraction ──
318
319    #[test]
320    fn extract_plain_aggregate() {
321        let aggs =
322            extract_aggregates(&first_expr("SELECT SUM(x) FROM t"), "total", &functions()).unwrap();
323        assert_eq!(aggs.len(), 1);
324        assert_eq!(aggs[0].function, "sum");
325        assert_eq!(aggs[0].alias, "total");
326    }
327
328    #[test]
329    fn extract_aggregate_inside_cast() {
330        let aggs = extract_aggregates(
331            &first_expr("SELECT CAST(SUM(x) AS TEXT) AS n FROM t"),
332            "n",
333            &functions(),
334        )
335        .unwrap();
336        assert_eq!(aggs.len(), 1);
337        assert_eq!(aggs[0].function, "sum");
338    }
339
340    #[test]
341    fn extract_aggregate_inside_case() {
342        let aggs = extract_aggregates(
343            &first_expr("SELECT CASE WHEN x > 0 THEN SUM(y) ELSE 0 END FROM t"),
344            "r",
345            &functions(),
346        )
347        .unwrap();
348        assert_eq!(aggs.len(), 1);
349        assert_eq!(aggs[0].function, "sum");
350    }
351
352    #[test]
353    fn extract_aggregate_inside_coalesce() {
354        let aggs = extract_aggregates(
355            &first_expr("SELECT COALESCE(SUM(x), 0) FROM t"),
356            "r",
357            &functions(),
358        )
359        .unwrap();
360        assert_eq!(aggs.len(), 1);
361        assert_eq!(aggs[0].function, "sum");
362    }
363
364    #[test]
365    fn extract_two_aggregates_under_one_alias() {
366        let aggs = extract_aggregates(
367            &first_expr("SELECT SUM(x) + COUNT(y) AS total FROM t"),
368            "total",
369            &functions(),
370        )
371        .unwrap();
372        assert_eq!(aggs.len(), 2);
373        let names: Vec<&str> = aggs.iter().map(|a| a.function.as_str()).collect();
374        assert!(names.contains(&"sum"));
375        assert!(names.contains(&"count"));
376    }
377
378    #[test]
379    fn nested_aggregate_directly_inside_aggregate_rejected() {
380        let err = extract_aggregates(&first_expr("SELECT SUM(AVG(x)) FROM t"), "r", &functions())
381            .unwrap_err();
382        let msg = format!("{err:?}");
383        assert!(
384            msg.to_lowercase().contains("nested aggregate"),
385            "error must identify the nested-aggregate class: {msg}"
386        );
387    }
388
389    /// Nested aggregate buried under an intermediate non-aggregate
390    /// (`AVG(x)` wrapped in `CAST(...)` inside `SUM(...)`) must also
391    /// be caught — the depth counter tracks the ancestor aggregate,
392    /// not the direct parent.
393    #[test]
394    fn nested_aggregate_through_cast_rejected() {
395        let err = extract_aggregates(
396            &first_expr("SELECT SUM(CAST(AVG(x) AS BIGINT)) FROM t"),
397            "r",
398            &functions(),
399        )
400        .unwrap_err();
401        assert!(
402            format!("{err:?}")
403                .to_lowercase()
404                .contains("nested aggregate"),
405            "got: {err:?}"
406        );
407    }
408
409    /// Two top-level aggregates separated by a non-aggregate binary
410    /// operator are siblings, not nested — both must extract cleanly.
411    #[test]
412    fn sibling_aggregates_not_treated_as_nested() {
413        let aggs = extract_aggregates(
414            &first_expr("SELECT CAST(SUM(x) AS TEXT) || CAST(COUNT(y) AS TEXT) FROM t"),
415            "r",
416            &functions(),
417        )
418        .unwrap();
419        assert_eq!(aggs.len(), 2);
420    }
421
422    #[test]
423    fn extract_distinct_preserved() {
424        let aggs = extract_aggregates(
425            &first_expr("SELECT COUNT(DISTINCT x) FROM t"),
426            "c",
427            &functions(),
428        )
429        .unwrap();
430        assert_eq!(aggs.len(), 1);
431        assert!(aggs[0].distinct);
432    }
433}