1use sqlparser::ast::{
8 Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, OrderByExpr, Select, SelectItem,
9 SetExpr, Statement,
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum AggregateType {
15 Count,
18 CountDistinct,
20 Sum,
22 Min,
24 Max,
26 Avg,
28 FirstValue,
30 LastValue,
32
33 StdDev,
36 StdDevPop,
38 Variance,
40 VariancePop,
42 Median,
44
45 PercentileCont,
48 PercentileDisc,
50
51 BoolAnd,
54 BoolOr,
56
57 StringAgg,
60 ArrayAgg,
62
63 ApproxCountDistinct,
66 ApproxPercentile,
68 ApproxMedian,
70
71 Covar,
74 CovarPop,
76 Corr,
78 RegrSlope,
80 RegrIntercept,
82
83 BitAnd,
86 BitOr,
88 BitXor,
90
91 Custom,
93}
94
95impl AggregateType {
96 #[must_use]
99 pub fn is_order_sensitive(&self) -> bool {
100 matches!(
101 self,
102 AggregateType::FirstValue
103 | AggregateType::LastValue
104 | AggregateType::PercentileCont
105 | AggregateType::PercentileDisc
106 | AggregateType::StringAgg
107 | AggregateType::ArrayAgg
108 )
109 }
110
111 #[must_use]
116 pub fn is_decomposable(&self) -> bool {
117 matches!(
118 self,
119 AggregateType::Count
120 | AggregateType::Sum
121 | AggregateType::Min
122 | AggregateType::Max
123 | AggregateType::BoolAnd
124 | AggregateType::BoolOr
125 | AggregateType::BitAnd
126 | AggregateType::BitOr
127 | AggregateType::BitXor
128 )
129 }
130
131 #[must_use]
134 pub fn datafusion_name(&self) -> Option<&'static str> {
135 match self {
136 AggregateType::Count | AggregateType::CountDistinct => Some("count"),
137 AggregateType::Sum => Some("sum"),
138 AggregateType::Min => Some("min"),
139 AggregateType::Max => Some("max"),
140 AggregateType::Avg => Some("avg"),
141 AggregateType::FirstValue => Some("first_value"),
142 AggregateType::LastValue => Some("last_value"),
143 AggregateType::StdDev => Some("stddev"),
144 AggregateType::StdDevPop => Some("stddev_pop"),
145 AggregateType::Variance => Some("variance"),
146 AggregateType::VariancePop => Some("variance_pop"),
147 AggregateType::Median => Some("median"),
148 AggregateType::PercentileCont => Some("percentile_cont"),
149 AggregateType::PercentileDisc => Some("percentile_disc"),
150 AggregateType::BoolAnd => Some("bool_and"),
151 AggregateType::BoolOr => Some("bool_or"),
152 AggregateType::StringAgg => Some("string_agg"),
153 AggregateType::ArrayAgg => Some("array_agg"),
154 AggregateType::ApproxCountDistinct => Some("approx_distinct"),
155 AggregateType::ApproxPercentile => Some("approx_percentile_cont"),
156 AggregateType::ApproxMedian => Some("approx_median"),
157 AggregateType::Covar => Some("covar_samp"),
158 AggregateType::CovarPop => Some("covar_pop"),
159 AggregateType::Corr => Some("corr"),
160 AggregateType::RegrSlope => Some("regr_slope"),
161 AggregateType::RegrIntercept => Some("regr_intercept"),
162 AggregateType::BitAnd => Some("bit_and"),
163 AggregateType::BitOr => Some("bit_or"),
164 AggregateType::BitXor => Some("bit_xor"),
165 AggregateType::Custom => None,
166 }
167 }
168
169 #[must_use]
171 pub fn arity(&self) -> usize {
172 match self {
173 AggregateType::Covar
174 | AggregateType::CovarPop
175 | AggregateType::Corr
176 | AggregateType::RegrSlope
177 | AggregateType::RegrIntercept => 2,
178 _ => 1,
179 }
180 }
181}
182
183#[derive(Debug, Clone)]
185pub struct AggregateInfo {
186 pub aggregate_type: AggregateType,
188 pub column: Option<String>,
190 pub alias: Option<String>,
192 pub distinct: bool,
194 pub filter: Option<Box<Expr>>,
196 pub within_group: Vec<OrderByExpr>,
198}
199
200impl AggregateInfo {
201 #[must_use]
203 pub fn new(aggregate_type: AggregateType, column: Option<String>) -> Self {
204 Self {
205 aggregate_type,
206 column,
207 alias: None,
208 distinct: false,
209 filter: None,
210 within_group: Vec::new(),
211 }
212 }
213
214 #[must_use]
216 pub fn with_alias(mut self, alias: String) -> Self {
217 self.alias = Some(alias);
218 self
219 }
220
221 #[must_use]
223 pub fn with_distinct(mut self, distinct: bool) -> Self {
224 self.distinct = distinct;
225 self
226 }
227
228 #[must_use]
230 pub fn has_filter(&self) -> bool {
231 self.filter.is_some()
232 }
233
234 #[must_use]
236 pub fn has_within_group(&self) -> bool {
237 !self.within_group.is_empty()
238 }
239}
240
241#[derive(Debug, Clone, Default)]
243pub struct AggregationAnalysis {
244 pub aggregates: Vec<AggregateInfo>,
246 pub group_by_columns: Vec<String>,
248 pub has_having: bool,
250 pub having_expr: Option<String>,
252}
253
254impl AggregationAnalysis {
255 #[must_use]
257 pub fn has_aggregates(&self) -> bool {
258 !self.aggregates.is_empty()
259 }
260
261 #[must_use]
263 pub fn has_order_sensitive(&self) -> bool {
264 self.aggregates
265 .iter()
266 .any(|a| a.aggregate_type.is_order_sensitive())
267 }
268
269 #[must_use]
271 pub fn all_decomposable(&self) -> bool {
272 self.aggregates
273 .iter()
274 .all(|a| a.aggregate_type.is_decomposable())
275 }
276
277 #[must_use]
279 pub fn get_by_type(&self, agg_type: AggregateType) -> Vec<&AggregateInfo> {
280 self.aggregates
281 .iter()
282 .filter(|a| a.aggregate_type == agg_type)
283 .collect()
284 }
285
286 #[must_use]
288 pub fn has_any_filter(&self) -> bool {
289 self.aggregates.iter().any(AggregateInfo::has_filter)
290 }
291
292 #[must_use]
294 pub fn has_any_within_group(&self) -> bool {
295 self.aggregates.iter().any(AggregateInfo::has_within_group)
296 }
297}
298
299#[must_use]
301pub fn analyze_aggregates(stmt: &Statement) -> AggregationAnalysis {
302 let mut analysis = AggregationAnalysis::default();
303
304 if let Statement::Query(query) = stmt {
305 if let SetExpr::Select(select) = query.body.as_ref() {
306 analyze_select(&mut analysis, select);
307 }
308 }
309
310 analysis
311}
312
313fn analyze_select(analysis: &mut AggregationAnalysis, select: &Select) {
315 for item in &select.projection {
317 match item {
318 SelectItem::UnnamedExpr(expr) => {
319 if let Some(agg) = extract_aggregate(expr, None) {
320 analysis.aggregates.push(agg);
321 }
322 }
323 SelectItem::ExprWithAlias { expr, alias } => {
324 if let Some(agg) = extract_aggregate(expr, Some(alias.value.clone())) {
325 analysis.aggregates.push(agg);
326 }
327 }
328 SelectItem::QualifiedWildcard(_, _) | SelectItem::Wildcard(_) => {}
329 }
330 }
331
332 match &select.group_by {
334 GroupByExpr::Expressions(exprs, _modifiers) => {
335 for expr in exprs {
336 if let Some(col) = extract_column_name(expr) {
337 analysis.group_by_columns.push(col);
338 }
339 }
340 }
341 GroupByExpr::All(_) => {}
342 }
343
344 analysis.has_having = select.having.is_some();
346 analysis.having_expr = select.having.as_ref().map(std::string::ToString::to_string);
347}
348
349fn resolve_aggregate_type(name: &str, func: &Function) -> Option<AggregateType> {
352 match name {
353 "COUNT" => {
355 if has_distinct_arg(func) {
356 Some(AggregateType::CountDistinct)
357 } else {
358 Some(AggregateType::Count)
359 }
360 }
361 "SUM" => Some(AggregateType::Sum),
362 "MIN" => Some(AggregateType::Min),
363 "MAX" => Some(AggregateType::Max),
364 "AVG" | "MEAN" => Some(AggregateType::Avg),
365 "FIRST_VALUE" | "FIRST" => Some(AggregateType::FirstValue),
366 "LAST_VALUE" | "LAST" => Some(AggregateType::LastValue),
367
368 "STDDEV" | "STDDEV_SAMP" => Some(AggregateType::StdDev),
370 "STDDEV_POP" => Some(AggregateType::StdDevPop),
371 "VARIANCE" | "VAR_SAMP" | "VAR" => Some(AggregateType::Variance),
372 "VAR_POP" | "VARIANCE_POP" => Some(AggregateType::VariancePop),
373 "MEDIAN" => Some(AggregateType::Median),
374
375 "PERCENTILE_CONT" => Some(AggregateType::PercentileCont),
377 "PERCENTILE_DISC" => Some(AggregateType::PercentileDisc),
378
379 "BOOL_AND" | "EVERY" => Some(AggregateType::BoolAnd),
381 "BOOL_OR" | "ANY" => Some(AggregateType::BoolOr),
382
383 "STRING_AGG" | "LISTAGG" | "GROUP_CONCAT" => Some(AggregateType::StringAgg),
385 "ARRAY_AGG" => Some(AggregateType::ArrayAgg),
386
387 "APPROX_COUNT_DISTINCT" | "APPROX_DISTINCT" => Some(AggregateType::ApproxCountDistinct),
389 "APPROX_PERCENTILE_CONT" | "APPROX_PERCENTILE" => Some(AggregateType::ApproxPercentile),
390 "APPROX_MEDIAN" => Some(AggregateType::ApproxMedian),
391
392 "COVAR_SAMP" | "COVAR" => Some(AggregateType::Covar),
394 "COVAR_POP" => Some(AggregateType::CovarPop),
395 "CORR" => Some(AggregateType::Corr),
396 "REGR_SLOPE" => Some(AggregateType::RegrSlope),
397 "REGR_INTERCEPT" => Some(AggregateType::RegrIntercept),
398
399 "BIT_AND" => Some(AggregateType::BitAnd),
401 "BIT_OR" => Some(AggregateType::BitOr),
402 "BIT_XOR" => Some(AggregateType::BitXor),
403
404 _ => None,
405 }
406}
407
408fn extract_aggregate(expr: &Expr, alias: Option<String>) -> Option<AggregateInfo> {
410 match expr {
411 Expr::Function(func) => {
412 let func_name = func.name.to_string().to_uppercase();
413 let agg_type = resolve_aggregate_type(&func_name, func)?;
414
415 let column = extract_first_arg_column(func);
416 let distinct = has_distinct_arg(func);
417
418 let mut info = AggregateInfo::new(agg_type, column).with_distinct(distinct);
419
420 if let Some(filter_expr) = &func.filter {
422 info.filter = Some(filter_expr.clone());
423 }
424
425 if !func.within_group.is_empty() {
427 info.within_group.clone_from(&func.within_group);
428 }
429
430 if let Some(a) = alias {
431 info = info.with_alias(a);
432 }
433 Some(info)
434 }
435 Expr::Cast { expr, .. } | Expr::Nested(expr) => extract_aggregate(expr, alias),
437 _ => None,
438 }
439}
440
441fn has_distinct_arg(func: &Function) -> bool {
443 match &func.args {
445 sqlparser::ast::FunctionArguments::List(list) => list.duplicate_treatment.is_some(),
446 _ => false,
447 }
448}
449
450fn extract_first_arg_column(func: &Function) -> Option<String> {
452 match &func.args {
454 sqlparser::ast::FunctionArguments::List(list) => {
455 if list.args.is_empty() {
456 return None;
457 }
458 match &list.args[0] {
459 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => extract_column_name(expr),
460 FunctionArg::Named { arg, .. } | FunctionArg::ExprNamed { arg, .. } => {
461 if let FunctionArgExpr::Expr(expr) = arg {
462 extract_column_name(expr)
463 } else {
464 None
465 }
466 }
467 FunctionArg::Unnamed(_) => None,
469 }
470 }
471 sqlparser::ast::FunctionArguments::Subquery(_)
472 | sqlparser::ast::FunctionArguments::None => None,
473 }
474}
475
476fn extract_column_name(expr: &Expr) -> Option<String> {
478 match expr {
479 Expr::Identifier(ident) => Some(ident.value.clone()),
480 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
481 _ => None,
482 }
483}
484
485#[must_use]
487pub fn has_aggregates(stmt: &Statement) -> bool {
488 analyze_aggregates(stmt).has_aggregates()
489}
490
491#[must_use]
493pub fn count_aggregates(stmt: &Statement) -> usize {
494 analyze_aggregates(stmt).aggregates.len()
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use sqlparser::dialect::GenericDialect;
501 use sqlparser::parser::Parser;
502
503 fn parse_statement(sql: &str) -> Statement {
504 let dialect = GenericDialect {};
505 Parser::parse_sql(&dialect, sql).unwrap().remove(0)
506 }
507
508 #[test]
511 fn test_analyze_count() {
512 let stmt = parse_statement("SELECT COUNT(*) FROM events");
513 let analysis = analyze_aggregates(&stmt);
514
515 assert_eq!(analysis.aggregates.len(), 1);
516 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Count);
517 assert!(analysis.aggregates[0].column.is_none());
518 }
519
520 #[test]
521 fn test_analyze_count_column() {
522 let stmt = parse_statement("SELECT COUNT(id) FROM events");
523 let analysis = analyze_aggregates(&stmt);
524
525 assert_eq!(analysis.aggregates.len(), 1);
526 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Count);
527 assert_eq!(analysis.aggregates[0].column, Some("id".to_string()));
528 }
529
530 #[test]
531 fn test_analyze_count_distinct() {
532 let stmt = parse_statement("SELECT COUNT(DISTINCT user_id) FROM events");
533 let analysis = analyze_aggregates(&stmt);
534
535 assert_eq!(analysis.aggregates.len(), 1);
536 assert_eq!(
537 analysis.aggregates[0].aggregate_type,
538 AggregateType::CountDistinct
539 );
540 assert!(analysis.aggregates[0].distinct);
541 }
542
543 #[test]
544 fn test_analyze_sum() {
545 let stmt = parse_statement("SELECT SUM(amount) FROM orders");
546 let analysis = analyze_aggregates(&stmt);
547
548 assert_eq!(analysis.aggregates.len(), 1);
549 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Sum);
550 assert_eq!(analysis.aggregates[0].column, Some("amount".to_string()));
551 }
552
553 #[test]
554 fn test_analyze_min_max() {
555 let stmt = parse_statement("SELECT MIN(price), MAX(price) FROM products");
556 let analysis = analyze_aggregates(&stmt);
557
558 assert_eq!(analysis.aggregates.len(), 2);
559 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Min);
560 assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::Max);
561 }
562
563 #[test]
564 fn test_analyze_avg() {
565 let stmt = parse_statement("SELECT AVG(score) AS avg_score FROM tests");
566 let analysis = analyze_aggregates(&stmt);
567
568 assert_eq!(analysis.aggregates.len(), 1);
569 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Avg);
570 assert_eq!(analysis.aggregates[0].alias, Some("avg_score".to_string()));
571 }
572
573 #[test]
574 fn test_analyze_first_last() {
575 let stmt = parse_statement(
576 "SELECT FIRST_VALUE(price) AS open, LAST_VALUE(price) AS close FROM trades",
577 );
578 let analysis = analyze_aggregates(&stmt);
579
580 assert_eq!(analysis.aggregates.len(), 2);
581 assert_eq!(
582 analysis.aggregates[0].aggregate_type,
583 AggregateType::FirstValue
584 );
585 assert_eq!(
586 analysis.aggregates[1].aggregate_type,
587 AggregateType::LastValue
588 );
589 assert!(analysis.has_order_sensitive());
590 }
591
592 #[test]
593 fn test_analyze_group_by() {
594 let stmt = parse_statement("SELECT category, COUNT(*) FROM products GROUP BY category");
595 let analysis = analyze_aggregates(&stmt);
596
597 assert_eq!(analysis.aggregates.len(), 1);
598 assert_eq!(analysis.group_by_columns.len(), 1);
599 assert_eq!(analysis.group_by_columns[0], "category");
600 }
601
602 #[test]
603 fn test_analyze_multiple_group_by() {
604 let stmt = parse_statement(
605 "SELECT region, category, SUM(sales) FROM orders GROUP BY region, category",
606 );
607 let analysis = analyze_aggregates(&stmt);
608
609 assert_eq!(analysis.group_by_columns.len(), 2);
610 assert_eq!(analysis.group_by_columns[0], "region");
611 assert_eq!(analysis.group_by_columns[1], "category");
612 }
613
614 #[test]
615 fn test_analyze_having() {
616 let stmt = parse_statement(
617 "SELECT category, COUNT(*) FROM products GROUP BY category HAVING COUNT(*) > 10",
618 );
619 let analysis = analyze_aggregates(&stmt);
620
621 assert!(analysis.has_having);
622 }
623
624 #[test]
625 fn test_no_aggregates() {
626 let stmt = parse_statement("SELECT id, name FROM users");
627 let analysis = analyze_aggregates(&stmt);
628
629 assert!(!analysis.has_aggregates());
630 assert_eq!(analysis.aggregates.len(), 0);
631 }
632
633 #[test]
634 fn test_has_aggregates() {
635 let with_agg = parse_statement("SELECT COUNT(*) FROM events");
636 let without_agg = parse_statement("SELECT * FROM events");
637
638 assert!(has_aggregates(&with_agg));
639 assert!(!has_aggregates(&without_agg));
640 }
641
642 #[test]
643 fn test_count_aggregates() {
644 let stmt = parse_statement(
645 "SELECT COUNT(*), SUM(amount), AVG(price), MIN(qty), MAX(qty) FROM orders",
646 );
647 assert_eq!(count_aggregates(&stmt), 5);
648 }
649
650 #[test]
651 fn test_decomposable() {
652 let stmt =
653 parse_statement("SELECT COUNT(*), SUM(amount), MIN(price), MAX(price) FROM orders");
654 let analysis = analyze_aggregates(&stmt);
655 assert!(analysis.all_decomposable());
656
657 let stmt2 = parse_statement("SELECT AVG(price), FIRST_VALUE(price) FROM orders");
658 let analysis2 = analyze_aggregates(&stmt2);
659 assert!(!analysis2.all_decomposable());
660 }
661
662 #[test]
663 fn test_get_by_type() {
664 let stmt = parse_statement("SELECT COUNT(*), COUNT(id), SUM(amount) FROM orders");
665 let analysis = analyze_aggregates(&stmt);
666
667 let counts = analysis.get_by_type(AggregateType::Count);
668 assert_eq!(counts.len(), 2);
669
670 let sums = analysis.get_by_type(AggregateType::Sum);
671 assert_eq!(sums.len(), 1);
672 }
673
674 #[test]
677 fn test_stddev() {
678 let stmt = parse_statement("SELECT STDDEV(price) FROM trades");
679 let analysis = analyze_aggregates(&stmt);
680 assert_eq!(analysis.aggregates.len(), 1);
681 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
682 }
683
684 #[test]
685 fn test_stddev_pop() {
686 let stmt = parse_statement("SELECT STDDEV_POP(latency) FROM requests");
687 let analysis = analyze_aggregates(&stmt);
688 assert_eq!(
689 analysis.aggregates[0].aggregate_type,
690 AggregateType::StdDevPop
691 );
692 }
693
694 #[test]
695 fn test_variance() {
696 let stmt = parse_statement("SELECT VARIANCE(price) FROM trades");
697 let analysis = analyze_aggregates(&stmt);
698 assert_eq!(
699 analysis.aggregates[0].aggregate_type,
700 AggregateType::Variance
701 );
702 }
703
704 #[test]
705 fn test_variance_pop() {
706 let stmt = parse_statement("SELECT VAR_POP(price) FROM trades");
707 let analysis = analyze_aggregates(&stmt);
708 assert_eq!(
709 analysis.aggregates[0].aggregate_type,
710 AggregateType::VariancePop
711 );
712 }
713
714 #[test]
715 fn test_median() {
716 let stmt = parse_statement("SELECT MEDIAN(response_time) FROM requests");
717 let analysis = analyze_aggregates(&stmt);
718 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Median);
719 }
720
721 #[test]
722 fn test_percentile_cont() {
723 let stmt = parse_statement("SELECT PERCENTILE_CONT(0.95) FROM latencies");
724 let analysis = analyze_aggregates(&stmt);
725 assert_eq!(
726 analysis.aggregates[0].aggregate_type,
727 AggregateType::PercentileCont
728 );
729 }
730
731 #[test]
732 fn test_percentile_disc() {
733 let stmt = parse_statement("SELECT PERCENTILE_DISC(0.5) FROM scores");
734 let analysis = analyze_aggregates(&stmt);
735 assert_eq!(
736 analysis.aggregates[0].aggregate_type,
737 AggregateType::PercentileDisc
738 );
739 }
740
741 #[test]
742 fn test_bool_and() {
743 let stmt = parse_statement("SELECT BOOL_AND(is_active) FROM users");
744 let analysis = analyze_aggregates(&stmt);
745 assert_eq!(
746 analysis.aggregates[0].aggregate_type,
747 AggregateType::BoolAnd
748 );
749 }
750
751 #[test]
752 fn test_bool_or() {
753 let stmt = parse_statement("SELECT BOOL_OR(has_error) FROM events");
754 let analysis = analyze_aggregates(&stmt);
755 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::BoolOr);
756 }
757
758 #[test]
759 fn test_string_agg() {
760 let stmt = parse_statement("SELECT STRING_AGG(name, ',') FROM users");
761 let analysis = analyze_aggregates(&stmt);
762 assert_eq!(
763 analysis.aggregates[0].aggregate_type,
764 AggregateType::StringAgg
765 );
766 assert!(analysis.aggregates[0].aggregate_type.is_order_sensitive());
767 }
768
769 #[test]
770 fn test_array_agg() {
771 let stmt = parse_statement("SELECT ARRAY_AGG(id) FROM events");
772 let analysis = analyze_aggregates(&stmt);
773 assert_eq!(
774 analysis.aggregates[0].aggregate_type,
775 AggregateType::ArrayAgg
776 );
777 }
778
779 #[test]
780 fn test_approx_count_distinct() {
781 let stmt = parse_statement("SELECT APPROX_COUNT_DISTINCT(user_id) FROM events");
782 let analysis = analyze_aggregates(&stmt);
783 assert_eq!(
784 analysis.aggregates[0].aggregate_type,
785 AggregateType::ApproxCountDistinct
786 );
787 }
788
789 #[test]
790 fn test_approx_percentile() {
791 let stmt = parse_statement("SELECT APPROX_PERCENTILE_CONT(latency, 0.99) FROM req");
792 let analysis = analyze_aggregates(&stmt);
793 assert_eq!(
794 analysis.aggregates[0].aggregate_type,
795 AggregateType::ApproxPercentile
796 );
797 }
798
799 #[test]
800 fn test_approx_median() {
801 let stmt = parse_statement("SELECT APPROX_MEDIAN(price) FROM trades");
802 let analysis = analyze_aggregates(&stmt);
803 assert_eq!(
804 analysis.aggregates[0].aggregate_type,
805 AggregateType::ApproxMedian
806 );
807 }
808
809 #[test]
810 fn test_covar_samp() {
811 let stmt = parse_statement("SELECT COVAR_SAMP(x, y) FROM points");
812 let analysis = analyze_aggregates(&stmt);
813 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Covar);
814 }
815
816 #[test]
817 fn test_covar_pop() {
818 let stmt = parse_statement("SELECT COVAR_POP(x, y) FROM points");
819 let analysis = analyze_aggregates(&stmt);
820 assert_eq!(
821 analysis.aggregates[0].aggregate_type,
822 AggregateType::CovarPop
823 );
824 }
825
826 #[test]
827 fn test_corr() {
828 let stmt = parse_statement("SELECT CORR(x, y) FROM points");
829 let analysis = analyze_aggregates(&stmt);
830 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Corr);
831 }
832
833 #[test]
834 fn test_regr_slope() {
835 let stmt = parse_statement("SELECT REGR_SLOPE(y, x) FROM data");
836 let analysis = analyze_aggregates(&stmt);
837 assert_eq!(
838 analysis.aggregates[0].aggregate_type,
839 AggregateType::RegrSlope
840 );
841 }
842
843 #[test]
844 fn test_regr_intercept() {
845 let stmt = parse_statement("SELECT REGR_INTERCEPT(y, x) FROM data");
846 let analysis = analyze_aggregates(&stmt);
847 assert_eq!(
848 analysis.aggregates[0].aggregate_type,
849 AggregateType::RegrIntercept
850 );
851 }
852
853 #[test]
854 fn test_bit_aggregates() {
855 let stmt =
856 parse_statement("SELECT BIT_AND(flags), BIT_OR(flags), BIT_XOR(flags) FROM events");
857 let analysis = analyze_aggregates(&stmt);
858 assert_eq!(analysis.aggregates.len(), 3);
859 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::BitAnd);
860 assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::BitOr);
861 assert_eq!(analysis.aggregates[2].aggregate_type, AggregateType::BitXor);
862 }
863
864 #[test]
867 fn test_alias_stddev_samp() {
868 let stmt = parse_statement("SELECT STDDEV_SAMP(price) FROM trades");
869 let analysis = analyze_aggregates(&stmt);
870 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
871 }
872
873 #[test]
874 fn test_alias_var_samp() {
875 let stmt = parse_statement("SELECT VAR_SAMP(price) FROM trades");
876 let analysis = analyze_aggregates(&stmt);
877 assert_eq!(
878 analysis.aggregates[0].aggregate_type,
879 AggregateType::Variance
880 );
881 }
882
883 #[test]
884 fn test_alias_every() {
885 let stmt = parse_statement("SELECT EVERY(is_valid) FROM checks");
886 let analysis = analyze_aggregates(&stmt);
887 assert_eq!(
888 analysis.aggregates[0].aggregate_type,
889 AggregateType::BoolAnd
890 );
891 }
892
893 #[test]
894 fn test_alias_listagg() {
895 let stmt = parse_statement("SELECT LISTAGG(name, ',') FROM users");
896 let analysis = analyze_aggregates(&stmt);
897 assert_eq!(
898 analysis.aggregates[0].aggregate_type,
899 AggregateType::StringAgg
900 );
901 }
902
903 #[test]
904 fn test_alias_group_concat() {
905 let stmt = parse_statement("SELECT GROUP_CONCAT(name, ',') FROM users");
906 let analysis = analyze_aggregates(&stmt);
907 assert_eq!(
908 analysis.aggregates[0].aggregate_type,
909 AggregateType::StringAgg
910 );
911 }
912
913 #[test]
916 fn test_filter_clause_count() {
917 let stmt = parse_statement("SELECT COUNT(*) FILTER (WHERE status = 'active') FROM users");
918 let analysis = analyze_aggregates(&stmt);
919 assert_eq!(analysis.aggregates.len(), 1);
920 assert!(analysis.aggregates[0].has_filter());
921 assert!(analysis.has_any_filter());
922 }
923
924 #[test]
925 fn test_filter_clause_sum() {
926 let stmt = parse_statement(
927 "SELECT SUM(amount) FILTER (WHERE category = 'A') AS sum_a FROM orders",
928 );
929 let analysis = analyze_aggregates(&stmt);
930 assert!(analysis.aggregates[0].has_filter());
931 assert_eq!(analysis.aggregates[0].alias, Some("sum_a".to_string()));
932 }
933
934 #[test]
935 fn test_filter_clause_mixed() {
936 let stmt = parse_statement("SELECT COUNT(*), COUNT(*) FILTER (WHERE x > 0) FROM t");
937 let analysis = analyze_aggregates(&stmt);
938 assert_eq!(analysis.aggregates.len(), 2);
939 assert!(!analysis.aggregates[0].has_filter());
940 assert!(analysis.aggregates[1].has_filter());
941 }
942
943 #[test]
944 fn test_no_filter() {
945 let stmt = parse_statement("SELECT SUM(amount) FROM orders");
946 let analysis = analyze_aggregates(&stmt);
947 assert!(!analysis.aggregates[0].has_filter());
948 assert!(!analysis.has_any_filter());
949 }
950
951 #[test]
954 fn test_within_group_percentile_cont() {
955 let stmt =
956 parse_statement("SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY latency) FROM req");
957 let analysis = analyze_aggregates(&stmt);
958 assert_eq!(analysis.aggregates.len(), 1);
959 assert!(analysis.aggregates[0].has_within_group());
960 assert_eq!(analysis.aggregates[0].within_group.len(), 1);
961 assert!(analysis.has_any_within_group());
962 }
963
964 #[test]
965 fn test_within_group_string_agg() {
966 let stmt =
967 parse_statement("SELECT STRING_AGG(name, ',') WITHIN GROUP (ORDER BY name) FROM users");
968 let analysis = analyze_aggregates(&stmt);
969 assert!(analysis.aggregates[0].has_within_group());
970 }
971
972 #[test]
973 fn test_no_within_group() {
974 let stmt = parse_statement("SELECT SUM(amount) FROM orders");
975 let analysis = analyze_aggregates(&stmt);
976 assert!(!analysis.aggregates[0].has_within_group());
977 assert!(!analysis.has_any_within_group());
978 }
979
980 #[test]
983 fn test_datafusion_name_core() {
984 assert_eq!(AggregateType::Count.datafusion_name(), Some("count"));
985 assert_eq!(AggregateType::Sum.datafusion_name(), Some("sum"));
986 assert_eq!(AggregateType::Min.datafusion_name(), Some("min"));
987 assert_eq!(AggregateType::Max.datafusion_name(), Some("max"));
988 assert_eq!(AggregateType::Avg.datafusion_name(), Some("avg"));
989 }
990
991 #[test]
992 fn test_datafusion_name_statistical() {
993 assert_eq!(AggregateType::StdDev.datafusion_name(), Some("stddev"));
994 assert_eq!(
995 AggregateType::StdDevPop.datafusion_name(),
996 Some("stddev_pop")
997 );
998 assert_eq!(AggregateType::Variance.datafusion_name(), Some("variance"));
999 assert_eq!(
1000 AggregateType::VariancePop.datafusion_name(),
1001 Some("variance_pop")
1002 );
1003 assert_eq!(AggregateType::Median.datafusion_name(), Some("median"));
1004 }
1005
1006 #[test]
1007 fn test_datafusion_name_approx() {
1008 assert_eq!(
1009 AggregateType::ApproxCountDistinct.datafusion_name(),
1010 Some("approx_distinct")
1011 );
1012 assert_eq!(
1013 AggregateType::ApproxPercentile.datafusion_name(),
1014 Some("approx_percentile_cont")
1015 );
1016 assert_eq!(
1017 AggregateType::ApproxMedian.datafusion_name(),
1018 Some("approx_median")
1019 );
1020 }
1021
1022 #[test]
1023 fn test_datafusion_name_custom() {
1024 assert_eq!(AggregateType::Custom.datafusion_name(), None);
1025 }
1026
1027 #[test]
1030 fn test_decomposable_new_types() {
1031 assert!(AggregateType::BoolAnd.is_decomposable());
1033 assert!(AggregateType::BoolOr.is_decomposable());
1034 assert!(AggregateType::BitAnd.is_decomposable());
1035 assert!(AggregateType::BitOr.is_decomposable());
1036 assert!(AggregateType::BitXor.is_decomposable());
1037
1038 assert!(!AggregateType::StdDev.is_decomposable());
1040 assert!(!AggregateType::Variance.is_decomposable());
1041 assert!(!AggregateType::Median.is_decomposable());
1042 assert!(!AggregateType::PercentileCont.is_decomposable());
1043 assert!(!AggregateType::Corr.is_decomposable());
1044 }
1045
1046 #[test]
1047 fn test_order_sensitive_new_types() {
1048 assert!(AggregateType::PercentileCont.is_order_sensitive());
1050 assert!(AggregateType::PercentileDisc.is_order_sensitive());
1051 assert!(AggregateType::StringAgg.is_order_sensitive());
1052 assert!(AggregateType::ArrayAgg.is_order_sensitive());
1053
1054 assert!(!AggregateType::StdDev.is_order_sensitive());
1056 assert!(!AggregateType::Variance.is_order_sensitive());
1057 assert!(!AggregateType::Corr.is_order_sensitive());
1058 }
1059
1060 #[test]
1063 fn test_multi_aggregate_statistical() {
1064 let stmt = parse_statement(
1065 "SELECT AVG(price), STDDEV(price), VARIANCE(price), \
1066 MEDIAN(price) FROM trades GROUP BY symbol",
1067 );
1068 let analysis = analyze_aggregates(&stmt);
1069 assert_eq!(analysis.aggregates.len(), 4);
1070 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Avg);
1071 assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::StdDev);
1072 assert_eq!(
1073 analysis.aggregates[2].aggregate_type,
1074 AggregateType::Variance
1075 );
1076 assert_eq!(analysis.aggregates[3].aggregate_type, AggregateType::Median);
1077 assert!(!analysis.all_decomposable());
1078 }
1079
1080 #[test]
1081 fn test_multi_aggregate_mixed_with_filter() {
1082 let stmt = parse_statement(
1083 "SELECT COUNT(*), \
1084 SUM(amount) FILTER (WHERE status = 'complete'), \
1085 APPROX_COUNT_DISTINCT(user_id) FROM orders",
1086 );
1087 let analysis = analyze_aggregates(&stmt);
1088 assert_eq!(analysis.aggregates.len(), 3);
1089 assert!(!analysis.aggregates[0].has_filter());
1090 assert!(analysis.aggregates[1].has_filter());
1091 assert!(!analysis.aggregates[2].has_filter());
1092 }
1093
1094 #[test]
1097 fn test_arity() {
1098 assert_eq!(AggregateType::Count.arity(), 1);
1099 assert_eq!(AggregateType::Sum.arity(), 1);
1100 assert_eq!(AggregateType::Covar.arity(), 2);
1101 assert_eq!(AggregateType::CovarPop.arity(), 2);
1102 assert_eq!(AggregateType::Corr.arity(), 2);
1103 assert_eq!(AggregateType::RegrSlope.arity(), 2);
1104 assert_eq!(AggregateType::RegrIntercept.arity(), 2);
1105 }
1106
1107 #[test]
1112 fn test_having_expr_simple() {
1113 let stmt = parse_statement(
1114 "SELECT symbol, COUNT(*) FROM trades GROUP BY symbol HAVING COUNT(*) > 10",
1115 );
1116 let analysis = analyze_aggregates(&stmt);
1117 assert!(analysis.has_having);
1118 assert!(analysis.having_expr.is_some());
1119 let expr = analysis.having_expr.unwrap();
1120 assert!(expr.contains("COUNT(*)"), "expr was: {expr}");
1121 assert!(expr.contains("10"), "expr was: {expr}");
1122 }
1123
1124 #[test]
1125 fn test_having_expr_multiple_predicates() {
1126 let stmt = parse_statement(
1127 "SELECT symbol, SUM(volume) AS vol, AVG(price) AS avg_p \
1128 FROM trades GROUP BY symbol \
1129 HAVING SUM(volume) > 1000 AND AVG(price) < 50",
1130 );
1131 let analysis = analyze_aggregates(&stmt);
1132 assert!(analysis.has_having);
1133 let expr = analysis.having_expr.unwrap();
1134 assert!(expr.contains("SUM(volume)"), "expr was: {expr}");
1135 assert!(expr.contains("AVG(price)"), "expr was: {expr}");
1136 assert!(expr.contains("AND"), "expr was: {expr}");
1137 }
1138
1139 #[test]
1140 fn test_having_expr_with_or() {
1141 let stmt = parse_statement(
1142 "SELECT category, COUNT(*) FROM products GROUP BY category \
1143 HAVING COUNT(*) > 100 OR SUM(price) > 5000",
1144 );
1145 let analysis = analyze_aggregates(&stmt);
1146 assert!(analysis.has_having);
1147 let expr = analysis.having_expr.unwrap();
1148 assert!(expr.contains("OR"), "expr was: {expr}");
1149 }
1150
1151 #[test]
1152 fn test_having_expr_none_without_having() {
1153 let stmt = parse_statement("SELECT symbol, COUNT(*) FROM trades GROUP BY symbol");
1154 let analysis = analyze_aggregates(&stmt);
1155 assert!(!analysis.has_having);
1156 assert!(analysis.having_expr.is_none());
1157 }
1158
1159 #[test]
1160 fn test_having_expr_with_alias_reference() {
1161 let stmt = parse_statement(
1162 "SELECT symbol, COUNT(*) AS cnt FROM trades GROUP BY symbol HAVING COUNT(*) >= 5",
1163 );
1164 let analysis = analyze_aggregates(&stmt);
1165 assert!(analysis.has_having);
1166 let expr = analysis.having_expr.unwrap();
1167 assert!(expr.contains("COUNT(*)"), "expr was: {expr}");
1168 assert!(expr.contains('5'), "expr was: {expr}");
1169 }
1170
1171 #[test]
1172 fn test_case_insensitive_detection() {
1173 let stmt = parse_statement("SELECT stddev(price), Variance(price) FROM trades");
1174 let analysis = analyze_aggregates(&stmt);
1175 assert_eq!(analysis.aggregates.len(), 2);
1176 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
1177 assert_eq!(
1178 analysis.aggregates[1].aggregate_type,
1179 AggregateType::Variance
1180 );
1181 }
1182}