1use core::ops::ControlFlow;
27
28use sqlparser::ast::{self, Expr, Visit, Visitor};
29
30use crate::error::{Result, SqlError};
31use crate::functions::registry::FunctionRegistry;
32use crate::parser::normalize::normalize_ident;
33use crate::resolver::expr::convert_expr;
34use crate::types::{AggregateExpr, SqlExpr};
35
36pub fn contains_aggregate(expr: &Expr, functions: &FunctionRegistry) -> bool {
39 let mut detector = AggregateDetector {
40 functions,
41 found: false,
42 };
43 let _ = expr.visit(&mut detector);
44 detector.found
45}
46
47pub fn extract_aggregates(
52 expr: &Expr,
53 alias: &str,
54 functions: &FunctionRegistry,
55) -> Result<Vec<AggregateExpr>> {
56 let mut extractor = AggregateExtractor {
57 functions,
58 alias,
59 inside_aggregate: 0,
60 out: Vec::new(),
61 error: None,
62 };
63 let _ = expr.visit(&mut extractor);
64 if let Some(e) = extractor.error {
65 return Err(e);
66 }
67 Ok(extractor.out)
68}
69
70struct AggregateDetector<'a> {
73 functions: &'a FunctionRegistry,
74 found: bool,
75}
76
77impl Visitor for AggregateDetector<'_> {
78 type Break = ();
79
80 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
81 if let Expr::Function(f) = expr
82 && self.functions.is_aggregate(&function_name(f))
83 && f.over.is_none()
87 {
88 self.found = true;
89 return ControlFlow::Break(());
90 }
91 ControlFlow::Continue(())
92 }
93}
94
95struct AggregateExtractor<'a> {
98 functions: &'a FunctionRegistry,
99 alias: &'a str,
100 inside_aggregate: u32,
104 out: Vec<AggregateExpr>,
105 error: Option<SqlError>,
110}
111
112impl Visitor for AggregateExtractor<'_> {
113 type Break = ();
114
115 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
116 if self.error.is_some() {
117 return ControlFlow::Break(());
118 }
119 if let Expr::Function(f) = expr
120 && self.functions.is_aggregate(&function_name(f))
121 && f.over.is_none()
124 {
125 if self.inside_aggregate > 0 {
126 self.error = Some(SqlError::Unsupported {
127 detail: format!(
128 "nested aggregate functions are not allowed: {}(...{}...)",
129 function_name(f),
130 function_name(f),
131 ),
132 });
133 return ControlFlow::Break(());
134 }
135 let (args, distinct) = function_args_and_distinct(f);
136 self.out.push(AggregateExpr {
137 function: function_name(f),
138 args,
139 alias: self.alias.into(),
140 distinct,
141 grouping_col_index: None,
142 });
143 self.inside_aggregate += 1;
144 }
145 ControlFlow::Continue(())
146 }
147
148 fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<()> {
149 if let Expr::Function(f) = expr
150 && self.functions.is_aggregate(&function_name(f))
151 && f.over.is_none()
152 && self.inside_aggregate > 0
153 {
154 self.inside_aggregate -= 1;
155 }
156 ControlFlow::Continue(())
157 }
158}
159
160fn function_name(f: &ast::Function) -> String {
170 if f.name.0.len() > 1 {
171 let qualified: String = f
174 .name
175 .0
176 .iter()
177 .map(|p| match p {
178 ast::ObjectNamePart::Identifier(ident) => ident.value.clone(),
179 _ => String::new(),
180 })
181 .collect::<Vec<_>>()
182 .join(".");
183 return format!("__schema_qualified__{qualified}");
184 }
185 f.name
186 .0
187 .iter()
188 .map(|p| match p {
189 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
190 _ => String::new(),
191 })
192 .collect::<Vec<_>>()
193 .join(".")
194}
195
196fn function_args_and_distinct(f: &ast::Function) -> (Vec<SqlExpr>, bool) {
197 let ast::FunctionArguments::List(args) = &f.args else {
198 return (Vec::new(), false);
199 };
200 let parsed = args
201 .args
202 .iter()
203 .filter_map(|a| match a {
204 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => convert_expr(e).ok(),
205 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard) => Some(SqlExpr::Wildcard),
206 _ => None,
207 })
208 .collect();
209 let distinct = matches!(
210 args.duplicate_treatment,
211 Some(ast::DuplicateTreatment::Distinct)
212 );
213 (parsed, distinct)
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::parser::statement::parse_sql;
220
221 fn first_select_projection(sql: &str) -> Vec<ast::SelectItem> {
222 let stmts = parse_sql(sql).unwrap();
223 match stmts.into_iter().next().unwrap() {
224 ast::Statement::Query(q) => match *q.body {
225 ast::SetExpr::Select(s) => s.projection,
226 _ => panic!(),
227 },
228 _ => panic!(),
229 }
230 }
231
232 fn first_expr(sql: &str) -> ast::Expr {
233 match first_select_projection(sql).into_iter().next().unwrap() {
234 ast::SelectItem::UnnamedExpr(e) | ast::SelectItem::ExprWithAlias { expr: e, .. } => e,
235 _ => panic!(),
236 }
237 }
238
239 fn functions() -> FunctionRegistry {
240 FunctionRegistry::new()
241 }
242
243 #[test]
246 fn detect_plain_aggregate() {
247 assert!(contains_aggregate(
248 &first_expr("SELECT SUM(x) FROM t"),
249 &functions()
250 ));
251 }
252
253 #[test]
254 fn detect_aggregate_inside_case() {
255 assert!(contains_aggregate(
256 &first_expr("SELECT CASE WHEN x > 0 THEN SUM(y) ELSE 0 END FROM t"),
257 &functions(),
258 ));
259 }
260
261 #[test]
262 fn detect_aggregate_inside_cast() {
263 assert!(contains_aggregate(
264 &first_expr("SELECT CAST(SUM(x) AS TEXT) FROM t"),
265 &functions(),
266 ));
267 }
268
269 #[test]
270 fn detect_aggregate_inside_unary_op() {
271 assert!(contains_aggregate(
272 &first_expr("SELECT -SUM(x) FROM t"),
273 &functions(),
274 ));
275 }
276
277 #[test]
278 fn detect_aggregate_inside_coalesce() {
279 assert!(contains_aggregate(
280 &first_expr("SELECT COALESCE(SUM(x), 0) FROM t"),
281 &functions(),
282 ));
283 }
284
285 #[test]
286 fn detect_aggregate_inside_between() {
287 assert!(contains_aggregate(
288 &first_expr("SELECT SUM(x) BETWEEN 1 AND 10 FROM t"),
289 &functions(),
290 ));
291 }
292
293 #[test]
294 fn detect_aggregate_inside_in_list() {
295 assert!(contains_aggregate(
296 &first_expr("SELECT SUM(x) IN (1, 2, 3) FROM t"),
297 &functions(),
298 ));
299 }
300
301 #[test]
302 fn no_aggregate_in_plain_select() {
303 assert!(!contains_aggregate(
304 &first_expr("SELECT x FROM t"),
305 &functions()
306 ));
307 assert!(!contains_aggregate(
308 &first_expr("SELECT x + 1 FROM t"),
309 &functions()
310 ));
311 assert!(!contains_aggregate(
312 &first_expr("SELECT upper(name) FROM t"),
313 &functions(),
314 ));
315 }
316
317 #[test]
320 fn extract_plain_aggregate() {
321 let aggs =
322 extract_aggregates(&first_expr("SELECT SUM(x) FROM t"), "total", &functions()).unwrap();
323 assert_eq!(aggs.len(), 1);
324 assert_eq!(aggs[0].function, "sum");
325 assert_eq!(aggs[0].alias, "total");
326 }
327
328 #[test]
329 fn extract_aggregate_inside_cast() {
330 let aggs = extract_aggregates(
331 &first_expr("SELECT CAST(SUM(x) AS TEXT) AS n FROM t"),
332 "n",
333 &functions(),
334 )
335 .unwrap();
336 assert_eq!(aggs.len(), 1);
337 assert_eq!(aggs[0].function, "sum");
338 }
339
340 #[test]
341 fn extract_aggregate_inside_case() {
342 let aggs = extract_aggregates(
343 &first_expr("SELECT CASE WHEN x > 0 THEN SUM(y) ELSE 0 END FROM t"),
344 "r",
345 &functions(),
346 )
347 .unwrap();
348 assert_eq!(aggs.len(), 1);
349 assert_eq!(aggs[0].function, "sum");
350 }
351
352 #[test]
353 fn extract_aggregate_inside_coalesce() {
354 let aggs = extract_aggregates(
355 &first_expr("SELECT COALESCE(SUM(x), 0) FROM t"),
356 "r",
357 &functions(),
358 )
359 .unwrap();
360 assert_eq!(aggs.len(), 1);
361 assert_eq!(aggs[0].function, "sum");
362 }
363
364 #[test]
365 fn extract_two_aggregates_under_one_alias() {
366 let aggs = extract_aggregates(
367 &first_expr("SELECT SUM(x) + COUNT(y) AS total FROM t"),
368 "total",
369 &functions(),
370 )
371 .unwrap();
372 assert_eq!(aggs.len(), 2);
373 let names: Vec<&str> = aggs.iter().map(|a| a.function.as_str()).collect();
374 assert!(names.contains(&"sum"));
375 assert!(names.contains(&"count"));
376 }
377
378 #[test]
379 fn nested_aggregate_directly_inside_aggregate_rejected() {
380 let err = extract_aggregates(&first_expr("SELECT SUM(AVG(x)) FROM t"), "r", &functions())
381 .unwrap_err();
382 let msg = format!("{err:?}");
383 assert!(
384 msg.to_lowercase().contains("nested aggregate"),
385 "error must identify the nested-aggregate class: {msg}"
386 );
387 }
388
389 #[test]
394 fn nested_aggregate_through_cast_rejected() {
395 let err = extract_aggregates(
396 &first_expr("SELECT SUM(CAST(AVG(x) AS BIGINT)) FROM t"),
397 "r",
398 &functions(),
399 )
400 .unwrap_err();
401 assert!(
402 format!("{err:?}")
403 .to_lowercase()
404 .contains("nested aggregate"),
405 "got: {err:?}"
406 );
407 }
408
409 #[test]
412 fn sibling_aggregates_not_treated_as_nested() {
413 let aggs = extract_aggregates(
414 &first_expr("SELECT CAST(SUM(x) AS TEXT) || CAST(COUNT(y) AS TEXT) FROM t"),
415 "r",
416 &functions(),
417 )
418 .unwrap();
419 assert_eq!(aggs.len(), 2);
420 }
421
422 #[test]
423 fn extract_distinct_preserved() {
424 let aggs = extract_aggregates(
425 &first_expr("SELECT COUNT(DISTINCT x) FROM t"),
426 "c",
427 &functions(),
428 )
429 .unwrap();
430 assert_eq!(aggs.len(), 1);
431 assert!(aggs[0].distinct);
432 }
433}