Skip to main content

laminar_sql/parser/
aggregation_parser.rs

1//! Aggregate function detection and extraction
2//!
3//! This module analyzes SQL queries to extract aggregate functions like
4//! COUNT, SUM, MIN, MAX, AVG, STDDEV, VARIANCE, PERCENTILE, and more.
5//! It determines the aggregation strategy and maps to DataFusion names.
6
7use sqlparser::ast::{
8    Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, OrderByExpr, Select, SelectItem,
9    SetExpr, Statement,
10};
11
12/// Types of aggregate functions supported.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum AggregateType {
15    // ── Core aggregates ─────────────────────────────────────────────
16    /// COUNT function
17    Count,
18    /// COUNT DISTINCT function
19    CountDistinct,
20    /// SUM function
21    Sum,
22    /// MIN function
23    Min,
24    /// MAX function
25    Max,
26    /// AVG function
27    Avg,
28    /// `FIRST_VALUE` function
29    FirstValue,
30    /// `LAST_VALUE` function
31    LastValue,
32
33    // ── Statistical aggregates ──────────────────────────────────────
34    /// Sample standard deviation (STDDEV / STDDEV_SAMP)
35    StdDev,
36    /// Population standard deviation (STDDEV_POP)
37    StdDevPop,
38    /// Sample variance (VARIANCE / VAR_SAMP)
39    Variance,
40    /// Population variance (VAR_POP / VARIANCE_POP)
41    VariancePop,
42    /// Median
43    Median,
44
45    // ── Percentile aggregates ───────────────────────────────────────
46    /// PERCENTILE_CONT (continuous interpolation)
47    PercentileCont,
48    /// PERCENTILE_DISC (discrete, nearest-rank)
49    PercentileDisc,
50
51    // ── Boolean aggregates ──────────────────────────────────────────
52    /// BOOL_AND / EVERY
53    BoolAnd,
54    /// BOOL_OR / ANY
55    BoolOr,
56
57    // ── Collection aggregates ───────────────────────────────────────
58    /// STRING_AGG / LISTAGG / GROUP_CONCAT
59    StringAgg,
60    /// ARRAY_AGG
61    ArrayAgg,
62
63    // ── Approximate aggregates ──────────────────────────────────────
64    /// APPROX_COUNT_DISTINCT
65    ApproxCountDistinct,
66    /// APPROX_PERCENTILE_CONT
67    ApproxPercentile,
68    /// APPROX_MEDIAN
69    ApproxMedian,
70
71    // ── Correlation / Regression ────────────────────────────────────
72    /// Covariance sample (COVAR_SAMP)
73    Covar,
74    /// Covariance population (COVAR_POP)
75    CovarPop,
76    /// Pearson correlation (CORR)
77    Corr,
78    /// Linear regression slope (REGR_SLOPE)
79    RegrSlope,
80    /// Linear regression intercept (REGR_INTERCEPT)
81    RegrIntercept,
82
83    // ── Bit aggregates ──────────────────────────────────────────────
84    /// BIT_AND
85    BitAnd,
86    /// BIT_OR
87    BitOr,
88    /// BIT_XOR
89    BitXor,
90
91    /// Custom / unrecognized aggregate function
92    Custom,
93}
94
95impl AggregateType {
96    /// Check if this aggregate is order-sensitive.
97    /// Order-sensitive aggregates require maintaining event order.
98    #[must_use]
99    pub fn is_order_sensitive(&self) -> bool {
100        matches!(
101            self,
102            AggregateType::FirstValue
103                | AggregateType::LastValue
104                | AggregateType::PercentileCont
105                | AggregateType::PercentileDisc
106                | AggregateType::StringAgg
107                | AggregateType::ArrayAgg
108        )
109    }
110
111    /// Check if this aggregate is decomposable (can be computed incrementally).
112    ///
113    /// Decomposable aggregates can be split into partial and final steps,
114    /// enabling parallel or distributed computation.
115    #[must_use]
116    pub fn is_decomposable(&self) -> bool {
117        matches!(
118            self,
119            AggregateType::Count
120                | AggregateType::Sum
121                | AggregateType::Min
122                | AggregateType::Max
123                | AggregateType::BoolAnd
124                | AggregateType::BoolOr
125                | AggregateType::BitAnd
126                | AggregateType::BitOr
127                | AggregateType::BitXor
128        )
129    }
130
131    /// Returns the DataFusion function registry name for this aggregate type,
132    /// or `None` if not directly mappable.
133    #[must_use]
134    pub fn datafusion_name(&self) -> Option<&'static str> {
135        match self {
136            AggregateType::Count | AggregateType::CountDistinct => Some("count"),
137            AggregateType::Sum => Some("sum"),
138            AggregateType::Min => Some("min"),
139            AggregateType::Max => Some("max"),
140            AggregateType::Avg => Some("avg"),
141            AggregateType::FirstValue => Some("first_value"),
142            AggregateType::LastValue => Some("last_value"),
143            AggregateType::StdDev => Some("stddev"),
144            AggregateType::StdDevPop => Some("stddev_pop"),
145            AggregateType::Variance => Some("variance"),
146            AggregateType::VariancePop => Some("variance_pop"),
147            AggregateType::Median => Some("median"),
148            AggregateType::PercentileCont => Some("percentile_cont"),
149            AggregateType::PercentileDisc => Some("percentile_disc"),
150            AggregateType::BoolAnd => Some("bool_and"),
151            AggregateType::BoolOr => Some("bool_or"),
152            AggregateType::StringAgg => Some("string_agg"),
153            AggregateType::ArrayAgg => Some("array_agg"),
154            AggregateType::ApproxCountDistinct => Some("approx_distinct"),
155            AggregateType::ApproxPercentile => Some("approx_percentile_cont"),
156            AggregateType::ApproxMedian => Some("approx_median"),
157            AggregateType::Covar => Some("covar_samp"),
158            AggregateType::CovarPop => Some("covar_pop"),
159            AggregateType::Corr => Some("corr"),
160            AggregateType::RegrSlope => Some("regr_slope"),
161            AggregateType::RegrIntercept => Some("regr_intercept"),
162            AggregateType::BitAnd => Some("bit_and"),
163            AggregateType::BitOr => Some("bit_or"),
164            AggregateType::BitXor => Some("bit_xor"),
165            AggregateType::Custom => None,
166        }
167    }
168
169    /// Returns the number of input columns required by this aggregate.
170    #[must_use]
171    pub fn arity(&self) -> usize {
172        match self {
173            AggregateType::Covar
174            | AggregateType::CovarPop
175            | AggregateType::Corr
176            | AggregateType::RegrSlope
177            | AggregateType::RegrIntercept => 2,
178            _ => 1,
179        }
180    }
181}
182
183/// Information about a detected aggregate function.
184#[derive(Debug, Clone)]
185pub struct AggregateInfo {
186    /// Type of aggregate
187    pub aggregate_type: AggregateType,
188    /// Column being aggregated (None for COUNT(*))
189    pub column: Option<String>,
190    /// Optional alias for the aggregate result
191    pub alias: Option<String>,
192    /// Whether DISTINCT is applied
193    pub distinct: bool,
194    /// FILTER clause expression (e.g. `COUNT(x) FILTER (WHERE x > 5)`)
195    pub filter: Option<Box<Expr>>,
196    /// WITHIN GROUP ORDER BY expressions
197    pub within_group: Vec<OrderByExpr>,
198}
199
200impl AggregateInfo {
201    /// Create a new aggregate info.
202    #[must_use]
203    pub fn new(aggregate_type: AggregateType, column: Option<String>) -> Self {
204        Self {
205            aggregate_type,
206            column,
207            alias: None,
208            distinct: false,
209            filter: None,
210            within_group: Vec::new(),
211        }
212    }
213
214    /// Set the alias.
215    #[must_use]
216    pub fn with_alias(mut self, alias: String) -> Self {
217        self.alias = Some(alias);
218        self
219    }
220
221    /// Set distinct flag.
222    #[must_use]
223    pub fn with_distinct(mut self, distinct: bool) -> Self {
224        self.distinct = distinct;
225        self
226    }
227
228    /// Check whether a FILTER clause is present.
229    #[must_use]
230    pub fn has_filter(&self) -> bool {
231        self.filter.is_some()
232    }
233
234    /// Check whether a WITHIN GROUP clause is present.
235    #[must_use]
236    pub fn has_within_group(&self) -> bool {
237        !self.within_group.is_empty()
238    }
239}
240
241/// Analysis result for aggregations in a query.
242#[derive(Debug, Clone, Default)]
243pub struct AggregationAnalysis {
244    /// List of aggregate functions found
245    pub aggregates: Vec<AggregateInfo>,
246    /// GROUP BY columns
247    pub group_by_columns: Vec<String>,
248    /// Whether the query has a HAVING clause
249    pub has_having: bool,
250}
251
252impl AggregationAnalysis {
253    /// Check if this analysis contains any aggregates.
254    #[must_use]
255    pub fn has_aggregates(&self) -> bool {
256        !self.aggregates.is_empty()
257    }
258
259    /// Check if any aggregate is order-sensitive.
260    #[must_use]
261    pub fn has_order_sensitive(&self) -> bool {
262        self.aggregates
263            .iter()
264            .any(|a| a.aggregate_type.is_order_sensitive())
265    }
266
267    /// Check if all aggregates are decomposable.
268    #[must_use]
269    pub fn all_decomposable(&self) -> bool {
270        self.aggregates
271            .iter()
272            .all(|a| a.aggregate_type.is_decomposable())
273    }
274
275    /// Get aggregates by type.
276    #[must_use]
277    pub fn get_by_type(&self, agg_type: AggregateType) -> Vec<&AggregateInfo> {
278        self.aggregates
279            .iter()
280            .filter(|a| a.aggregate_type == agg_type)
281            .collect()
282    }
283
284    /// Check if any aggregate has a FILTER clause.
285    #[must_use]
286    pub fn has_any_filter(&self) -> bool {
287        self.aggregates.iter().any(AggregateInfo::has_filter)
288    }
289
290    /// Check if any aggregate has a WITHIN GROUP clause.
291    #[must_use]
292    pub fn has_any_within_group(&self) -> bool {
293        self.aggregates.iter().any(AggregateInfo::has_within_group)
294    }
295}
296
297/// Analyze a SQL statement for aggregate functions.
298#[must_use]
299pub fn analyze_aggregates(stmt: &Statement) -> AggregationAnalysis {
300    let mut analysis = AggregationAnalysis::default();
301
302    if let Statement::Query(query) = stmt {
303        if let SetExpr::Select(select) = query.body.as_ref() {
304            analyze_select(&mut analysis, select);
305        }
306    }
307
308    analysis
309}
310
311/// Analyze a SELECT statement for aggregates.
312fn analyze_select(analysis: &mut AggregationAnalysis, select: &Select) {
313    // Check SELECT items for aggregate functions
314    for item in &select.projection {
315        match item {
316            SelectItem::UnnamedExpr(expr) => {
317                if let Some(agg) = extract_aggregate(expr, None) {
318                    analysis.aggregates.push(agg);
319                }
320            }
321            SelectItem::ExprWithAlias { expr, alias } => {
322                if let Some(agg) = extract_aggregate(expr, Some(alias.value.clone())) {
323                    analysis.aggregates.push(agg);
324                }
325            }
326            SelectItem::QualifiedWildcard(_, _) | SelectItem::Wildcard(_) => {}
327        }
328    }
329
330    // Extract GROUP BY columns
331    match &select.group_by {
332        GroupByExpr::Expressions(exprs, _modifiers) => {
333            for expr in exprs {
334                if let Some(col) = extract_column_name(expr) {
335                    analysis.group_by_columns.push(col);
336                }
337            }
338        }
339        GroupByExpr::All(_) => {}
340    }
341
342    // Check for HAVING clause
343    analysis.has_having = select.having.is_some();
344}
345
346/// Resolve a SQL function name (upper-cased) to an [`AggregateType`], handling
347/// both canonical names and common aliases.
348fn resolve_aggregate_type(name: &str, func: &Function) -> Option<AggregateType> {
349    match name {
350        // ── Core ────────────────────────────────────────────────────
351        "COUNT" => {
352            if has_distinct_arg(func) {
353                Some(AggregateType::CountDistinct)
354            } else {
355                Some(AggregateType::Count)
356            }
357        }
358        "SUM" => Some(AggregateType::Sum),
359        "MIN" => Some(AggregateType::Min),
360        "MAX" => Some(AggregateType::Max),
361        "AVG" | "MEAN" => Some(AggregateType::Avg),
362        "FIRST_VALUE" | "FIRST" => Some(AggregateType::FirstValue),
363        "LAST_VALUE" | "LAST" => Some(AggregateType::LastValue),
364
365        // ── Statistical ────────────────────────────────────────────
366        "STDDEV" | "STDDEV_SAMP" => Some(AggregateType::StdDev),
367        "STDDEV_POP" => Some(AggregateType::StdDevPop),
368        "VARIANCE" | "VAR_SAMP" | "VAR" => Some(AggregateType::Variance),
369        "VAR_POP" | "VARIANCE_POP" => Some(AggregateType::VariancePop),
370        "MEDIAN" => Some(AggregateType::Median),
371
372        // ── Percentile ─────────────────────────────────────────────
373        "PERCENTILE_CONT" => Some(AggregateType::PercentileCont),
374        "PERCENTILE_DISC" => Some(AggregateType::PercentileDisc),
375
376        // ── Boolean ────────────────────────────────────────────────
377        "BOOL_AND" | "EVERY" => Some(AggregateType::BoolAnd),
378        "BOOL_OR" | "ANY" => Some(AggregateType::BoolOr),
379
380        // ── Collection ─────────────────────────────────────────────
381        "STRING_AGG" | "LISTAGG" | "GROUP_CONCAT" => Some(AggregateType::StringAgg),
382        "ARRAY_AGG" => Some(AggregateType::ArrayAgg),
383
384        // ── Approximate ────────────────────────────────────────────
385        "APPROX_COUNT_DISTINCT" | "APPROX_DISTINCT" => Some(AggregateType::ApproxCountDistinct),
386        "APPROX_PERCENTILE_CONT" | "APPROX_PERCENTILE" => Some(AggregateType::ApproxPercentile),
387        "APPROX_MEDIAN" => Some(AggregateType::ApproxMedian),
388
389        // ── Correlation / Regression ───────────────────────────────
390        "COVAR_SAMP" | "COVAR" => Some(AggregateType::Covar),
391        "COVAR_POP" => Some(AggregateType::CovarPop),
392        "CORR" => Some(AggregateType::Corr),
393        "REGR_SLOPE" => Some(AggregateType::RegrSlope),
394        "REGR_INTERCEPT" => Some(AggregateType::RegrIntercept),
395
396        // ── Bit ────────────────────────────────────────────────────
397        "BIT_AND" => Some(AggregateType::BitAnd),
398        "BIT_OR" => Some(AggregateType::BitOr),
399        "BIT_XOR" => Some(AggregateType::BitXor),
400
401        _ => None,
402    }
403}
404
405/// Extract aggregate function from an expression.
406fn extract_aggregate(expr: &Expr, alias: Option<String>) -> Option<AggregateInfo> {
407    match expr {
408        Expr::Function(func) => {
409            let func_name = func.name.to_string().to_uppercase();
410            let agg_type = resolve_aggregate_type(&func_name, func)?;
411
412            let column = extract_first_arg_column(func);
413            let distinct = has_distinct_arg(func);
414
415            let mut info = AggregateInfo::new(agg_type, column).with_distinct(distinct);
416
417            // Extract FILTER clause
418            if let Some(filter_expr) = &func.filter {
419                info.filter = Some(filter_expr.clone());
420            }
421
422            // Extract WITHIN GROUP clause
423            if !func.within_group.is_empty() {
424                info.within_group.clone_from(&func.within_group);
425            }
426
427            if let Some(a) = alias {
428                info = info.with_alias(a);
429            }
430            Some(info)
431        }
432        // Handle nested expressions (e.g., CAST(COUNT(*) AS INT))
433        Expr::Cast { expr, .. } | Expr::Nested(expr) => extract_aggregate(expr, alias),
434        _ => None,
435    }
436}
437
438/// Check if the function has a DISTINCT argument.
439fn has_distinct_arg(func: &Function) -> bool {
440    // In sqlparser 0.60, DISTINCT is part of FunctionArgumentList
441    match &func.args {
442        sqlparser::ast::FunctionArguments::List(list) => list.duplicate_treatment.is_some(),
443        _ => false,
444    }
445}
446
447/// Extract the column name from the first argument of a function.
448fn extract_first_arg_column(func: &Function) -> Option<String> {
449    // Handle FunctionArguments::List
450    match &func.args {
451        sqlparser::ast::FunctionArguments::List(list) => {
452            if list.args.is_empty() {
453                return None;
454            }
455            match &list.args[0] {
456                FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => extract_column_name(expr),
457                FunctionArg::Named { arg, .. } | FunctionArg::ExprNamed { arg, .. } => {
458                    if let FunctionArgExpr::Expr(expr) = arg {
459                        extract_column_name(expr)
460                    } else {
461                        None
462                    }
463                }
464                // COUNT(*), QualifiedWildcard, etc.
465                FunctionArg::Unnamed(_) => None,
466            }
467        }
468        sqlparser::ast::FunctionArguments::Subquery(_)
469        | sqlparser::ast::FunctionArguments::None => None,
470    }
471}
472
473/// Extract column name from an expression.
474fn extract_column_name(expr: &Expr) -> Option<String> {
475    match expr {
476        Expr::Identifier(ident) => Some(ident.value.clone()),
477        Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
478        _ => None,
479    }
480}
481
482/// Check if a SELECT statement contains any aggregate functions.
483#[must_use]
484pub fn has_aggregates(stmt: &Statement) -> bool {
485    analyze_aggregates(stmt).has_aggregates()
486}
487
488/// Count the number of aggregate functions in a statement.
489#[must_use]
490pub fn count_aggregates(stmt: &Statement) -> usize {
491    analyze_aggregates(stmt).aggregates.len()
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use sqlparser::dialect::GenericDialect;
498    use sqlparser::parser::Parser;
499
500    fn parse_statement(sql: &str) -> Statement {
501        let dialect = GenericDialect {};
502        Parser::parse_sql(&dialect, sql).unwrap().remove(0)
503    }
504
505    // ── Core aggregate tests (existing, preserved) ──────────────────
506
507    #[test]
508    fn test_analyze_count() {
509        let stmt = parse_statement("SELECT COUNT(*) FROM events");
510        let analysis = analyze_aggregates(&stmt);
511
512        assert_eq!(analysis.aggregates.len(), 1);
513        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Count);
514        assert!(analysis.aggregates[0].column.is_none());
515    }
516
517    #[test]
518    fn test_analyze_count_column() {
519        let stmt = parse_statement("SELECT COUNT(id) FROM events");
520        let analysis = analyze_aggregates(&stmt);
521
522        assert_eq!(analysis.aggregates.len(), 1);
523        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Count);
524        assert_eq!(analysis.aggregates[0].column, Some("id".to_string()));
525    }
526
527    #[test]
528    fn test_analyze_count_distinct() {
529        let stmt = parse_statement("SELECT COUNT(DISTINCT user_id) FROM events");
530        let analysis = analyze_aggregates(&stmt);
531
532        assert_eq!(analysis.aggregates.len(), 1);
533        assert_eq!(
534            analysis.aggregates[0].aggregate_type,
535            AggregateType::CountDistinct
536        );
537        assert!(analysis.aggregates[0].distinct);
538    }
539
540    #[test]
541    fn test_analyze_sum() {
542        let stmt = parse_statement("SELECT SUM(amount) FROM orders");
543        let analysis = analyze_aggregates(&stmt);
544
545        assert_eq!(analysis.aggregates.len(), 1);
546        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Sum);
547        assert_eq!(analysis.aggregates[0].column, Some("amount".to_string()));
548    }
549
550    #[test]
551    fn test_analyze_min_max() {
552        let stmt = parse_statement("SELECT MIN(price), MAX(price) FROM products");
553        let analysis = analyze_aggregates(&stmt);
554
555        assert_eq!(analysis.aggregates.len(), 2);
556        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Min);
557        assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::Max);
558    }
559
560    #[test]
561    fn test_analyze_avg() {
562        let stmt = parse_statement("SELECT AVG(score) AS avg_score FROM tests");
563        let analysis = analyze_aggregates(&stmt);
564
565        assert_eq!(analysis.aggregates.len(), 1);
566        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Avg);
567        assert_eq!(analysis.aggregates[0].alias, Some("avg_score".to_string()));
568    }
569
570    #[test]
571    fn test_analyze_first_last() {
572        let stmt = parse_statement(
573            "SELECT FIRST_VALUE(price) AS open, LAST_VALUE(price) AS close FROM trades",
574        );
575        let analysis = analyze_aggregates(&stmt);
576
577        assert_eq!(analysis.aggregates.len(), 2);
578        assert_eq!(
579            analysis.aggregates[0].aggregate_type,
580            AggregateType::FirstValue
581        );
582        assert_eq!(
583            analysis.aggregates[1].aggregate_type,
584            AggregateType::LastValue
585        );
586        assert!(analysis.has_order_sensitive());
587    }
588
589    #[test]
590    fn test_analyze_group_by() {
591        let stmt = parse_statement("SELECT category, COUNT(*) FROM products GROUP BY category");
592        let analysis = analyze_aggregates(&stmt);
593
594        assert_eq!(analysis.aggregates.len(), 1);
595        assert_eq!(analysis.group_by_columns.len(), 1);
596        assert_eq!(analysis.group_by_columns[0], "category");
597    }
598
599    #[test]
600    fn test_analyze_multiple_group_by() {
601        let stmt = parse_statement(
602            "SELECT region, category, SUM(sales) FROM orders GROUP BY region, category",
603        );
604        let analysis = analyze_aggregates(&stmt);
605
606        assert_eq!(analysis.group_by_columns.len(), 2);
607        assert_eq!(analysis.group_by_columns[0], "region");
608        assert_eq!(analysis.group_by_columns[1], "category");
609    }
610
611    #[test]
612    fn test_analyze_having() {
613        let stmt = parse_statement(
614            "SELECT category, COUNT(*) FROM products GROUP BY category HAVING COUNT(*) > 10",
615        );
616        let analysis = analyze_aggregates(&stmt);
617
618        assert!(analysis.has_having);
619    }
620
621    #[test]
622    fn test_no_aggregates() {
623        let stmt = parse_statement("SELECT id, name FROM users");
624        let analysis = analyze_aggregates(&stmt);
625
626        assert!(!analysis.has_aggregates());
627        assert_eq!(analysis.aggregates.len(), 0);
628    }
629
630    #[test]
631    fn test_has_aggregates() {
632        let with_agg = parse_statement("SELECT COUNT(*) FROM events");
633        let without_agg = parse_statement("SELECT * FROM events");
634
635        assert!(has_aggregates(&with_agg));
636        assert!(!has_aggregates(&without_agg));
637    }
638
639    #[test]
640    fn test_count_aggregates() {
641        let stmt = parse_statement(
642            "SELECT COUNT(*), SUM(amount), AVG(price), MIN(qty), MAX(qty) FROM orders",
643        );
644        assert_eq!(count_aggregates(&stmt), 5);
645    }
646
647    #[test]
648    fn test_decomposable() {
649        let stmt =
650            parse_statement("SELECT COUNT(*), SUM(amount), MIN(price), MAX(price) FROM orders");
651        let analysis = analyze_aggregates(&stmt);
652        assert!(analysis.all_decomposable());
653
654        let stmt2 = parse_statement("SELECT AVG(price), FIRST_VALUE(price) FROM orders");
655        let analysis2 = analyze_aggregates(&stmt2);
656        assert!(!analysis2.all_decomposable());
657    }
658
659    #[test]
660    fn test_get_by_type() {
661        let stmt = parse_statement("SELECT COUNT(*), COUNT(id), SUM(amount) FROM orders");
662        let analysis = analyze_aggregates(&stmt);
663
664        let counts = analysis.get_by_type(AggregateType::Count);
665        assert_eq!(counts.len(), 2);
666
667        let sums = analysis.get_by_type(AggregateType::Sum);
668        assert_eq!(sums.len(), 1);
669    }
670
671    // ── New aggregate type detection tests ──────────────────────────
672
673    #[test]
674    fn test_stddev() {
675        let stmt = parse_statement("SELECT STDDEV(price) FROM trades");
676        let analysis = analyze_aggregates(&stmt);
677        assert_eq!(analysis.aggregates.len(), 1);
678        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
679    }
680
681    #[test]
682    fn test_stddev_pop() {
683        let stmt = parse_statement("SELECT STDDEV_POP(latency) FROM requests");
684        let analysis = analyze_aggregates(&stmt);
685        assert_eq!(
686            analysis.aggregates[0].aggregate_type,
687            AggregateType::StdDevPop
688        );
689    }
690
691    #[test]
692    fn test_variance() {
693        let stmt = parse_statement("SELECT VARIANCE(price) FROM trades");
694        let analysis = analyze_aggregates(&stmt);
695        assert_eq!(
696            analysis.aggregates[0].aggregate_type,
697            AggregateType::Variance
698        );
699    }
700
701    #[test]
702    fn test_variance_pop() {
703        let stmt = parse_statement("SELECT VAR_POP(price) FROM trades");
704        let analysis = analyze_aggregates(&stmt);
705        assert_eq!(
706            analysis.aggregates[0].aggregate_type,
707            AggregateType::VariancePop
708        );
709    }
710
711    #[test]
712    fn test_median() {
713        let stmt = parse_statement("SELECT MEDIAN(response_time) FROM requests");
714        let analysis = analyze_aggregates(&stmt);
715        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Median);
716    }
717
718    #[test]
719    fn test_percentile_cont() {
720        let stmt = parse_statement("SELECT PERCENTILE_CONT(0.95) FROM latencies");
721        let analysis = analyze_aggregates(&stmt);
722        assert_eq!(
723            analysis.aggregates[0].aggregate_type,
724            AggregateType::PercentileCont
725        );
726    }
727
728    #[test]
729    fn test_percentile_disc() {
730        let stmt = parse_statement("SELECT PERCENTILE_DISC(0.5) FROM scores");
731        let analysis = analyze_aggregates(&stmt);
732        assert_eq!(
733            analysis.aggregates[0].aggregate_type,
734            AggregateType::PercentileDisc
735        );
736    }
737
738    #[test]
739    fn test_bool_and() {
740        let stmt = parse_statement("SELECT BOOL_AND(is_active) FROM users");
741        let analysis = analyze_aggregates(&stmt);
742        assert_eq!(
743            analysis.aggregates[0].aggregate_type,
744            AggregateType::BoolAnd
745        );
746    }
747
748    #[test]
749    fn test_bool_or() {
750        let stmt = parse_statement("SELECT BOOL_OR(has_error) FROM events");
751        let analysis = analyze_aggregates(&stmt);
752        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::BoolOr);
753    }
754
755    #[test]
756    fn test_string_agg() {
757        let stmt = parse_statement("SELECT STRING_AGG(name, ',') FROM users");
758        let analysis = analyze_aggregates(&stmt);
759        assert_eq!(
760            analysis.aggregates[0].aggregate_type,
761            AggregateType::StringAgg
762        );
763        assert!(analysis.aggregates[0].aggregate_type.is_order_sensitive());
764    }
765
766    #[test]
767    fn test_array_agg() {
768        let stmt = parse_statement("SELECT ARRAY_AGG(id) FROM events");
769        let analysis = analyze_aggregates(&stmt);
770        assert_eq!(
771            analysis.aggregates[0].aggregate_type,
772            AggregateType::ArrayAgg
773        );
774    }
775
776    #[test]
777    fn test_approx_count_distinct() {
778        let stmt = parse_statement("SELECT APPROX_COUNT_DISTINCT(user_id) FROM events");
779        let analysis = analyze_aggregates(&stmt);
780        assert_eq!(
781            analysis.aggregates[0].aggregate_type,
782            AggregateType::ApproxCountDistinct
783        );
784    }
785
786    #[test]
787    fn test_approx_percentile() {
788        let stmt = parse_statement("SELECT APPROX_PERCENTILE_CONT(latency, 0.99) FROM req");
789        let analysis = analyze_aggregates(&stmt);
790        assert_eq!(
791            analysis.aggregates[0].aggregate_type,
792            AggregateType::ApproxPercentile
793        );
794    }
795
796    #[test]
797    fn test_approx_median() {
798        let stmt = parse_statement("SELECT APPROX_MEDIAN(price) FROM trades");
799        let analysis = analyze_aggregates(&stmt);
800        assert_eq!(
801            analysis.aggregates[0].aggregate_type,
802            AggregateType::ApproxMedian
803        );
804    }
805
806    #[test]
807    fn test_covar_samp() {
808        let stmt = parse_statement("SELECT COVAR_SAMP(x, y) FROM points");
809        let analysis = analyze_aggregates(&stmt);
810        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Covar);
811    }
812
813    #[test]
814    fn test_covar_pop() {
815        let stmt = parse_statement("SELECT COVAR_POP(x, y) FROM points");
816        let analysis = analyze_aggregates(&stmt);
817        assert_eq!(
818            analysis.aggregates[0].aggregate_type,
819            AggregateType::CovarPop
820        );
821    }
822
823    #[test]
824    fn test_corr() {
825        let stmt = parse_statement("SELECT CORR(x, y) FROM points");
826        let analysis = analyze_aggregates(&stmt);
827        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Corr);
828    }
829
830    #[test]
831    fn test_regr_slope() {
832        let stmt = parse_statement("SELECT REGR_SLOPE(y, x) FROM data");
833        let analysis = analyze_aggregates(&stmt);
834        assert_eq!(
835            analysis.aggregates[0].aggregate_type,
836            AggregateType::RegrSlope
837        );
838    }
839
840    #[test]
841    fn test_regr_intercept() {
842        let stmt = parse_statement("SELECT REGR_INTERCEPT(y, x) FROM data");
843        let analysis = analyze_aggregates(&stmt);
844        assert_eq!(
845            analysis.aggregates[0].aggregate_type,
846            AggregateType::RegrIntercept
847        );
848    }
849
850    #[test]
851    fn test_bit_aggregates() {
852        let stmt =
853            parse_statement("SELECT BIT_AND(flags), BIT_OR(flags), BIT_XOR(flags) FROM events");
854        let analysis = analyze_aggregates(&stmt);
855        assert_eq!(analysis.aggregates.len(), 3);
856        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::BitAnd);
857        assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::BitOr);
858        assert_eq!(analysis.aggregates[2].aggregate_type, AggregateType::BitXor);
859    }
860
861    // ── Alias synonym tests ────────────────────────────────────────
862
863    #[test]
864    fn test_alias_stddev_samp() {
865        let stmt = parse_statement("SELECT STDDEV_SAMP(price) FROM trades");
866        let analysis = analyze_aggregates(&stmt);
867        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
868    }
869
870    #[test]
871    fn test_alias_var_samp() {
872        let stmt = parse_statement("SELECT VAR_SAMP(price) FROM trades");
873        let analysis = analyze_aggregates(&stmt);
874        assert_eq!(
875            analysis.aggregates[0].aggregate_type,
876            AggregateType::Variance
877        );
878    }
879
880    #[test]
881    fn test_alias_every() {
882        let stmt = parse_statement("SELECT EVERY(is_valid) FROM checks");
883        let analysis = analyze_aggregates(&stmt);
884        assert_eq!(
885            analysis.aggregates[0].aggregate_type,
886            AggregateType::BoolAnd
887        );
888    }
889
890    #[test]
891    fn test_alias_listagg() {
892        let stmt = parse_statement("SELECT LISTAGG(name, ',') FROM users");
893        let analysis = analyze_aggregates(&stmt);
894        assert_eq!(
895            analysis.aggregates[0].aggregate_type,
896            AggregateType::StringAgg
897        );
898    }
899
900    #[test]
901    fn test_alias_group_concat() {
902        let stmt = parse_statement("SELECT GROUP_CONCAT(name, ',') FROM users");
903        let analysis = analyze_aggregates(&stmt);
904        assert_eq!(
905            analysis.aggregates[0].aggregate_type,
906            AggregateType::StringAgg
907        );
908    }
909
910    // ── FILTER clause tests ────────────────────────────────────────
911
912    #[test]
913    fn test_filter_clause_count() {
914        let stmt = parse_statement("SELECT COUNT(*) FILTER (WHERE status = 'active') FROM users");
915        let analysis = analyze_aggregates(&stmt);
916        assert_eq!(analysis.aggregates.len(), 1);
917        assert!(analysis.aggregates[0].has_filter());
918        assert!(analysis.has_any_filter());
919    }
920
921    #[test]
922    fn test_filter_clause_sum() {
923        let stmt = parse_statement(
924            "SELECT SUM(amount) FILTER (WHERE category = 'A') AS sum_a FROM orders",
925        );
926        let analysis = analyze_aggregates(&stmt);
927        assert!(analysis.aggregates[0].has_filter());
928        assert_eq!(analysis.aggregates[0].alias, Some("sum_a".to_string()));
929    }
930
931    #[test]
932    fn test_filter_clause_mixed() {
933        let stmt = parse_statement("SELECT COUNT(*), COUNT(*) FILTER (WHERE x > 0) FROM t");
934        let analysis = analyze_aggregates(&stmt);
935        assert_eq!(analysis.aggregates.len(), 2);
936        assert!(!analysis.aggregates[0].has_filter());
937        assert!(analysis.aggregates[1].has_filter());
938    }
939
940    #[test]
941    fn test_no_filter() {
942        let stmt = parse_statement("SELECT SUM(amount) FROM orders");
943        let analysis = analyze_aggregates(&stmt);
944        assert!(!analysis.aggregates[0].has_filter());
945        assert!(!analysis.has_any_filter());
946    }
947
948    // ── WITHIN GROUP tests ─────────────────────────────────────────
949
950    #[test]
951    fn test_within_group_percentile_cont() {
952        let stmt =
953            parse_statement("SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY latency) FROM req");
954        let analysis = analyze_aggregates(&stmt);
955        assert_eq!(analysis.aggregates.len(), 1);
956        assert!(analysis.aggregates[0].has_within_group());
957        assert_eq!(analysis.aggregates[0].within_group.len(), 1);
958        assert!(analysis.has_any_within_group());
959    }
960
961    #[test]
962    fn test_within_group_string_agg() {
963        let stmt =
964            parse_statement("SELECT STRING_AGG(name, ',') WITHIN GROUP (ORDER BY name) FROM users");
965        let analysis = analyze_aggregates(&stmt);
966        assert!(analysis.aggregates[0].has_within_group());
967    }
968
969    #[test]
970    fn test_no_within_group() {
971        let stmt = parse_statement("SELECT SUM(amount) FROM orders");
972        let analysis = analyze_aggregates(&stmt);
973        assert!(!analysis.aggregates[0].has_within_group());
974        assert!(!analysis.has_any_within_group());
975    }
976
977    // ── datafusion_name() tests ────────────────────────────────────
978
979    #[test]
980    fn test_datafusion_name_core() {
981        assert_eq!(AggregateType::Count.datafusion_name(), Some("count"));
982        assert_eq!(AggregateType::Sum.datafusion_name(), Some("sum"));
983        assert_eq!(AggregateType::Min.datafusion_name(), Some("min"));
984        assert_eq!(AggregateType::Max.datafusion_name(), Some("max"));
985        assert_eq!(AggregateType::Avg.datafusion_name(), Some("avg"));
986    }
987
988    #[test]
989    fn test_datafusion_name_statistical() {
990        assert_eq!(AggregateType::StdDev.datafusion_name(), Some("stddev"));
991        assert_eq!(
992            AggregateType::StdDevPop.datafusion_name(),
993            Some("stddev_pop")
994        );
995        assert_eq!(AggregateType::Variance.datafusion_name(), Some("variance"));
996        assert_eq!(
997            AggregateType::VariancePop.datafusion_name(),
998            Some("variance_pop")
999        );
1000        assert_eq!(AggregateType::Median.datafusion_name(), Some("median"));
1001    }
1002
1003    #[test]
1004    fn test_datafusion_name_approx() {
1005        assert_eq!(
1006            AggregateType::ApproxCountDistinct.datafusion_name(),
1007            Some("approx_distinct")
1008        );
1009        assert_eq!(
1010            AggregateType::ApproxPercentile.datafusion_name(),
1011            Some("approx_percentile_cont")
1012        );
1013        assert_eq!(
1014            AggregateType::ApproxMedian.datafusion_name(),
1015            Some("approx_median")
1016        );
1017    }
1018
1019    #[test]
1020    fn test_datafusion_name_custom() {
1021        assert_eq!(AggregateType::Custom.datafusion_name(), None);
1022    }
1023
1024    // ── is_decomposable() for new types ────────────────────────────
1025
1026    #[test]
1027    fn test_decomposable_new_types() {
1028        // Decomposable: bit aggregates, bool aggregates
1029        assert!(AggregateType::BoolAnd.is_decomposable());
1030        assert!(AggregateType::BoolOr.is_decomposable());
1031        assert!(AggregateType::BitAnd.is_decomposable());
1032        assert!(AggregateType::BitOr.is_decomposable());
1033        assert!(AggregateType::BitXor.is_decomposable());
1034
1035        // Not decomposable: statistical, percentile, approx, etc.
1036        assert!(!AggregateType::StdDev.is_decomposable());
1037        assert!(!AggregateType::Variance.is_decomposable());
1038        assert!(!AggregateType::Median.is_decomposable());
1039        assert!(!AggregateType::PercentileCont.is_decomposable());
1040        assert!(!AggregateType::Corr.is_decomposable());
1041    }
1042
1043    #[test]
1044    fn test_order_sensitive_new_types() {
1045        // Order-sensitive: percentile, string_agg, array_agg
1046        assert!(AggregateType::PercentileCont.is_order_sensitive());
1047        assert!(AggregateType::PercentileDisc.is_order_sensitive());
1048        assert!(AggregateType::StringAgg.is_order_sensitive());
1049        assert!(AggregateType::ArrayAgg.is_order_sensitive());
1050
1051        // Not order-sensitive: statistical aggregates
1052        assert!(!AggregateType::StdDev.is_order_sensitive());
1053        assert!(!AggregateType::Variance.is_order_sensitive());
1054        assert!(!AggregateType::Corr.is_order_sensitive());
1055    }
1056
1057    // ── Multi-aggregate with new types ─────────────────────────────
1058
1059    #[test]
1060    fn test_multi_aggregate_statistical() {
1061        let stmt = parse_statement(
1062            "SELECT AVG(price), STDDEV(price), VARIANCE(price), \
1063             MEDIAN(price) FROM trades GROUP BY symbol",
1064        );
1065        let analysis = analyze_aggregates(&stmt);
1066        assert_eq!(analysis.aggregates.len(), 4);
1067        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::Avg);
1068        assert_eq!(analysis.aggregates[1].aggregate_type, AggregateType::StdDev);
1069        assert_eq!(
1070            analysis.aggregates[2].aggregate_type,
1071            AggregateType::Variance
1072        );
1073        assert_eq!(analysis.aggregates[3].aggregate_type, AggregateType::Median);
1074        assert!(!analysis.all_decomposable());
1075    }
1076
1077    #[test]
1078    fn test_multi_aggregate_mixed_with_filter() {
1079        let stmt = parse_statement(
1080            "SELECT COUNT(*), \
1081             SUM(amount) FILTER (WHERE status = 'complete'), \
1082             APPROX_COUNT_DISTINCT(user_id) FROM orders",
1083        );
1084        let analysis = analyze_aggregates(&stmt);
1085        assert_eq!(analysis.aggregates.len(), 3);
1086        assert!(!analysis.aggregates[0].has_filter());
1087        assert!(analysis.aggregates[1].has_filter());
1088        assert!(!analysis.aggregates[2].has_filter());
1089    }
1090
1091    // ── Arity tests ────────────────────────────────────────────────
1092
1093    #[test]
1094    fn test_arity() {
1095        assert_eq!(AggregateType::Count.arity(), 1);
1096        assert_eq!(AggregateType::Sum.arity(), 1);
1097        assert_eq!(AggregateType::Covar.arity(), 2);
1098        assert_eq!(AggregateType::CovarPop.arity(), 2);
1099        assert_eq!(AggregateType::Corr.arity(), 2);
1100        assert_eq!(AggregateType::RegrSlope.arity(), 2);
1101        assert_eq!(AggregateType::RegrIntercept.arity(), 2);
1102    }
1103
1104    // ── Case insensitivity ─────────────────────────────────────────
1105
1106    #[test]
1107    fn test_case_insensitive_detection() {
1108        let stmt = parse_statement("SELECT stddev(price), Variance(price) FROM trades");
1109        let analysis = analyze_aggregates(&stmt);
1110        assert_eq!(analysis.aggregates.len(), 2);
1111        assert_eq!(analysis.aggregates[0].aggregate_type, AggregateType::StdDev);
1112        assert_eq!(
1113            analysis.aggregates[1].aggregate_type,
1114            AggregateType::Variance
1115        );
1116    }
1117}