1use core::ops::ControlFlow;
25
26use sqlparser::ast::{self, Expr, Visit, Visitor};
27
28use crate::error::{Result, SqlError};
29use crate::functions::registry::FunctionRegistry;
30use crate::parser::normalize::normalize_ident;
31use crate::resolver::expr::convert_expr;
32use crate::types::{AggregateExpr, SqlExpr};
33
34pub fn contains_aggregate(expr: &Expr, functions: &FunctionRegistry) -> bool {
37 let mut detector = AggregateDetector {
38 functions,
39 found: false,
40 };
41 let _ = expr.visit(&mut detector);
42 detector.found
43}
44
45pub fn extract_aggregates(
50 expr: &Expr,
51 alias: &str,
52 functions: &FunctionRegistry,
53) -> Result<Vec<AggregateExpr>> {
54 let mut extractor = AggregateExtractor {
55 functions,
56 alias,
57 inside_aggregate: 0,
58 out: Vec::new(),
59 error: None,
60 };
61 let _ = expr.visit(&mut extractor);
62 if let Some(e) = extractor.error {
63 return Err(e);
64 }
65 Ok(extractor.out)
66}
67
68struct AggregateDetector<'a> {
71 functions: &'a FunctionRegistry,
72 found: bool,
73}
74
75impl Visitor for AggregateDetector<'_> {
76 type Break = ();
77
78 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
79 if let Expr::Function(f) = expr
80 && self.functions.is_aggregate(&function_name(f))
81 {
82 self.found = true;
83 return ControlFlow::Break(());
84 }
85 ControlFlow::Continue(())
86 }
87}
88
89struct AggregateExtractor<'a> {
92 functions: &'a FunctionRegistry,
93 alias: &'a str,
94 inside_aggregate: u32,
98 out: Vec<AggregateExpr>,
99 error: Option<SqlError>,
104}
105
106impl Visitor for AggregateExtractor<'_> {
107 type Break = ();
108
109 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
110 if self.error.is_some() {
111 return ControlFlow::Break(());
112 }
113 if let Expr::Function(f) = expr
114 && self.functions.is_aggregate(&function_name(f))
115 {
116 if self.inside_aggregate > 0 {
117 self.error = Some(SqlError::Unsupported {
118 detail: format!(
119 "nested aggregate functions are not allowed: {}(...{}...)",
120 function_name(f),
121 function_name(f),
122 ),
123 });
124 return ControlFlow::Break(());
125 }
126 let (args, distinct) = function_args_and_distinct(f);
127 self.out.push(AggregateExpr {
128 function: function_name(f),
129 args,
130 alias: self.alias.into(),
131 distinct,
132 });
133 self.inside_aggregate += 1;
134 }
135 ControlFlow::Continue(())
136 }
137
138 fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
139 if let Expr::Function(f) = expr
140 && self.functions.is_aggregate(&function_name(f))
141 && self.inside_aggregate > 0
142 {
143 self.inside_aggregate -= 1;
144 }
145 ControlFlow::Continue(())
146 }
147}
148
149fn function_name(f: &ast::Function) -> String {
152 f.name
153 .0
154 .iter()
155 .map(|p| match p {
156 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
157 _ => String::new(),
158 })
159 .collect::<Vec<_>>()
160 .join(".")
161}
162
163fn function_args_and_distinct(f: &ast::Function) -> (Vec<SqlExpr>, bool) {
164 let ast::FunctionArguments::List(args) = &f.args else {
165 return (Vec::new(), false);
166 };
167 let parsed = args
168 .args
169 .iter()
170 .filter_map(|a| match a {
171 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => convert_expr(e).ok(),
172 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard) => Some(SqlExpr::Wildcard),
173 _ => None,
174 })
175 .collect();
176 let distinct = matches!(
177 args.duplicate_treatment,
178 Some(ast::DuplicateTreatment::Distinct)
179 );
180 (parsed, distinct)
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::parser::statement::parse_sql;
187
188 fn first_select_projection(sql: &str) -> Vec<ast::SelectItem> {
189 let stmts = parse_sql(sql).unwrap();
190 match stmts.into_iter().next().unwrap() {
191 ast::Statement::Query(q) => match *q.body {
192 ast::SetExpr::Select(s) => s.projection,
193 _ => panic!(),
194 },
195 _ => panic!(),
196 }
197 }
198
199 fn first_expr(sql: &str) -> ast::Expr {
200 match first_select_projection(sql).into_iter().next().unwrap() {
201 ast::SelectItem::UnnamedExpr(e) | ast::SelectItem::ExprWithAlias { expr: e, .. } => e,
202 _ => panic!(),
203 }
204 }
205
206 fn functions() -> FunctionRegistry {
207 FunctionRegistry::new()
208 }
209
210 #[test]
213 fn detect_plain_aggregate() {
214 assert!(contains_aggregate(
215 &first_expr("SELECT SUM(x) FROM t"),
216 &functions()
217 ));
218 }
219
220 #[test]
221 fn detect_aggregate_inside_case() {
222 assert!(contains_aggregate(
223 &first_expr("SELECT CASE WHEN x > 0 THEN SUM(y) ELSE 0 END FROM t"),
224 &functions(),
225 ));
226 }
227
228 #[test]
229 fn detect_aggregate_inside_cast() {
230 assert!(contains_aggregate(
231 &first_expr("SELECT CAST(SUM(x) AS TEXT) FROM t"),
232 &functions(),
233 ));
234 }
235
236 #[test]
237 fn detect_aggregate_inside_unary_op() {
238 assert!(contains_aggregate(
239 &first_expr("SELECT -SUM(x) FROM t"),
240 &functions(),
241 ));
242 }
243
244 #[test]
245 fn detect_aggregate_inside_coalesce() {
246 assert!(contains_aggregate(
247 &first_expr("SELECT COALESCE(SUM(x), 0) FROM t"),
248 &functions(),
249 ));
250 }
251
252 #[test]
253 fn detect_aggregate_inside_between() {
254 assert!(contains_aggregate(
255 &first_expr("SELECT SUM(x) BETWEEN 1 AND 10 FROM t"),
256 &functions(),
257 ));
258 }
259
260 #[test]
261 fn detect_aggregate_inside_in_list() {
262 assert!(contains_aggregate(
263 &first_expr("SELECT SUM(x) IN (1, 2, 3) FROM t"),
264 &functions(),
265 ));
266 }
267
268 #[test]
269 fn no_aggregate_in_plain_select() {
270 assert!(!contains_aggregate(
271 &first_expr("SELECT x FROM t"),
272 &functions()
273 ));
274 assert!(!contains_aggregate(
275 &first_expr("SELECT x + 1 FROM t"),
276 &functions()
277 ));
278 assert!(!contains_aggregate(
279 &first_expr("SELECT upper(name) FROM t"),
280 &functions(),
281 ));
282 }
283
284 #[test]
287 fn extract_plain_aggregate() {
288 let aggs =
289 extract_aggregates(&first_expr("SELECT SUM(x) FROM t"), "total", &functions()).unwrap();
290 assert_eq!(aggs.len(), 1);
291 assert_eq!(aggs[0].function, "sum");
292 assert_eq!(aggs[0].alias, "total");
293 }
294
295 #[test]
296 fn extract_aggregate_inside_cast() {
297 let aggs = extract_aggregates(
298 &first_expr("SELECT CAST(SUM(x) AS TEXT) AS n FROM t"),
299 "n",
300 &functions(),
301 )
302 .unwrap();
303 assert_eq!(aggs.len(), 1);
304 assert_eq!(aggs[0].function, "sum");
305 }
306
307 #[test]
308 fn extract_aggregate_inside_case() {
309 let aggs = extract_aggregates(
310 &first_expr("SELECT CASE WHEN x > 0 THEN SUM(y) ELSE 0 END FROM t"),
311 "r",
312 &functions(),
313 )
314 .unwrap();
315 assert_eq!(aggs.len(), 1);
316 assert_eq!(aggs[0].function, "sum");
317 }
318
319 #[test]
320 fn extract_aggregate_inside_coalesce() {
321 let aggs = extract_aggregates(
322 &first_expr("SELECT COALESCE(SUM(x), 0) FROM t"),
323 "r",
324 &functions(),
325 )
326 .unwrap();
327 assert_eq!(aggs.len(), 1);
328 assert_eq!(aggs[0].function, "sum");
329 }
330
331 #[test]
332 fn extract_two_aggregates_under_one_alias() {
333 let aggs = extract_aggregates(
334 &first_expr("SELECT SUM(x) + COUNT(y) AS total FROM t"),
335 "total",
336 &functions(),
337 )
338 .unwrap();
339 assert_eq!(aggs.len(), 2);
340 let names: Vec<&str> = aggs.iter().map(|a| a.function.as_str()).collect();
341 assert!(names.contains(&"sum"));
342 assert!(names.contains(&"count"));
343 }
344
345 #[test]
346 fn nested_aggregate_directly_inside_aggregate_rejected() {
347 let err = extract_aggregates(&first_expr("SELECT SUM(AVG(x)) FROM t"), "r", &functions())
348 .unwrap_err();
349 let msg = format!("{err:?}");
350 assert!(
351 msg.to_lowercase().contains("nested aggregate"),
352 "error must identify the nested-aggregate class: {msg}"
353 );
354 }
355
356 #[test]
361 fn nested_aggregate_through_cast_rejected() {
362 let err = extract_aggregates(
363 &first_expr("SELECT SUM(CAST(AVG(x) AS BIGINT)) FROM t"),
364 "r",
365 &functions(),
366 )
367 .unwrap_err();
368 assert!(
369 format!("{err:?}")
370 .to_lowercase()
371 .contains("nested aggregate"),
372 "got: {err:?}"
373 );
374 }
375
376 #[test]
379 fn sibling_aggregates_not_treated_as_nested() {
380 let aggs = extract_aggregates(
381 &first_expr("SELECT CAST(SUM(x) AS TEXT) || CAST(COUNT(y) AS TEXT) FROM t"),
382 "r",
383 &functions(),
384 )
385 .unwrap();
386 assert_eq!(aggs.len(), 2);
387 }
388
389 #[test]
390 fn extract_distinct_preserved() {
391 let aggs = extract_aggregates(
392 &first_expr("SELECT COUNT(DISTINCT x) FROM t"),
393 "c",
394 &functions(),
395 )
396 .unwrap();
397 assert_eq!(aggs.len(), 1);
398 assert!(aggs[0].distinct);
399 }
400}