1use sqlparser::ast::{
8 Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, OrderByExpr, Select, SelectItem,
9 SetExpr, Statement,
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum AggregateType {
15 Count,
18 CountDistinct,
20 Sum,
22 Min,
24 Max,
26 Avg,
28 FirstValue,
30 LastValue,
32
33 StdDev,
36 StdDevPop,
38 Variance,
40 VariancePop,
42 Median,
44
45 PercentileCont,
48 PercentileDisc,
50
51 BoolAnd,
54 BoolOr,
56
57 StringAgg,
60 ArrayAgg,
62
63 ApproxCountDistinct,
66 ApproxPercentile,
68 ApproxMedian,
70
71 Covar,
74 CovarPop,
76 Corr,
78 RegrSlope,
80 RegrIntercept,
82
83 BitAnd,
86 BitOr,
88 BitXor,
90
91 Custom,
93}
94
95impl AggregateType {
96 #[must_use]
99 pub fn is_order_sensitive(&self) -> bool {
100 matches!(
101 self,
102 AggregateType::FirstValue
103 | AggregateType::LastValue
104 | AggregateType::PercentileCont
105 | AggregateType::PercentileDisc
106 | AggregateType::StringAgg
107 | AggregateType::ArrayAgg
108 )
109 }
110
111 #[must_use]
116 pub fn is_decomposable(&self) -> bool {
117 matches!(
118 self,
119 AggregateType::Count
120 | AggregateType::Sum
121 | AggregateType::Min
122 | AggregateType::Max
123 | AggregateType::BoolAnd
124 | AggregateType::BoolOr
125 | AggregateType::BitAnd
126 | AggregateType::BitOr
127 | AggregateType::BitXor
128 )
129 }
130
131 #[must_use]
134 pub fn datafusion_name(&self) -> Option<&'static str> {
135 match self {
136 AggregateType::Count | AggregateType::CountDistinct => Some("count"),
137 AggregateType::Sum => Some("sum"),
138 AggregateType::Min => Some("min"),
139 AggregateType::Max => Some("max"),
140 AggregateType::Avg => Some("avg"),
141 AggregateType::FirstValue => Some("first_value"),
142 AggregateType::LastValue => Some("last_value"),
143 AggregateType::StdDev => Some("stddev"),
144 AggregateType::StdDevPop => Some("stddev_pop"),
145 AggregateType::Variance => Some("variance"),
146 AggregateType::VariancePop => Some("variance_pop"),
147 AggregateType::Median => Some("median"),
148 AggregateType::PercentileCont => Some("percentile_cont"),
149 AggregateType::PercentileDisc => Some("percentile_disc"),
150 AggregateType::BoolAnd => Some("bool_and"),
151 AggregateType::BoolOr => Some("bool_or"),
152 AggregateType::StringAgg => Some("string_agg"),
153 AggregateType::ArrayAgg => Some("array_agg"),
154 AggregateType::ApproxCountDistinct => Some("approx_distinct"),
155 AggregateType::ApproxPercentile => Some("approx_percentile_cont"),
156 AggregateType::ApproxMedian => Some("approx_median"),
157 AggregateType::Covar => Some("covar_samp"),
158 AggregateType::CovarPop => Some("covar_pop"),
159 AggregateType::Corr => Some("corr"),
160 AggregateType::RegrSlope => Some("regr_slope"),
161 AggregateType::RegrIntercept => Some("regr_intercept"),
162 AggregateType::BitAnd => Some("bit_and"),
163 AggregateType::BitOr => Some("bit_or"),
164 AggregateType::BitXor => Some("bit_xor"),
165 AggregateType::Custom => None,
166 }
167 }
168
169 #[must_use]
171 pub fn arity(&self) -> usize {
172 match self {
173 AggregateType::Covar
174 | AggregateType::CovarPop
175 | AggregateType::Corr
176 | AggregateType::RegrSlope
177 | AggregateType::RegrIntercept => 2,
178 _ => 1,
179 }
180 }
181}
182
183#[derive(Debug, Clone)]
185pub struct AggregateInfo {
186 pub aggregate_type: AggregateType,
188 pub column: Option<String>,
190 pub alias: Option<String>,
192 pub distinct: bool,
194 pub filter: Option<Box<Expr>>,
196 pub within_group: Vec<OrderByExpr>,
198}
199
200impl AggregateInfo {
201 #[must_use]
203 pub fn new(aggregate_type: AggregateType, column: Option<String>) -> Self {
204 Self {
205 aggregate_type,
206 column,
207 alias: None,
208 distinct: false,
209 filter: None,
210 within_group: Vec::new(),
211 }
212 }
213
214 #[must_use]
216 pub fn with_alias(mut self, alias: String) -> Self {
217 self.alias = Some(alias);
218 self
219 }
220
221 #[must_use]
223 pub fn with_distinct(mut self, distinct: bool) -> Self {
224 self.distinct = distinct;
225 self
226 }
227
228 #[must_use]
230 pub fn has_filter(&self) -> bool {
231 self.filter.is_some()
232 }
233
234 #[must_use]
236 pub fn has_within_group(&self) -> bool {
237 !self.within_group.is_empty()
238 }
239}
240
241#[derive(Debug, Clone, Default)]
243pub struct AggregationAnalysis {
244 pub aggregates: Vec<AggregateInfo>,
246 pub group_by_columns: Vec<String>,
248 pub has_having: bool,
250}
251
252impl AggregationAnalysis {
253 #[must_use]
255 pub fn has_aggregates(&self) -> bool {
256 !self.aggregates.is_empty()
257 }
258
259 #[must_use]
261 pub fn has_order_sensitive(&self) -> bool {
262 self.aggregates
263 .iter()
264 .any(|a| a.aggregate_type.is_order_sensitive())
265 }
266
267 #[must_use]
269 pub fn all_decomposable(&self) -> bool {
270 self.aggregates
271 .iter()
272 .all(|a| a.aggregate_type.is_decomposable())
273 }
274
275 #[must_use]
277 pub fn get_by_type(&self, agg_type: AggregateType) -> Vec<&AggregateInfo> {
278 self.aggregates
279 .iter()
280 .filter(|a| a.aggregate_type == agg_type)
281 .collect()
282 }
283
284 #[must_use]
286 pub fn has_any_filter(&self) -> bool {
287 self.aggregates.iter().any(AggregateInfo::has_filter)
288 }
289
290 #[must_use]
292 pub fn has_any_within_group(&self) -> bool {
293 self.aggregates.iter().any(AggregateInfo::has_within_group)
294 }
295}
296
297#[must_use]
299pub fn analyze_aggregates(stmt: &Statement) -> AggregationAnalysis {
300 let mut analysis = AggregationAnalysis::default();
301
302 if let Statement::Query(query) = stmt {
303 if let SetExpr::Select(select) = query.body.as_ref() {
304 analyze_select(&mut analysis, select);
305 }
306 }
307
308 analysis
309}
310
311fn analyze_select(analysis: &mut AggregationAnalysis, select: &Select) {
313 for item in &select.projection {
315 match item {
316 SelectItem::UnnamedExpr(expr) => {
317 if let Some(agg) = extract_aggregate(expr, None) {
318 analysis.aggregates.push(agg);
319 }
320 }
321 SelectItem::ExprWithAlias { expr, alias } => {
322 if let Some(agg) = extract_aggregate(expr, Some(alias.value.clone())) {
323 analysis.aggregates.push(agg);
324 }
325 }
326 SelectItem::QualifiedWildcard(_, _) | SelectItem::Wildcard(_) => {}
327 }
328 }
329
330 match &select.group_by {
332 GroupByExpr::Expressions(exprs, _modifiers) => {
333 for expr in exprs {
334 if let Some(col) = extract_column_name(expr) {
335 analysis.group_by_columns.push(col);
336 }
337 }
338 }
339 GroupByExpr::All(_) => {}
340 }
341
342 analysis.has_having = select.having.is_some();
344}
345
346fn resolve_aggregate_type(name: &str, func: &Function) -> Option<AggregateType> {
349 match name {
350 "COUNT" => {
352 if has_distinct_arg(func) {
353 Some(AggregateType::CountDistinct)
354 } else {
355 Some(AggregateType::Count)
356 }
357 }
358 "SUM" => Some(AggregateType::Sum),
359 "MIN" => Some(AggregateType::Min),
360 "MAX" => Some(AggregateType::Max),
361 "AVG" | "MEAN" => Some(AggregateType::Avg),
362 "FIRST_VALUE" | "FIRST" => Some(AggregateType::FirstValue),
363 "LAST_VALUE" | "LAST" => Some(AggregateType::LastValue),
364
365 "STDDEV" | "STDDEV_SAMP" => Some(AggregateType::StdDev),
367 "STDDEV_POP" => Some(AggregateType::StdDevPop),
368 "VARIANCE" | "VAR_SAMP" | "VAR" => Some(AggregateType::Variance),
369 "VAR_POP" | "VARIANCE_POP" => Some(AggregateType::VariancePop),
370 "MEDIAN" => Some(AggregateType::Median),
371
372 "PERCENTILE_CONT" => Some(AggregateType::PercentileCont),
374 "PERCENTILE_DISC" => Some(AggregateType::PercentileDisc),
375
376 "BOOL_AND" | "EVERY" => Some(AggregateType::BoolAnd),
378 "BOOL_OR" | "ANY" => Some(AggregateType::BoolOr),
379
380 "STRING_AGG" | "LISTAGG" | "GROUP_CONCAT" => Some(AggregateType::StringAgg),
382 "ARRAY_AGG" => Some(AggregateType::ArrayAgg),
383
384 "APPROX_COUNT_DISTINCT" | "APPROX_DISTINCT" => Some(AggregateType::ApproxCountDistinct),
386 "APPROX_PERCENTILE_CONT" | "APPROX_PERCENTILE" => Some(AggregateType::ApproxPercentile),
387 "APPROX_MEDIAN" => Some(AggregateType::ApproxMedian),
388
389 "COVAR_SAMP" | "COVAR" => Some(AggregateType::Covar),
391 "COVAR_POP" => Some(AggregateType::CovarPop),
392 "CORR" => Some(AggregateType::Corr),
393 "REGR_SLOPE" => Some(AggregateType::RegrSlope),
394 "REGR_INTERCEPT" => Some(AggregateType::RegrIntercept),
395
396 "BIT_AND" => Some(AggregateType::BitAnd),
398 "BIT_OR" => Some(AggregateType::BitOr),
399 "BIT_XOR" => Some(AggregateType::BitXor),
400
401 _ => None,
402 }
403}
404
405fn extract_aggregate(expr: &Expr, alias: Option<String>) -> Option<AggregateInfo> {
407 match expr {
408 Expr::Function(func) => {
409 let func_name = func.name.to_string().to_uppercase();
410 let agg_type = resolve_aggregate_type(&func_name, func)?;
411
412 let column = extract_first_arg_column(func);
413 let distinct = has_distinct_arg(func);
414
415 let mut info = AggregateInfo::new(agg_type, column).with_distinct(distinct);
416
417 if let Some(filter_expr) = &func.filter {
419 info.filter = Some(filter_expr.clone());
420 }
421
422 if !func.within_group.is_empty() {
424 info.within_group.clone_from(&func.within_group);
425 }
426
427 if let Some(a) = alias {
428 info = info.with_alias(a);
429 }
430 Some(info)
431 }
432 Expr::Cast { expr, .. } | Expr::Nested(expr) => extract_aggregate(expr, alias),
434 _ => None,
435 }
436}
437
438fn has_distinct_arg(func: &Function) -> bool {
440 match &func.args {
442 sqlparser::ast::FunctionArguments::List(list) => list.duplicate_treatment.is_some(),
443 _ => false,
444 }
445}
446
447fn extract_first_arg_column(func: &Function) -> Option<String> {
449 match &func.args {
451 sqlparser::ast::FunctionArguments::List(list) => {
452 if list.args.is_empty() {
453 return None;
454 }
455 match &list.args[0] {
456 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => extract_column_name(expr),
457 FunctionArg::Named { arg, .. } | FunctionArg::ExprNamed { arg, .. } => {
458 if let FunctionArgExpr::Expr(expr) = arg {
459 extract_column_name(expr)
460 } else {
461 None
462 }
463 }
464 FunctionArg::Unnamed(_) => None,
466 }
467 }
468 sqlparser::ast::FunctionArguments::Subquery(_)
469 | sqlparser::ast::FunctionArguments::None => None,
470 }
471}
472
473fn extract_column_name(expr: &Expr) -> Option<String> {
475 match expr {
476 Expr::Identifier(ident) => Some(ident.value.clone()),
477 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
478 _ => None,
479 }
480}
481
482#[must_use]
484pub fn has_aggregates(stmt: &Statement) -> bool {
485 analyze_aggregates(stmt).has_aggregates()
486}
487
488#[must_use]
490pub fn count_aggregates(stmt: &Statement) -> usize {
491 analyze_aggregates(stmt).aggregates.len()
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497 use sqlparser::dialect::GenericDialect;
498 use sqlparser::parser::Parser;
499
500 fn parse_statement(sql: &str) -> Statement {
501 let dialect = GenericDialect {};
502 Parser::parse_sql(&dialect, sql).unwrap().remove(0)
503 }
504
505 #[test]
508 fn test_analyze_count() {
509 let stmt = parse_statement("SELECT COUNT(*) FROM events");
510 let analysis = analyze_aggregates(&stmt);
511
512 assert_eq!(analysis.aggregates.len(), 1);
513 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Count);
514 assert!(analysis.aggregates[0].column.is_none());
515 }
516
517 #[test]
518 fn test_analyze_count_column() {
519 let stmt = parse_statement("SELECT COUNT(id) FROM events");
520 let analysis = analyze_aggregates(&stmt);
521
522 assert_eq!(analysis.aggregates.len(), 1);
523 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Count);
524 assert_eq!(analysis.aggregates[0].column, Some("id".to_string()));
525 }
526
527 #[test]
528 fn test_analyze_count_distinct() {
529 let stmt = parse_statement("SELECT COUNT(DISTINCT user_id) FROM events");
530 let analysis = analyze_aggregates(&stmt);
531
532 assert_eq!(analysis.aggregates.len(), 1);
533 assert_eq!(
534 analysis.aggregates[0].aggregate_type,
535 AggregateType::CountDistinct
536 );
537 assert!(analysis.aggregates[0].distinct);
538 }
539
540 #[test]
541 fn test_analyze_sum() {
542 let stmt = parse_statement("SELECT SUM(amount) FROM orders");
543 let analysis = analyze_aggregates(&stmt);
544
545 assert_eq!(analysis.aggregates.len(), 1);
546 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Sum);
547 assert_eq!(analysis.aggregates[0].column, Some("amount".to_string()));
548 }
549
550 #[test]
551 fn test_analyze_min_max() {
552 let stmt = parse_statement("SELECT MIN(price), MAX(price) FROM products");
553 let analysis = analyze_aggregates(&stmt);
554
555 assert_eq!(analysis.aggregates.len(), 2);
556 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Min);
557 assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::Max);
558 }
559
560 #[test]
561 fn test_analyze_avg() {
562 let stmt = parse_statement("SELECT AVG(score) AS avg_score FROM tests");
563 let analysis = analyze_aggregates(&stmt);
564
565 assert_eq!(analysis.aggregates.len(), 1);
566 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Avg);
567 assert_eq!(analysis.aggregates[0].alias, Some("avg_score".to_string()));
568 }
569
570 #[test]
571 fn test_analyze_first_last() {
572 let stmt = parse_statement(
573 "SELECT FIRST_VALUE(price) AS open, LAST_VALUE(price) AS close FROM trades",
574 );
575 let analysis = analyze_aggregates(&stmt);
576
577 assert_eq!(analysis.aggregates.len(), 2);
578 assert_eq!(
579 analysis.aggregates[0].aggregate_type,
580 AggregateType::FirstValue
581 );
582 assert_eq!(
583 analysis.aggregates[1].aggregate_type,
584 AggregateType::LastValue
585 );
586 assert!(analysis.has_order_sensitive());
587 }
588
589 #[test]
590 fn test_analyze_group_by() {
591 let stmt = parse_statement("SELECT category, COUNT(*) FROM products GROUP BY category");
592 let analysis = analyze_aggregates(&stmt);
593
594 assert_eq!(analysis.aggregates.len(), 1);
595 assert_eq!(analysis.group_by_columns.len(), 1);
596 assert_eq!(analysis.group_by_columns[0], "category");
597 }
598
599 #[test]
600 fn test_analyze_multiple_group_by() {
601 let stmt = parse_statement(
602 "SELECT region, category, SUM(sales) FROM orders GROUP BY region, category",
603 );
604 let analysis = analyze_aggregates(&stmt);
605
606 assert_eq!(analysis.group_by_columns.len(), 2);
607 assert_eq!(analysis.group_by_columns[0], "region");
608 assert_eq!(analysis.group_by_columns[1], "category");
609 }
610
611 #[test]
612 fn test_analyze_having() {
613 let stmt = parse_statement(
614 "SELECT category, COUNT(*) FROM products GROUP BY category HAVING COUNT(*) > 10",
615 );
616 let analysis = analyze_aggregates(&stmt);
617
618 assert!(analysis.has_having);
619 }
620
621 #[test]
622 fn test_no_aggregates() {
623 let stmt = parse_statement("SELECT id, name FROM users");
624 let analysis = analyze_aggregates(&stmt);
625
626 assert!(!analysis.has_aggregates());
627 assert_eq!(analysis.aggregates.len(), 0);
628 }
629
630 #[test]
631 fn test_has_aggregates() {
632 let with_agg = parse_statement("SELECT COUNT(*) FROM events");
633 let without_agg = parse_statement("SELECT * FROM events");
634
635 assert!(has_aggregates(&with_agg));
636 assert!(!has_aggregates(&without_agg));
637 }
638
639 #[test]
640 fn test_count_aggregates() {
641 let stmt = parse_statement(
642 "SELECT COUNT(*), SUM(amount), AVG(price), MIN(qty), MAX(qty) FROM orders",
643 );
644 assert_eq!(count_aggregates(&stmt), 5);
645 }
646
647 #[test]
648 fn test_decomposable() {
649 let stmt =
650 parse_statement("SELECT COUNT(*), SUM(amount), MIN(price), MAX(price) FROM orders");
651 let analysis = analyze_aggregates(&stmt);
652 assert!(analysis.all_decomposable());
653
654 let stmt2 = parse_statement("SELECT AVG(price), FIRST_VALUE(price) FROM orders");
655 let analysis2 = analyze_aggregates(&stmt2);
656 assert!(!analysis2.all_decomposable());
657 }
658
659 #[test]
660 fn test_get_by_type() {
661 let stmt = parse_statement("SELECT COUNT(*), COUNT(id), SUM(amount) FROM orders");
662 let analysis = analyze_aggregates(&stmt);
663
664 let counts = analysis.get_by_type(AggregateType::Count);
665 assert_eq!(counts.len(), 2);
666
667 let sums = analysis.get_by_type(AggregateType::Sum);
668 assert_eq!(sums.len(), 1);
669 }
670
671 #[test]
674 fn test_stddev() {
675 let stmt = parse_statement("SELECT STDDEV(price) FROM trades");
676 let analysis = analyze_aggregates(&stmt);
677 assert_eq!(analysis.aggregates.len(), 1);
678 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
679 }
680
681 #[test]
682 fn test_stddev_pop() {
683 let stmt = parse_statement("SELECT STDDEV_POP(latency) FROM requests");
684 let analysis = analyze_aggregates(&stmt);
685 assert_eq!(
686 analysis.aggregates[0].aggregate_type,
687 AggregateType::StdDevPop
688 );
689 }
690
691 #[test]
692 fn test_variance() {
693 let stmt = parse_statement("SELECT VARIANCE(price) FROM trades");
694 let analysis = analyze_aggregates(&stmt);
695 assert_eq!(
696 analysis.aggregates[0].aggregate_type,
697 AggregateType::Variance
698 );
699 }
700
701 #[test]
702 fn test_variance_pop() {
703 let stmt = parse_statement("SELECT VAR_POP(price) FROM trades");
704 let analysis = analyze_aggregates(&stmt);
705 assert_eq!(
706 analysis.aggregates[0].aggregate_type,
707 AggregateType::VariancePop
708 );
709 }
710
711 #[test]
712 fn test_median() {
713 let stmt = parse_statement("SELECT MEDIAN(response_time) FROM requests");
714 let analysis = analyze_aggregates(&stmt);
715 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Median);
716 }
717
718 #[test]
719 fn test_percentile_cont() {
720 let stmt = parse_statement("SELECT PERCENTILE_CONT(0.95) FROM latencies");
721 let analysis = analyze_aggregates(&stmt);
722 assert_eq!(
723 analysis.aggregates[0].aggregate_type,
724 AggregateType::PercentileCont
725 );
726 }
727
728 #[test]
729 fn test_percentile_disc() {
730 let stmt = parse_statement("SELECT PERCENTILE_DISC(0.5) FROM scores");
731 let analysis = analyze_aggregates(&stmt);
732 assert_eq!(
733 analysis.aggregates[0].aggregate_type,
734 AggregateType::PercentileDisc
735 );
736 }
737
738 #[test]
739 fn test_bool_and() {
740 let stmt = parse_statement("SELECT BOOL_AND(is_active) FROM users");
741 let analysis = analyze_aggregates(&stmt);
742 assert_eq!(
743 analysis.aggregates[0].aggregate_type,
744 AggregateType::BoolAnd
745 );
746 }
747
748 #[test]
749 fn test_bool_or() {
750 let stmt = parse_statement("SELECT BOOL_OR(has_error) FROM events");
751 let analysis = analyze_aggregates(&stmt);
752 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::BoolOr);
753 }
754
755 #[test]
756 fn test_string_agg() {
757 let stmt = parse_statement("SELECT STRING_AGG(name, ',') FROM users");
758 let analysis = analyze_aggregates(&stmt);
759 assert_eq!(
760 analysis.aggregates[0].aggregate_type,
761 AggregateType::StringAgg
762 );
763 assert!(analysis.aggregates[0].aggregate_type.is_order_sensitive());
764 }
765
766 #[test]
767 fn test_array_agg() {
768 let stmt = parse_statement("SELECT ARRAY_AGG(id) FROM events");
769 let analysis = analyze_aggregates(&stmt);
770 assert_eq!(
771 analysis.aggregates[0].aggregate_type,
772 AggregateType::ArrayAgg
773 );
774 }
775
776 #[test]
777 fn test_approx_count_distinct() {
778 let stmt = parse_statement("SELECT APPROX_COUNT_DISTINCT(user_id) FROM events");
779 let analysis = analyze_aggregates(&stmt);
780 assert_eq!(
781 analysis.aggregates[0].aggregate_type,
782 AggregateType::ApproxCountDistinct
783 );
784 }
785
786 #[test]
787 fn test_approx_percentile() {
788 let stmt = parse_statement("SELECT APPROX_PERCENTILE_CONT(latency, 0.99) FROM req");
789 let analysis = analyze_aggregates(&stmt);
790 assert_eq!(
791 analysis.aggregates[0].aggregate_type,
792 AggregateType::ApproxPercentile
793 );
794 }
795
796 #[test]
797 fn test_approx_median() {
798 let stmt = parse_statement("SELECT APPROX_MEDIAN(price) FROM trades");
799 let analysis = analyze_aggregates(&stmt);
800 assert_eq!(
801 analysis.aggregates[0].aggregate_type,
802 AggregateType::ApproxMedian
803 );
804 }
805
806 #[test]
807 fn test_covar_samp() {
808 let stmt = parse_statement("SELECT COVAR_SAMP(x, y) FROM points");
809 let analysis = analyze_aggregates(&stmt);
810 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Covar);
811 }
812
813 #[test]
814 fn test_covar_pop() {
815 let stmt = parse_statement("SELECT COVAR_POP(x, y) FROM points");
816 let analysis = analyze_aggregates(&stmt);
817 assert_eq!(
818 analysis.aggregates[0].aggregate_type,
819 AggregateType::CovarPop
820 );
821 }
822
823 #[test]
824 fn test_corr() {
825 let stmt = parse_statement("SELECT CORR(x, y) FROM points");
826 let analysis = analyze_aggregates(&stmt);
827 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Corr);
828 }
829
830 #[test]
831 fn test_regr_slope() {
832 let stmt = parse_statement("SELECT REGR_SLOPE(y, x) FROM data");
833 let analysis = analyze_aggregates(&stmt);
834 assert_eq!(
835 analysis.aggregates[0].aggregate_type,
836 AggregateType::RegrSlope
837 );
838 }
839
840 #[test]
841 fn test_regr_intercept() {
842 let stmt = parse_statement("SELECT REGR_INTERCEPT(y, x) FROM data");
843 let analysis = analyze_aggregates(&stmt);
844 assert_eq!(
845 analysis.aggregates[0].aggregate_type,
846 AggregateType::RegrIntercept
847 );
848 }
849
850 #[test]
851 fn test_bit_aggregates() {
852 let stmt =
853 parse_statement("SELECT BIT_AND(flags), BIT_OR(flags), BIT_XOR(flags) FROM events");
854 let analysis = analyze_aggregates(&stmt);
855 assert_eq!(analysis.aggregates.len(), 3);
856 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::BitAnd);
857 assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::BitOr);
858 assert_eq!(analysis.aggregates[2].aggregate_type, AggregateType::BitXor);
859 }
860
861 #[test]
864 fn test_alias_stddev_samp() {
865 let stmt = parse_statement("SELECT STDDEV_SAMP(price) FROM trades");
866 let analysis = analyze_aggregates(&stmt);
867 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
868 }
869
870 #[test]
871 fn test_alias_var_samp() {
872 let stmt = parse_statement("SELECT VAR_SAMP(price) FROM trades");
873 let analysis = analyze_aggregates(&stmt);
874 assert_eq!(
875 analysis.aggregates[0].aggregate_type,
876 AggregateType::Variance
877 );
878 }
879
880 #[test]
881 fn test_alias_every() {
882 let stmt = parse_statement("SELECT EVERY(is_valid) FROM checks");
883 let analysis = analyze_aggregates(&stmt);
884 assert_eq!(
885 analysis.aggregates[0].aggregate_type,
886 AggregateType::BoolAnd
887 );
888 }
889
890 #[test]
891 fn test_alias_listagg() {
892 let stmt = parse_statement("SELECT LISTAGG(name, ',') FROM users");
893 let analysis = analyze_aggregates(&stmt);
894 assert_eq!(
895 analysis.aggregates[0].aggregate_type,
896 AggregateType::StringAgg
897 );
898 }
899
900 #[test]
901 fn test_alias_group_concat() {
902 let stmt = parse_statement("SELECT GROUP_CONCAT(name, ',') FROM users");
903 let analysis = analyze_aggregates(&stmt);
904 assert_eq!(
905 analysis.aggregates[0].aggregate_type,
906 AggregateType::StringAgg
907 );
908 }
909
910 #[test]
913 fn test_filter_clause_count() {
914 let stmt = parse_statement("SELECT COUNT(*) FILTER (WHERE status = 'active') FROM users");
915 let analysis = analyze_aggregates(&stmt);
916 assert_eq!(analysis.aggregates.len(), 1);
917 assert!(analysis.aggregates[0].has_filter());
918 assert!(analysis.has_any_filter());
919 }
920
921 #[test]
922 fn test_filter_clause_sum() {
923 let stmt = parse_statement(
924 "SELECT SUM(amount) FILTER (WHERE category = 'A') AS sum_a FROM orders",
925 );
926 let analysis = analyze_aggregates(&stmt);
927 assert!(analysis.aggregates[0].has_filter());
928 assert_eq!(analysis.aggregates[0].alias, Some("sum_a".to_string()));
929 }
930
931 #[test]
932 fn test_filter_clause_mixed() {
933 let stmt = parse_statement("SELECT COUNT(*), COUNT(*) FILTER (WHERE x > 0) FROM t");
934 let analysis = analyze_aggregates(&stmt);
935 assert_eq!(analysis.aggregates.len(), 2);
936 assert!(!analysis.aggregates[0].has_filter());
937 assert!(analysis.aggregates[1].has_filter());
938 }
939
940 #[test]
941 fn test_no_filter() {
942 let stmt = parse_statement("SELECT SUM(amount) FROM orders");
943 let analysis = analyze_aggregates(&stmt);
944 assert!(!analysis.aggregates[0].has_filter());
945 assert!(!analysis.has_any_filter());
946 }
947
948 #[test]
951 fn test_within_group_percentile_cont() {
952 let stmt =
953 parse_statement("SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY latency) FROM req");
954 let analysis = analyze_aggregates(&stmt);
955 assert_eq!(analysis.aggregates.len(), 1);
956 assert!(analysis.aggregates[0].has_within_group());
957 assert_eq!(analysis.aggregates[0].within_group.len(), 1);
958 assert!(analysis.has_any_within_group());
959 }
960
961 #[test]
962 fn test_within_group_string_agg() {
963 let stmt =
964 parse_statement("SELECT STRING_AGG(name, ',') WITHIN GROUP (ORDER BY name) FROM users");
965 let analysis = analyze_aggregates(&stmt);
966 assert!(analysis.aggregates[0].has_within_group());
967 }
968
969 #[test]
970 fn test_no_within_group() {
971 let stmt = parse_statement("SELECT SUM(amount) FROM orders");
972 let analysis = analyze_aggregates(&stmt);
973 assert!(!analysis.aggregates[0].has_within_group());
974 assert!(!analysis.has_any_within_group());
975 }
976
977 #[test]
980 fn test_datafusion_name_core() {
981 assert_eq!(AggregateType::Count.datafusion_name(), Some("count"));
982 assert_eq!(AggregateType::Sum.datafusion_name(), Some("sum"));
983 assert_eq!(AggregateType::Min.datafusion_name(), Some("min"));
984 assert_eq!(AggregateType::Max.datafusion_name(), Some("max"));
985 assert_eq!(AggregateType::Avg.datafusion_name(), Some("avg"));
986 }
987
988 #[test]
989 fn test_datafusion_name_statistical() {
990 assert_eq!(AggregateType::StdDev.datafusion_name(), Some("stddev"));
991 assert_eq!(
992 AggregateType::StdDevPop.datafusion_name(),
993 Some("stddev_pop")
994 );
995 assert_eq!(AggregateType::Variance.datafusion_name(), Some("variance"));
996 assert_eq!(
997 AggregateType::VariancePop.datafusion_name(),
998 Some("variance_pop")
999 );
1000 assert_eq!(AggregateType::Median.datafusion_name(), Some("median"));
1001 }
1002
1003 #[test]
1004 fn test_datafusion_name_approx() {
1005 assert_eq!(
1006 AggregateType::ApproxCountDistinct.datafusion_name(),
1007 Some("approx_distinct")
1008 );
1009 assert_eq!(
1010 AggregateType::ApproxPercentile.datafusion_name(),
1011 Some("approx_percentile_cont")
1012 );
1013 assert_eq!(
1014 AggregateType::ApproxMedian.datafusion_name(),
1015 Some("approx_median")
1016 );
1017 }
1018
1019 #[test]
1020 fn test_datafusion_name_custom() {
1021 assert_eq!(AggregateType::Custom.datafusion_name(), None);
1022 }
1023
1024 #[test]
1027 fn test_decomposable_new_types() {
1028 assert!(AggregateType::BoolAnd.is_decomposable());
1030 assert!(AggregateType::BoolOr.is_decomposable());
1031 assert!(AggregateType::BitAnd.is_decomposable());
1032 assert!(AggregateType::BitOr.is_decomposable());
1033 assert!(AggregateType::BitXor.is_decomposable());
1034
1035 assert!(!AggregateType::StdDev.is_decomposable());
1037 assert!(!AggregateType::Variance.is_decomposable());
1038 assert!(!AggregateType::Median.is_decomposable());
1039 assert!(!AggregateType::PercentileCont.is_decomposable());
1040 assert!(!AggregateType::Corr.is_decomposable());
1041 }
1042
1043 #[test]
1044 fn test_order_sensitive_new_types() {
1045 assert!(AggregateType::PercentileCont.is_order_sensitive());
1047 assert!(AggregateType::PercentileDisc.is_order_sensitive());
1048 assert!(AggregateType::StringAgg.is_order_sensitive());
1049 assert!(AggregateType::ArrayAgg.is_order_sensitive());
1050
1051 assert!(!AggregateType::StdDev.is_order_sensitive());
1053 assert!(!AggregateType::Variance.is_order_sensitive());
1054 assert!(!AggregateType::Corr.is_order_sensitive());
1055 }
1056
1057 #[test]
1060 fn test_multi_aggregate_statistical() {
1061 let stmt = parse_statement(
1062 "SELECT AVG(price), STDDEV(price), VARIANCE(price), \
1063 MEDIAN(price) FROM trades GROUP BY symbol",
1064 );
1065 let analysis = analyze_aggregates(&stmt);
1066 assert_eq!(analysis.aggregates.len(), 4);
1067 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Avg);
1068 assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::StdDev);
1069 assert_eq!(
1070 analysis.aggregates[2].aggregate_type,
1071 AggregateType::Variance
1072 );
1073 assert_eq!(analysis.aggregates[3].aggregate_type, AggregateType::Median);
1074 assert!(!analysis.all_decomposable());
1075 }
1076
1077 #[test]
1078 fn test_multi_aggregate_mixed_with_filter() {
1079 let stmt = parse_statement(
1080 "SELECT COUNT(*), \
1081 SUM(amount) FILTER (WHERE status = 'complete'), \
1082 APPROX_COUNT_DISTINCT(user_id) FROM orders",
1083 );
1084 let analysis = analyze_aggregates(&stmt);
1085 assert_eq!(analysis.aggregates.len(), 3);
1086 assert!(!analysis.aggregates[0].has_filter());
1087 assert!(analysis.aggregates[1].has_filter());
1088 assert!(!analysis.aggregates[2].has_filter());
1089 }
1090
1091 #[test]
1094 fn test_arity() {
1095 assert_eq!(AggregateType::Count.arity(), 1);
1096 assert_eq!(AggregateType::Sum.arity(), 1);
1097 assert_eq!(AggregateType::Covar.arity(), 2);
1098 assert_eq!(AggregateType::CovarPop.arity(), 2);
1099 assert_eq!(AggregateType::Corr.arity(), 2);
1100 assert_eq!(AggregateType::RegrSlope.arity(), 2);
1101 assert_eq!(AggregateType::RegrIntercept.arity(), 2);
1102 }
1103
1104 #[test]
1107 fn test_case_insensitive_detection() {
1108 let stmt = parse_statement("SELECT stddev(price), Variance(price) FROM trades");
1109 let analysis = analyze_aggregates(&stmt);
1110 assert_eq!(analysis.aggregates.len(), 2);
1111 assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
1112 assert_eq!(
1113 analysis.aggregates[1].aggregate_type,
1114 AggregateType::Variance
1115 );
1116 }
1117}