Skip to main content

laminar_sql/parser/
aggregation_parser.rs

1//! Aggregate function detection and extraction
2//!
3//! This module analyzes SQL queries to extract aggregate functions like
4//! COUNT, SUM, MIN, MAX, AVG, STDDEV, VARIANCE, PERCENTILE, and more.
5//! It determines the aggregation strategy and maps to DataFusion names.
6
7use sqlparser::ast::{
8    Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, OrderByExpr, Select, SelectItem,
9    SetExpr, Statement,
10};
11
12/// Types of aggregate functions supported.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum AggregateType {
15    // ── Core aggregates ─────────────────────────────────────────────
16    /// COUNT function
17    Count,
18    /// COUNT DISTINCT function
19    CountDistinct,
20    /// SUM function
21    Sum,
22    /// MIN function
23    Min,
24    /// MAX function
25    Max,
26    /// AVG function
27    Avg,
28    /// `FIRST_VALUE` function
29    FirstValue,
30    /// `LAST_VALUE` function
31    LastValue,
32
33    // ── Statistical aggregates ──────────────────────────────────────
34    /// Sample standard deviation (STDDEV / STDDEV_SAMP)
35    StdDev,
36    /// Population standard deviation (STDDEV_POP)
37    StdDevPop,
38    /// Sample variance (VARIANCE / VAR_SAMP)
39    Variance,
40    /// Population variance (VAR_POP / VARIANCE_POP)
41    VariancePop,
42    /// Median
43    Median,
44
45    // ── Percentile aggregates ───────────────────────────────────────
46    /// PERCENTILE_CONT (continuous interpolation)
47    PercentileCont,
48    /// PERCENTILE_DISC (discrete, nearest-rank)
49    PercentileDisc,
50
51    // ── Boolean aggregates ──────────────────────────────────────────
52    /// BOOL_AND / EVERY
53    BoolAnd,
54    /// BOOL_OR / ANY
55    BoolOr,
56
57    // ── Collection aggregates ───────────────────────────────────────
58    /// STRING_AGG / LISTAGG / GROUP_CONCAT
59    StringAgg,
60    /// ARRAY_AGG
61    ArrayAgg,
62
63    // ── Approximate aggregates ──────────────────────────────────────
64    /// APPROX_COUNT_DISTINCT
65    ApproxCountDistinct,
66    /// APPROX_PERCENTILE_CONT
67    ApproxPercentile,
68    /// APPROX_MEDIAN
69    ApproxMedian,
70
71    // ── Correlation / Regression ────────────────────────────────────
72    /// Covariance sample (COVAR_SAMP)
73    Covar,
74    /// Covariance population (COVAR_POP)
75    CovarPop,
76    /// Pearson correlation (CORR)
77    Corr,
78    /// Linear regression slope (REGR_SLOPE)
79    RegrSlope,
80    /// Linear regression intercept (REGR_INTERCEPT)
81    RegrIntercept,
82
83    // ── Bit aggregates ──────────────────────────────────────────────
84    /// BIT_AND
85    BitAnd,
86    /// BIT_OR
87    BitOr,
88    /// BIT_XOR
89    BitXor,
90
91    /// Custom / unrecognized aggregate function
92    Custom,
93}
94
95impl AggregateType {
96    /// Check if this aggregate is order-sensitive.
97    /// Order-sensitive aggregates require maintaining event order.
98    #[must_use]
99    pub fn is_order_sensitive(&self) -> bool {
100        matches!(
101            self,
102            AggregateType::FirstValue
103                | AggregateType::LastValue
104                | AggregateType::PercentileCont
105                | AggregateType::PercentileDisc
106                | AggregateType::StringAgg
107                | AggregateType::ArrayAgg
108        )
109    }
110
111    /// Check if this aggregate is decomposable (can be computed incrementally).
112    ///
113    /// Decomposable aggregates can be split into partial and final steps,
114    /// enabling parallel or distributed computation.
115    #[must_use]
116    pub fn is_decomposable(&self) -> bool {
117        matches!(
118            self,
119            AggregateType::Count
120                | AggregateType::Sum
121                | AggregateType::Min
122                | AggregateType::Max
123                | AggregateType::BoolAnd
124                | AggregateType::BoolOr
125                | AggregateType::BitAnd
126                | AggregateType::BitOr
127                | AggregateType::BitXor
128        )
129    }
130
131    /// Returns the DataFusion function registry name for this aggregate type,
132    /// or `None` if not directly mappable.
133    #[must_use]
134    pub fn datafusion_name(&self) -> Option<&'static str> {
135        match self {
136            AggregateType::Count | AggregateType::CountDistinct => Some("count"),
137            AggregateType::Sum => Some("sum"),
138            AggregateType::Min => Some("min"),
139            AggregateType::Max => Some("max"),
140            AggregateType::Avg => Some("avg"),
141            AggregateType::FirstValue => Some("first_value"),
142            AggregateType::LastValue => Some("last_value"),
143            AggregateType::StdDev => Some("stddev"),
144            AggregateType::StdDevPop => Some("stddev_pop"),
145            AggregateType::Variance => Some("variance"),
146            AggregateType::VariancePop => Some("variance_pop"),
147            AggregateType::Median => Some("median"),
148            AggregateType::PercentileCont => Some("percentile_cont"),
149            AggregateType::PercentileDisc => Some("percentile_disc"),
150            AggregateType::BoolAnd => Some("bool_and"),
151            AggregateType::BoolOr => Some("bool_or"),
152            AggregateType::StringAgg => Some("string_agg"),
153            AggregateType::ArrayAgg => Some("array_agg"),
154            AggregateType::ApproxCountDistinct => Some("approx_distinct"),
155            AggregateType::ApproxPercentile => Some("approx_percentile_cont"),
156            AggregateType::ApproxMedian => Some("approx_median"),
157            AggregateType::Covar => Some("covar_samp"),
158            AggregateType::CovarPop => Some("covar_pop"),
159            AggregateType::Corr => Some("corr"),
160            AggregateType::RegrSlope => Some("regr_slope"),
161            AggregateType::RegrIntercept => Some("regr_intercept"),
162            AggregateType::BitAnd => Some("bit_and"),
163            AggregateType::BitOr => Some("bit_or"),
164            AggregateType::BitXor => Some("bit_xor"),
165            AggregateType::Custom => None,
166        }
167    }
168
169    /// Returns the number of input columns required by this aggregate.
170    #[must_use]
171    pub fn arity(&self) -> usize {
172        match self {
173            AggregateType::Covar
174            | AggregateType::CovarPop
175            | AggregateType::Corr
176            | AggregateType::RegrSlope
177            | AggregateType::RegrIntercept => 2,
178            _ => 1,
179        }
180    }
181}
182
183/// Information about a detected aggregate function.
184#[derive(Debug, Clone)]
185pub struct AggregateInfo {
186    /// Type of aggregate
187    pub aggregate_type: AggregateType,
188    /// Column being aggregated (None for COUNT(*))
189    pub column: Option<String>,
190    /// Optional alias for the aggregate result
191    pub alias: Option<String>,
192    /// Whether DISTINCT is applied
193    pub distinct: bool,
194    /// FILTER clause expression (e.g. `COUNT(x) FILTER (WHERE x > 5)`)
195    pub filter: Option<Box<Expr>>,
196    /// WITHIN GROUP ORDER BY expressions
197    pub within_group: Vec<OrderByExpr>,
198}
199
200impl AggregateInfo {
201    /// Create a new aggregate info.
202    #[must_use]
203    pub fn new(aggregate_type: AggregateType, column: Option<String>) -> Self {
204        Self {
205            aggregate_type,
206            column,
207            alias: None,
208            distinct: false,
209            filter: None,
210            within_group: Vec::new(),
211        }
212    }
213
214    /// Set the alias.
215    #[must_use]
216    pub fn with_alias(mut self, alias: String) -> Self {
217        self.alias = Some(alias);
218        self
219    }
220
221    /// Set distinct flag.
222    #[must_use]
223    pub fn with_distinct(mut self, distinct: bool) -> Self {
224        self.distinct = distinct;
225        self
226    }
227
228    /// Check whether a FILTER clause is present.
229    #[must_use]
230    pub fn has_filter(&self) -> bool {
231        self.filter.is_some()
232    }
233
234    /// Check whether a WITHIN GROUP clause is present.
235    #[must_use]
236    pub fn has_within_group(&self) -> bool {
237        !self.within_group.is_empty()
238    }
239}
240
241/// Analysis result for aggregations in a query.
242#[derive(Debug, Clone, Default)]
243pub struct AggregationAnalysis {
244    /// List of aggregate functions found
245    pub aggregates: Vec<AggregateInfo>,
246    /// GROUP BY columns
247    pub group_by_columns: Vec<String>,
248    /// Whether the query has a HAVING clause
249    pub has_having: bool,
250    /// The HAVING expression as SQL text (for downstream evaluation)
251    pub having_expr: Option<String>,
252}
253
254impl AggregationAnalysis {
255    /// Check if this analysis contains any aggregates.
256    #[must_use]
257    pub fn has_aggregates(&self) -> bool {
258        !self.aggregates.is_empty()
259    }
260
261    /// Check if any aggregate is order-sensitive.
262    #[must_use]
263    pub fn has_order_sensitive(&self) -> bool {
264        self.aggregates
265            .iter()
266            .any(|a| a.aggregate_type.is_order_sensitive())
267    }
268
269    /// Check if all aggregates are decomposable.
270    #[must_use]
271    pub fn all_decomposable(&self) -> bool {
272        self.aggregates
273            .iter()
274            .all(|a| a.aggregate_type.is_decomposable())
275    }
276
277    /// Get aggregates by type.
278    #[must_use]
279    pub fn get_by_type(&self, agg_type: AggregateType) -> Vec<&AggregateInfo> {
280        self.aggregates
281            .iter()
282            .filter(|a| a.aggregate_type == agg_type)
283            .collect()
284    }
285
286    /// Check if any aggregate has a FILTER clause.
287    #[must_use]
288    pub fn has_any_filter(&self) -> bool {
289        self.aggregates.iter().any(AggregateInfo::has_filter)
290    }
291
292    /// Check if any aggregate has a WITHIN GROUP clause.
293    #[must_use]
294    pub fn has_any_within_group(&self) -> bool {
295        self.aggregates.iter().any(AggregateInfo::has_within_group)
296    }
297}
298
299/// Analyze a SQL statement for aggregate functions.
300#[must_use]
301pub fn analyze_aggregates(stmt: &Statement) -> AggregationAnalysis {
302    let mut analysis = AggregationAnalysis::default();
303
304    if let Statement::Query(query) = stmt {
305        if let SetExpr::Select(select) = query.body.as_ref() {
306            analyze_select(&mut analysis, select);
307        }
308    }
309
310    analysis
311}
312
313/// Analyze a SELECT statement for aggregates.
314fn analyze_select(analysis: &mut AggregationAnalysis, select: &Select) {
315    // Check SELECT items for aggregate functions
316    for item in &select.projection {
317        match item {
318            SelectItem::UnnamedExpr(expr) => {
319                if let Some(agg) = extract_aggregate(expr, None) {
320                    analysis.aggregates.push(agg);
321                }
322            }
323            SelectItem::ExprWithAlias { expr, alias } => {
324                if let Some(agg) = extract_aggregate(expr, Some(alias.value.clone())) {
325                    analysis.aggregates.push(agg);
326                }
327            }
328            SelectItem::QualifiedWildcard(_, _) | SelectItem::Wildcard(_) => {}
329        }
330    }
331
332    // Extract GROUP BY columns
333    match &select.group_by {
334        GroupByExpr::Expressions(exprs, _modifiers) => {
335            for expr in exprs {
336                if let Some(col) = extract_column_name(expr) {
337                    analysis.group_by_columns.push(col);
338                }
339            }
340        }
341        GroupByExpr::All(_) => {}
342    }
343
344    // Check for HAVING clause and extract expression
345    analysis.has_having = select.having.is_some();
346    analysis.having_expr = select.having.as_ref().map(std::string::ToString::to_string);
347}
348
349/// Resolve a SQL function name (upper-cased) to an [`AggregateType`], handling
350/// both canonical names and common aliases.
351fn resolve_aggregate_type(name: &str, func: &Function) -> Option<AggregateType> {
352    match name {
353        // ── Core ────────────────────────────────────────────────────
354        "COUNT" => {
355            if has_distinct_arg(func) {
356                Some(AggregateType::CountDistinct)
357            } else {
358                Some(AggregateType::Count)
359            }
360        }
361        "SUM" => Some(AggregateType::Sum),
362        "MIN" => Some(AggregateType::Min),
363        "MAX" => Some(AggregateType::Max),
364        "AVG" | "MEAN" => Some(AggregateType::Avg),
365        "FIRST_VALUE" | "FIRST" => Some(AggregateType::FirstValue),
366        "LAST_VALUE" | "LAST" => Some(AggregateType::LastValue),
367
368        // ── Statistical ────────────────────────────────────────────
369        "STDDEV" | "STDDEV_SAMP" => Some(AggregateType::StdDev),
370        "STDDEV_POP" => Some(AggregateType::StdDevPop),
371        "VARIANCE" | "VAR_SAMP" | "VAR" => Some(AggregateType::Variance),
372        "VAR_POP" | "VARIANCE_POP" => Some(AggregateType::VariancePop),
373        "MEDIAN" => Some(AggregateType::Median),
374
375        // ── Percentile ─────────────────────────────────────────────
376        "PERCENTILE_CONT" => Some(AggregateType::PercentileCont),
377        "PERCENTILE_DISC" => Some(AggregateType::PercentileDisc),
378
379        // ── Boolean ────────────────────────────────────────────────
380        "BOOL_AND" | "EVERY" => Some(AggregateType::BoolAnd),
381        "BOOL_OR" | "ANY" => Some(AggregateType::BoolOr),
382
383        // ── Collection ─────────────────────────────────────────────
384        "STRING_AGG" | "LISTAGG" | "GROUP_CONCAT" => Some(AggregateType::StringAgg),
385        "ARRAY_AGG" => Some(AggregateType::ArrayAgg),
386
387        // ── Approximate ────────────────────────────────────────────
388        "APPROX_COUNT_DISTINCT" | "APPROX_DISTINCT" => Some(AggregateType::ApproxCountDistinct),
389        "APPROX_PERCENTILE_CONT" | "APPROX_PERCENTILE" => Some(AggregateType::ApproxPercentile),
390        "APPROX_MEDIAN" => Some(AggregateType::ApproxMedian),
391
392        // ── Correlation / Regression ───────────────────────────────
393        "COVAR_SAMP" | "COVAR" => Some(AggregateType::Covar),
394        "COVAR_POP" => Some(AggregateType::CovarPop),
395        "CORR" => Some(AggregateType::Corr),
396        "REGR_SLOPE" => Some(AggregateType::RegrSlope),
397        "REGR_INTERCEPT" => Some(AggregateType::RegrIntercept),
398
399        // ── Bit ────────────────────────────────────────────────────
400        "BIT_AND" => Some(AggregateType::BitAnd),
401        "BIT_OR" => Some(AggregateType::BitOr),
402        "BIT_XOR" => Some(AggregateType::BitXor),
403
404        _ => None,
405    }
406}
407
408/// Extract aggregate function from an expression.
409fn extract_aggregate(expr: &Expr, alias: Option<String>) -> Option<AggregateInfo> {
410    match expr {
411        Expr::Function(func) => {
412            let func_name = func.name.to_string().to_uppercase();
413            let agg_type = resolve_aggregate_type(&func_name, func)?;
414
415            let column = extract_first_arg_column(func);
416            let distinct = has_distinct_arg(func);
417
418            let mut info = AggregateInfo::new(agg_type, column).with_distinct(distinct);
419
420            // Extract FILTER clause
421            if let Some(filter_expr) = &func.filter {
422                info.filter = Some(filter_expr.clone());
423            }
424
425            // Extract WITHIN GROUP clause
426            if !func.within_group.is_empty() {
427                info.within_group.clone_from(&func.within_group);
428            }
429
430            if let Some(a) = alias {
431                info = info.with_alias(a);
432            }
433            Some(info)
434        }
435        // Handle nested expressions (e.g., CAST(COUNT(*) AS INT))
436        Expr::Cast { expr, .. } | Expr::Nested(expr) => extract_aggregate(expr, alias),
437        _ => None,
438    }
439}
440
441/// Check if the function has a DISTINCT argument.
442fn has_distinct_arg(func: &Function) -> bool {
443    // In sqlparser 0.60, DISTINCT is part of FunctionArgumentList
444    match &func.args {
445        sqlparser::ast::FunctionArguments::List(list) => list.duplicate_treatment.is_some(),
446        _ => false,
447    }
448}
449
450/// Extract the column name from the first argument of a function.
451fn extract_first_arg_column(func: &Function) -> Option<String> {
452    // Handle FunctionArguments::List
453    match &func.args {
454        sqlparser::ast::FunctionArguments::List(list) => {
455            if list.args.is_empty() {
456                return None;
457            }
458            match &list.args[0] {
459                FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => extract_column_name(expr),
460                FunctionArg::Named { arg, .. } | FunctionArg::ExprNamed { arg, .. } => {
461                    if let FunctionArgExpr::Expr(expr) = arg {
462                        extract_column_name(expr)
463                    } else {
464                        None
465                    }
466                }
467                // COUNT(*), QualifiedWildcard, etc.
468                FunctionArg::Unnamed(_) => None,
469            }
470        }
471        sqlparser::ast::FunctionArguments::Subquery(_)
472        | sqlparser::ast::FunctionArguments::None => None,
473    }
474}
475
476/// Extract column name from an expression.
477fn extract_column_name(expr: &Expr) -> Option<String> {
478    match expr {
479        Expr::Identifier(ident) => Some(ident.value.clone()),
480        Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
481        _ => None,
482    }
483}
484
485/// Check if a SELECT statement contains any aggregate functions.
486#[must_use]
487pub fn has_aggregates(stmt: &Statement) -> bool {
488    analyze_aggregates(stmt).has_aggregates()
489}
490
491/// Count the number of aggregate functions in a statement.
492#[must_use]
493pub fn count_aggregates(stmt: &Statement) -> usize {
494    analyze_aggregates(stmt).aggregates.len()
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use sqlparser::dialect::GenericDialect;
501    use sqlparser::parser::Parser;
502
503    fn parse_statement(sql: &str) -> Statement {
504        let dialect = GenericDialect {};
505        Parser::parse_sql(&dialect, sql).unwrap().remove(0)
506    }
507
508    // ── Core aggregate tests (existing, preserved) ──────────────────
509
510    #[test]
511    fn test_analyze_count() {
512        let stmt = parse_statement("SELECT COUNT(*) FROM events");
513        let analysis = analyze_aggregates(&stmt);
514
515        assert_eq!(analysis.aggregates.len(), 1);
516        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Count);
517        assert!(analysis.aggregates[0].column.is_none());
518    }
519
520    #[test]
521    fn test_analyze_count_column() {
522        let stmt = parse_statement("SELECT COUNT(id) FROM events");
523        let analysis = analyze_aggregates(&stmt);
524
525        assert_eq!(analysis.aggregates.len(), 1);
526        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Count);
527        assert_eq!(analysis.aggregates[0].column, Some("id".to_string()));
528    }
529
530    #[test]
531    fn test_analyze_count_distinct() {
532        let stmt = parse_statement("SELECT COUNT(DISTINCT user_id) FROM events");
533        let analysis = analyze_aggregates(&stmt);
534
535        assert_eq!(analysis.aggregates.len(), 1);
536        assert_eq!(
537            analysis.aggregates[0].aggregate_type,
538            AggregateType::CountDistinct
539        );
540        assert!(analysis.aggregates[0].distinct);
541    }
542
543    #[test]
544    fn test_analyze_sum() {
545        let stmt = parse_statement("SELECT SUM(amount) FROM orders");
546        let analysis = analyze_aggregates(&stmt);
547
548        assert_eq!(analysis.aggregates.len(), 1);
549        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Sum);
550        assert_eq!(analysis.aggregates[0].column, Some("amount".to_string()));
551    }
552
553    #[test]
554    fn test_analyze_min_max() {
555        let stmt = parse_statement("SELECT MIN(price), MAX(price) FROM products");
556        let analysis = analyze_aggregates(&stmt);
557
558        assert_eq!(analysis.aggregates.len(), 2);
559        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Min);
560        assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::Max);
561    }
562
563    #[test]
564    fn test_analyze_avg() {
565        let stmt = parse_statement("SELECT AVG(score) AS avg_score FROM tests");
566        let analysis = analyze_aggregates(&stmt);
567
568        assert_eq!(analysis.aggregates.len(), 1);
569        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Avg);
570        assert_eq!(analysis.aggregates[0].alias, Some("avg_score".to_string()));
571    }
572
573    #[test]
574    fn test_analyze_first_last() {
575        let stmt = parse_statement(
576            "SELECT FIRST_VALUE(price) AS open, LAST_VALUE(price) AS close FROM trades",
577        );
578        let analysis = analyze_aggregates(&stmt);
579
580        assert_eq!(analysis.aggregates.len(), 2);
581        assert_eq!(
582            analysis.aggregates[0].aggregate_type,
583            AggregateType::FirstValue
584        );
585        assert_eq!(
586            analysis.aggregates[1].aggregate_type,
587            AggregateType::LastValue
588        );
589        assert!(analysis.has_order_sensitive());
590    }
591
592    #[test]
593    fn test_analyze_group_by() {
594        let stmt = parse_statement("SELECT category, COUNT(*) FROM products GROUP BY category");
595        let analysis = analyze_aggregates(&stmt);
596
597        assert_eq!(analysis.aggregates.len(), 1);
598        assert_eq!(analysis.group_by_columns.len(), 1);
599        assert_eq!(analysis.group_by_columns[0], "category");
600    }
601
602    #[test]
603    fn test_analyze_multiple_group_by() {
604        let stmt = parse_statement(
605            "SELECT region, category, SUM(sales) FROM orders GROUP BY region, category",
606        );
607        let analysis = analyze_aggregates(&stmt);
608
609        assert_eq!(analysis.group_by_columns.len(), 2);
610        assert_eq!(analysis.group_by_columns[0], "region");
611        assert_eq!(analysis.group_by_columns[1], "category");
612    }
613
614    #[test]
615    fn test_analyze_having() {
616        let stmt = parse_statement(
617            "SELECT category, COUNT(*) FROM products GROUP BY category HAVING COUNT(*) > 10",
618        );
619        let analysis = analyze_aggregates(&stmt);
620
621        assert!(analysis.has_having);
622    }
623
624    #[test]
625    fn test_no_aggregates() {
626        let stmt = parse_statement("SELECT id, name FROM users");
627        let analysis = analyze_aggregates(&stmt);
628
629        assert!(!analysis.has_aggregates());
630        assert_eq!(analysis.aggregates.len(), 0);
631    }
632
633    #[test]
634    fn test_has_aggregates() {
635        let with_agg = parse_statement("SELECT COUNT(*) FROM events");
636        let without_agg = parse_statement("SELECT * FROM events");
637
638        assert!(has_aggregates(&with_agg));
639        assert!(!has_aggregates(&without_agg));
640    }
641
642    #[test]
643    fn test_count_aggregates() {
644        let stmt = parse_statement(
645            "SELECT COUNT(*), SUM(amount), AVG(price), MIN(qty), MAX(qty) FROM orders",
646        );
647        assert_eq!(count_aggregates(&stmt), 5);
648    }
649
650    #[test]
651    fn test_decomposable() {
652        let stmt =
653            parse_statement("SELECT COUNT(*), SUM(amount), MIN(price), MAX(price) FROM orders");
654        let analysis = analyze_aggregates(&stmt);
655        assert!(analysis.all_decomposable());
656
657        let stmt2 = parse_statement("SELECT AVG(price), FIRST_VALUE(price) FROM orders");
658        let analysis2 = analyze_aggregates(&stmt2);
659        assert!(!analysis2.all_decomposable());
660    }
661
662    #[test]
663    fn test_get_by_type() {
664        let stmt = parse_statement("SELECT COUNT(*), COUNT(id), SUM(amount) FROM orders");
665        let analysis = analyze_aggregates(&stmt);
666
667        let counts = analysis.get_by_type(AggregateType::Count);
668        assert_eq!(counts.len(), 2);
669
670        let sums = analysis.get_by_type(AggregateType::Sum);
671        assert_eq!(sums.len(), 1);
672    }
673
674    // ── New aggregate type detection tests ──────────────────────────
675
676    #[test]
677    fn test_stddev() {
678        let stmt = parse_statement("SELECT STDDEV(price) FROM trades");
679        let analysis = analyze_aggregates(&stmt);
680        assert_eq!(analysis.aggregates.len(), 1);
681        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
682    }
683
684    #[test]
685    fn test_stddev_pop() {
686        let stmt = parse_statement("SELECT STDDEV_POP(latency) FROM requests");
687        let analysis = analyze_aggregates(&stmt);
688        assert_eq!(
689            analysis.aggregates[0].aggregate_type,
690            AggregateType::StdDevPop
691        );
692    }
693
694    #[test]
695    fn test_variance() {
696        let stmt = parse_statement("SELECT VARIANCE(price) FROM trades");
697        let analysis = analyze_aggregates(&stmt);
698        assert_eq!(
699            analysis.aggregates[0].aggregate_type,
700            AggregateType::Variance
701        );
702    }
703
704    #[test]
705    fn test_variance_pop() {
706        let stmt = parse_statement("SELECT VAR_POP(price) FROM trades");
707        let analysis = analyze_aggregates(&stmt);
708        assert_eq!(
709            analysis.aggregates[0].aggregate_type,
710            AggregateType::VariancePop
711        );
712    }
713
714    #[test]
715    fn test_median() {
716        let stmt = parse_statement("SELECT MEDIAN(response_time) FROM requests");
717        let analysis = analyze_aggregates(&stmt);
718        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Median);
719    }
720
721    #[test]
722    fn test_percentile_cont() {
723        let stmt = parse_statement("SELECT PERCENTILE_CONT(0.95) FROM latencies");
724        let analysis = analyze_aggregates(&stmt);
725        assert_eq!(
726            analysis.aggregates[0].aggregate_type,
727            AggregateType::PercentileCont
728        );
729    }
730
731    #[test]
732    fn test_percentile_disc() {
733        let stmt = parse_statement("SELECT PERCENTILE_DISC(0.5) FROM scores");
734        let analysis = analyze_aggregates(&stmt);
735        assert_eq!(
736            analysis.aggregates[0].aggregate_type,
737            AggregateType::PercentileDisc
738        );
739    }
740
741    #[test]
742    fn test_bool_and() {
743        let stmt = parse_statement("SELECT BOOL_AND(is_active) FROM users");
744        let analysis = analyze_aggregates(&stmt);
745        assert_eq!(
746            analysis.aggregates[0].aggregate_type,
747            AggregateType::BoolAnd
748        );
749    }
750
751    #[test]
752    fn test_bool_or() {
753        let stmt = parse_statement("SELECT BOOL_OR(has_error) FROM events");
754        let analysis = analyze_aggregates(&stmt);
755        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::BoolOr);
756    }
757
758    #[test]
759    fn test_string_agg() {
760        let stmt = parse_statement("SELECT STRING_AGG(name, ',') FROM users");
761        let analysis = analyze_aggregates(&stmt);
762        assert_eq!(
763            analysis.aggregates[0].aggregate_type,
764            AggregateType::StringAgg
765        );
766        assert!(analysis.aggregates[0].aggregate_type.is_order_sensitive());
767    }
768
769    #[test]
770    fn test_array_agg() {
771        let stmt = parse_statement("SELECT ARRAY_AGG(id) FROM events");
772        let analysis = analyze_aggregates(&stmt);
773        assert_eq!(
774            analysis.aggregates[0].aggregate_type,
775            AggregateType::ArrayAgg
776        );
777    }
778
779    #[test]
780    fn test_approx_count_distinct() {
781        let stmt = parse_statement("SELECT APPROX_COUNT_DISTINCT(user_id) FROM events");
782        let analysis = analyze_aggregates(&stmt);
783        assert_eq!(
784            analysis.aggregates[0].aggregate_type,
785            AggregateType::ApproxCountDistinct
786        );
787    }
788
789    #[test]
790    fn test_approx_percentile() {
791        let stmt = parse_statement("SELECT APPROX_PERCENTILE_CONT(latency, 0.99) FROM req");
792        let analysis = analyze_aggregates(&stmt);
793        assert_eq!(
794            analysis.aggregates[0].aggregate_type,
795            AggregateType::ApproxPercentile
796        );
797    }
798
799    #[test]
800    fn test_approx_median() {
801        let stmt = parse_statement("SELECT APPROX_MEDIAN(price) FROM trades");
802        let analysis = analyze_aggregates(&stmt);
803        assert_eq!(
804            analysis.aggregates[0].aggregate_type,
805            AggregateType::ApproxMedian
806        );
807    }
808
809    #[test]
810    fn test_covar_samp() {
811        let stmt = parse_statement("SELECT COVAR_SAMP(x, y) FROM points");
812        let analysis = analyze_aggregates(&stmt);
813        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Covar);
814    }
815
816    #[test]
817    fn test_covar_pop() {
818        let stmt = parse_statement("SELECT COVAR_POP(x, y) FROM points");
819        let analysis = analyze_aggregates(&stmt);
820        assert_eq!(
821            analysis.aggregates[0].aggregate_type,
822            AggregateType::CovarPop
823        );
824    }
825
826    #[test]
827    fn test_corr() {
828        let stmt = parse_statement("SELECT CORR(x, y) FROM points");
829        let analysis = analyze_aggregates(&stmt);
830        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Corr);
831    }
832
833    #[test]
834    fn test_regr_slope() {
835        let stmt = parse_statement("SELECT REGR_SLOPE(y, x) FROM data");
836        let analysis = analyze_aggregates(&stmt);
837        assert_eq!(
838            analysis.aggregates[0].aggregate_type,
839            AggregateType::RegrSlope
840        );
841    }
842
843    #[test]
844    fn test_regr_intercept() {
845        let stmt = parse_statement("SELECT REGR_INTERCEPT(y, x) FROM data");
846        let analysis = analyze_aggregates(&stmt);
847        assert_eq!(
848            analysis.aggregates[0].aggregate_type,
849            AggregateType::RegrIntercept
850        );
851    }
852
853    #[test]
854    fn test_bit_aggregates() {
855        let stmt =
856            parse_statement("SELECT BIT_AND(flags), BIT_OR(flags), BIT_XOR(flags) FROM events");
857        let analysis = analyze_aggregates(&stmt);
858        assert_eq!(analysis.aggregates.len(), 3);
859        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::BitAnd);
860        assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::BitOr);
861        assert_eq!(analysis.aggregates[2].aggregate_type, AggregateType::BitXor);
862    }
863
864    // ── Alias synonym tests ────────────────────────────────────────
865
866    #[test]
867    fn test_alias_stddev_samp() {
868        let stmt = parse_statement("SELECT STDDEV_SAMP(price) FROM trades");
869        let analysis = analyze_aggregates(&stmt);
870        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
871    }
872
873    #[test]
874    fn test_alias_var_samp() {
875        let stmt = parse_statement("SELECT VAR_SAMP(price) FROM trades");
876        let analysis = analyze_aggregates(&stmt);
877        assert_eq!(
878            analysis.aggregates[0].aggregate_type,
879            AggregateType::Variance
880        );
881    }
882
883    #[test]
884    fn test_alias_every() {
885        let stmt = parse_statement("SELECT EVERY(is_valid) FROM checks");
886        let analysis = analyze_aggregates(&stmt);
887        assert_eq!(
888            analysis.aggregates[0].aggregate_type,
889            AggregateType::BoolAnd
890        );
891    }
892
893    #[test]
894    fn test_alias_listagg() {
895        let stmt = parse_statement("SELECT LISTAGG(name, ',') FROM users");
896        let analysis = analyze_aggregates(&stmt);
897        assert_eq!(
898            analysis.aggregates[0].aggregate_type,
899            AggregateType::StringAgg
900        );
901    }
902
903    #[test]
904    fn test_alias_group_concat() {
905        let stmt = parse_statement("SELECT GROUP_CONCAT(name, ',') FROM users");
906        let analysis = analyze_aggregates(&stmt);
907        assert_eq!(
908            analysis.aggregates[0].aggregate_type,
909            AggregateType::StringAgg
910        );
911    }
912
913    // ── FILTER clause tests ────────────────────────────────────────
914
915    #[test]
916    fn test_filter_clause_count() {
917        let stmt = parse_statement("SELECT COUNT(*) FILTER (WHERE status = 'active') FROM users");
918        let analysis = analyze_aggregates(&stmt);
919        assert_eq!(analysis.aggregates.len(), 1);
920        assert!(analysis.aggregates[0].has_filter());
921        assert!(analysis.has_any_filter());
922    }
923
924    #[test]
925    fn test_filter_clause_sum() {
926        let stmt = parse_statement(
927            "SELECT SUM(amount) FILTER (WHERE category = 'A') AS sum_a FROM orders",
928        );
929        let analysis = analyze_aggregates(&stmt);
930        assert!(analysis.aggregates[0].has_filter());
931        assert_eq!(analysis.aggregates[0].alias, Some("sum_a".to_string()));
932    }
933
934    #[test]
935    fn test_filter_clause_mixed() {
936        let stmt = parse_statement("SELECT COUNT(*), COUNT(*) FILTER (WHERE x > 0) FROM t");
937        let analysis = analyze_aggregates(&stmt);
938        assert_eq!(analysis.aggregates.len(), 2);
939        assert!(!analysis.aggregates[0].has_filter());
940        assert!(analysis.aggregates[1].has_filter());
941    }
942
943    #[test]
944    fn test_no_filter() {
945        let stmt = parse_statement("SELECT SUM(amount) FROM orders");
946        let analysis = analyze_aggregates(&stmt);
947        assert!(!analysis.aggregates[0].has_filter());
948        assert!(!analysis.has_any_filter());
949    }
950
951    // ── WITHIN GROUP tests ─────────────────────────────────────────
952
953    #[test]
954    fn test_within_group_percentile_cont() {
955        let stmt =
956            parse_statement("SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY latency) FROM req");
957        let analysis = analyze_aggregates(&stmt);
958        assert_eq!(analysis.aggregates.len(), 1);
959        assert!(analysis.aggregates[0].has_within_group());
960        assert_eq!(analysis.aggregates[0].within_group.len(), 1);
961        assert!(analysis.has_any_within_group());
962    }
963
964    #[test]
965    fn test_within_group_string_agg() {
966        let stmt =
967            parse_statement("SELECT STRING_AGG(name, ',') WITHIN GROUP (ORDER BY name) FROM users");
968        let analysis = analyze_aggregates(&stmt);
969        assert!(analysis.aggregates[0].has_within_group());
970    }
971
972    #[test]
973    fn test_no_within_group() {
974        let stmt = parse_statement("SELECT SUM(amount) FROM orders");
975        let analysis = analyze_aggregates(&stmt);
976        assert!(!analysis.aggregates[0].has_within_group());
977        assert!(!analysis.has_any_within_group());
978    }
979
980    // ── datafusion_name() tests ────────────────────────────────────
981
982    #[test]
983    fn test_datafusion_name_core() {
984        assert_eq!(AggregateType::Count.datafusion_name(), Some("count"));
985        assert_eq!(AggregateType::Sum.datafusion_name(), Some("sum"));
986        assert_eq!(AggregateType::Min.datafusion_name(), Some("min"));
987        assert_eq!(AggregateType::Max.datafusion_name(), Some("max"));
988        assert_eq!(AggregateType::Avg.datafusion_name(), Some("avg"));
989    }
990
991    #[test]
992    fn test_datafusion_name_statistical() {
993        assert_eq!(AggregateType::StdDev.datafusion_name(), Some("stddev"));
994        assert_eq!(
995            AggregateType::StdDevPop.datafusion_name(),
996            Some("stddev_pop")
997        );
998        assert_eq!(AggregateType::Variance.datafusion_name(), Some("variance"));
999        assert_eq!(
1000            AggregateType::VariancePop.datafusion_name(),
1001            Some("variance_pop")
1002        );
1003        assert_eq!(AggregateType::Median.datafusion_name(), Some("median"));
1004    }
1005
1006    #[test]
1007    fn test_datafusion_name_approx() {
1008        assert_eq!(
1009            AggregateType::ApproxCountDistinct.datafusion_name(),
1010            Some("approx_distinct")
1011        );
1012        assert_eq!(
1013            AggregateType::ApproxPercentile.datafusion_name(),
1014            Some("approx_percentile_cont")
1015        );
1016        assert_eq!(
1017            AggregateType::ApproxMedian.datafusion_name(),
1018            Some("approx_median")
1019        );
1020    }
1021
1022    #[test]
1023    fn test_datafusion_name_custom() {
1024        assert_eq!(AggregateType::Custom.datafusion_name(), None);
1025    }
1026
1027    // ── is_decomposable() for new types ────────────────────────────
1028
1029    #[test]
1030    fn test_decomposable_new_types() {
1031        // Decomposable: bit aggregates, bool aggregates
1032        assert!(AggregateType::BoolAnd.is_decomposable());
1033        assert!(AggregateType::BoolOr.is_decomposable());
1034        assert!(AggregateType::BitAnd.is_decomposable());
1035        assert!(AggregateType::BitOr.is_decomposable());
1036        assert!(AggregateType::BitXor.is_decomposable());
1037
1038        // Not decomposable: statistical, percentile, approx, etc.
1039        assert!(!AggregateType::StdDev.is_decomposable());
1040        assert!(!AggregateType::Variance.is_decomposable());
1041        assert!(!AggregateType::Median.is_decomposable());
1042        assert!(!AggregateType::PercentileCont.is_decomposable());
1043        assert!(!AggregateType::Corr.is_decomposable());
1044    }
1045
1046    #[test]
1047    fn test_order_sensitive_new_types() {
1048        // Order-sensitive: percentile, string_agg, array_agg
1049        assert!(AggregateType::PercentileCont.is_order_sensitive());
1050        assert!(AggregateType::PercentileDisc.is_order_sensitive());
1051        assert!(AggregateType::StringAgg.is_order_sensitive());
1052        assert!(AggregateType::ArrayAgg.is_order_sensitive());
1053
1054        // Not order-sensitive: statistical aggregates
1055        assert!(!AggregateType::StdDev.is_order_sensitive());
1056        assert!(!AggregateType::Variance.is_order_sensitive());
1057        assert!(!AggregateType::Corr.is_order_sensitive());
1058    }
1059
1060    // ── Multi-aggregate with new types ─────────────────────────────
1061
1062    #[test]
1063    fn test_multi_aggregate_statistical() {
1064        let stmt = parse_statement(
1065            "SELECT AVG(price), STDDEV(price), VARIANCE(price), \
1066             MEDIAN(price) FROM trades GROUP BY symbol",
1067        );
1068        let analysis = analyze_aggregates(&stmt);
1069        assert_eq!(analysis.aggregates.len(), 4);
1070        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Avg);
1071        assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::StdDev);
1072        assert_eq!(
1073            analysis.aggregates[2].aggregate_type,
1074            AggregateType::Variance
1075        );
1076        assert_eq!(analysis.aggregates[3].aggregate_type, AggregateType::Median);
1077        assert!(!analysis.all_decomposable());
1078    }
1079
1080    #[test]
1081    fn test_multi_aggregate_mixed_with_filter() {
1082        let stmt = parse_statement(
1083            "SELECT COUNT(*), \
1084             SUM(amount) FILTER (WHERE status = 'complete'), \
1085             APPROX_COUNT_DISTINCT(user_id) FROM orders",
1086        );
1087        let analysis = analyze_aggregates(&stmt);
1088        assert_eq!(analysis.aggregates.len(), 3);
1089        assert!(!analysis.aggregates[0].has_filter());
1090        assert!(analysis.aggregates[1].has_filter());
1091        assert!(!analysis.aggregates[2].has_filter());
1092    }
1093
1094    // ── Arity tests ────────────────────────────────────────────────
1095
1096    #[test]
1097    fn test_arity() {
1098        assert_eq!(AggregateType::Count.arity(), 1);
1099        assert_eq!(AggregateType::Sum.arity(), 1);
1100        assert_eq!(AggregateType::Covar.arity(), 2);
1101        assert_eq!(AggregateType::CovarPop.arity(), 2);
1102        assert_eq!(AggregateType::Corr.arity(), 2);
1103        assert_eq!(AggregateType::RegrSlope.arity(), 2);
1104        assert_eq!(AggregateType::RegrIntercept.arity(), 2);
1105    }
1106
1107    // ── Case insensitivity ─────────────────────────────────────────
1108
1109    // ── HAVING expression extraction tests ──────────────────────────
1110
1111    #[test]
1112    fn test_having_expr_simple() {
1113        let stmt = parse_statement(
1114            "SELECT symbol, COUNT(*) FROM trades GROUP BY symbol HAVING COUNT(*) > 10",
1115        );
1116        let analysis = analyze_aggregates(&stmt);
1117        assert!(analysis.has_having);
1118        assert!(analysis.having_expr.is_some());
1119        let expr = analysis.having_expr.unwrap();
1120        assert!(expr.contains("COUNT(*)"), "expr was: {expr}");
1121        assert!(expr.contains("10"), "expr was: {expr}");
1122    }
1123
1124    #[test]
1125    fn test_having_expr_multiple_predicates() {
1126        let stmt = parse_statement(
1127            "SELECT symbol, SUM(volume) AS vol, AVG(price) AS avg_p \
1128             FROM trades GROUP BY symbol \
1129             HAVING SUM(volume) > 1000 AND AVG(price) < 50",
1130        );
1131        let analysis = analyze_aggregates(&stmt);
1132        assert!(analysis.has_having);
1133        let expr = analysis.having_expr.unwrap();
1134        assert!(expr.contains("SUM(volume)"), "expr was: {expr}");
1135        assert!(expr.contains("AVG(price)"), "expr was: {expr}");
1136        assert!(expr.contains("AND"), "expr was: {expr}");
1137    }
1138
1139    #[test]
1140    fn test_having_expr_with_or() {
1141        let stmt = parse_statement(
1142            "SELECT category, COUNT(*) FROM products GROUP BY category \
1143             HAVING COUNT(*) > 100 OR SUM(price) > 5000",
1144        );
1145        let analysis = analyze_aggregates(&stmt);
1146        assert!(analysis.has_having);
1147        let expr = analysis.having_expr.unwrap();
1148        assert!(expr.contains("OR"), "expr was: {expr}");
1149    }
1150
1151    #[test]
1152    fn test_having_expr_none_without_having() {
1153        let stmt = parse_statement("SELECT symbol, COUNT(*) FROM trades GROUP BY symbol");
1154        let analysis = analyze_aggregates(&stmt);
1155        assert!(!analysis.has_having);
1156        assert!(analysis.having_expr.is_none());
1157    }
1158
1159    #[test]
1160    fn test_having_expr_with_alias_reference() {
1161        let stmt = parse_statement(
1162            "SELECT symbol, COUNT(*) AS cnt FROM trades GROUP BY symbol HAVING COUNT(*) >= 5",
1163        );
1164        let analysis = analyze_aggregates(&stmt);
1165        assert!(analysis.has_having);
1166        let expr = analysis.having_expr.unwrap();
1167        assert!(expr.contains("COUNT(*)"), "expr was: {expr}");
1168        assert!(expr.contains('5'), "expr was: {expr}");
1169    }
1170
1171    #[test]
1172    fn test_case_insensitive_detection() {
1173        let stmt = parse_statement("SELECT stddev(price), Variance(price) FROM trades");
1174        let analysis = analyze_aggregates(&stmt);
1175        assert_eq!(analysis.aggregates.len(), 2);
1176        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
1177        assert_eq!(
1178            analysis.aggregates[1].aggregate_type,
1179            AggregateType::Variance
1180        );
1181    }
1182}