Skip to main content

sochdb_query/sql/
aggregate.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! # SQL Aggregation Executor
19//!
20//! Hash-aggregation operator for `GROUP BY` and aggregate functions.
21//!
22//! Supported aggregates: `COUNT(*)`, `COUNT(col)`, `COUNT(DISTINCT col)`,
23//! `SUM`, `AVG`, `MIN`, `MAX`, `MEDIAN`, `STDDEV` (sample, n-1, matching
24//! R's `sd()` and DuckDB's `stddev`).
25//!
26//! ## Pipeline
27//!
28//! ```text
29//! input rows (post-WHERE)
30//!   └─> group keys evaluated per row ──> hash table of group states
31//!         └─> accumulators updated per row
32//!               └─> finalize: one synthesized row per group
33//!                     └─> HAVING filter ─> ORDER BY ─> OFFSET/LIMIT ─> projection
34//! ```
35//!
36//! Semantics notes:
37//! - NULL inputs are skipped by all aggregates except `COUNT(*)` (SQL standard).
38//! - An ungrouped aggregate over zero rows yields exactly one row
39//!   (`COUNT` = 0, other aggregates NULL); a grouped aggregate over zero
40//!   rows yields zero rows.
41//! - Non-aggregate SELECT columns that are not in `GROUP BY` resolve to the
42//!   first value seen in the group (lenient mode, like SQLite / MySQL with
43//!   `ONLY_FULL_GROUP_BY` disabled).
44
45use super::ast::*;
46use super::bridge::ExecutionResult;
47use super::error::SqlResult;
48use rayon::prelude::*;
49use sochdb_core::SochValue;
50use std::collections::{HashMap, HashSet};
51
52/// Row count above which grouped accumulation runs on the rayon pool.
53const PARALLEL_THRESHOLD: usize = 100_000;
54
55// ============================================================================
56// Aggregate function identification
57// ============================================================================
58
59/// Recognized aggregate functions.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum AggFn {
62    Count,
63    Sum,
64    Avg,
65    Min,
66    Max,
67    Median,
68    Stddev,
69}
70
71impl AggFn {
72    /// Recognize an aggregate function by name (case-insensitive).
73    pub fn from_name(name: &str) -> Option<Self> {
74        match name.to_ascii_uppercase().as_str() {
75            "COUNT" => Some(Self::Count),
76            "SUM" => Some(Self::Sum),
77            "AVG" | "MEAN" => Some(Self::Avg),
78            "MIN" => Some(Self::Min),
79            "MAX" => Some(Self::Max),
80            "MEDIAN" => Some(Self::Median),
81            "STDDEV" | "STDDEV_SAMP" | "STDEV" | "SD" => Some(Self::Stddev),
82            _ => None,
83        }
84    }
85}
86
87/// One aggregate call discovered in the query, e.g. `sum(v1)`.
88#[derive(Debug, Clone)]
89struct AggSpec {
90    /// Canonical key, e.g. `"sum(v1)"` — used to bind HAVING / ORDER BY
91    /// references back to the computed value.
92    key: String,
93    func: AggFn,
94    /// Argument expression (`None` for `COUNT(*)`).
95    arg: Option<Expr>,
96    distinct: bool,
97}
98
99/// Returns true if the SELECT needs the aggregation operator.
100pub fn is_aggregate_query(select: &SelectStmt) -> bool {
101    if !select.group_by.is_empty() {
102        return true;
103    }
104    select
105        .columns
106        .iter()
107        .any(|item| matches!(item, SelectItem::Expr { expr, .. } if contains_aggregate(expr)))
108}
109
110/// Recursively check whether an expression contains an aggregate call.
111fn contains_aggregate(expr: &Expr) -> bool {
112    match expr {
113        Expr::Function(f) => {
114            AggFn::from_name(f.name.name()).is_some() || f.args.iter().any(contains_aggregate)
115        }
116        Expr::BinaryOp { left, right, .. } => contains_aggregate(left) || contains_aggregate(right),
117        Expr::UnaryOp { expr, .. } => contains_aggregate(expr),
118        Expr::IsNull { expr, .. } => contains_aggregate(expr),
119        Expr::Case {
120            operand,
121            conditions,
122            else_result,
123        } => {
124            operand.as_deref().map(contains_aggregate).unwrap_or(false)
125                || conditions
126                    .iter()
127                    .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
128                || else_result
129                    .as_deref()
130                    .map(contains_aggregate)
131                    .unwrap_or(false)
132        }
133        _ => false,
134    }
135}
136
137/// Collect all distinct aggregate calls from SELECT, HAVING and ORDER BY.
138fn collect_agg_specs(select: &SelectStmt) -> Vec<AggSpec> {
139    let mut specs: Vec<AggSpec> = Vec::new();
140    let mut seen: HashSet<String> = HashSet::new();
141
142    let walk = |expr: &Expr, specs: &mut Vec<AggSpec>, seen: &mut HashSet<String>| {
143        collect_from_expr(expr, specs, seen);
144    };
145
146    for item in &select.columns {
147        if let SelectItem::Expr { expr, .. } = item {
148            walk(expr, &mut specs, &mut seen);
149        }
150    }
151    if let Some(h) = &select.having {
152        walk(h, &mut specs, &mut seen);
153    }
154    for ob in &select.order_by {
155        walk(&ob.expr, &mut specs, &mut seen);
156    }
157    specs
158}
159
160fn collect_from_expr(expr: &Expr, specs: &mut Vec<AggSpec>, seen: &mut HashSet<String>) {
161    match expr {
162        Expr::Function(f) => {
163            if let Some(func) = AggFn::from_name(f.name.name()) {
164                let arg = f.args.first().cloned();
165                let is_star = matches!(arg.as_ref(), Some(Expr::Column(c)) if c.column == "*");
166                let arg = if is_star { None } else { arg };
167                let key = render_agg_key(func, arg.as_ref(), f.distinct);
168                if seen.insert(key.clone()) {
169                    specs.push(AggSpec {
170                        key,
171                        func,
172                        arg,
173                        distinct: f.distinct,
174                    });
175                }
176            } else {
177                for a in &f.args {
178                    collect_from_expr(a, specs, seen);
179                }
180            }
181        }
182        Expr::BinaryOp { left, right, .. } => {
183            collect_from_expr(left, specs, seen);
184            collect_from_expr(right, specs, seen);
185        }
186        Expr::UnaryOp { expr, .. } => collect_from_expr(expr, specs, seen),
187        Expr::IsNull { expr, .. } => collect_from_expr(expr, specs, seen),
188        Expr::Case {
189            operand,
190            conditions,
191            else_result,
192        } => {
193            if let Some(op) = operand {
194                collect_from_expr(op, specs, seen);
195            }
196            for (w, t) in conditions {
197                collect_from_expr(w, specs, seen);
198                collect_from_expr(t, specs, seen);
199            }
200            if let Some(e) = else_result {
201                collect_from_expr(e, specs, seen);
202            }
203        }
204        _ => {}
205    }
206}
207
208/// Canonical name for an aggregate call: `sum(v1)`, `count(*)`,
209/// `count(distinct id)`. Lowercased so lookups are case-insensitive.
210fn render_agg_key(func: AggFn, arg: Option<&Expr>, distinct: bool) -> String {
211    let fname = match func {
212        AggFn::Count => "count",
213        AggFn::Sum => "sum",
214        AggFn::Avg => "avg",
215        AggFn::Min => "min",
216        AggFn::Max => "max",
217        AggFn::Median => "median",
218        AggFn::Stddev => "stddev",
219    };
220    let arg_s = match arg {
221        None => "*".to_string(),
222        Some(e) => render_expr_name(e),
223    };
224    if distinct {
225        format!("{}(distinct {})", fname, arg_s)
226    } else {
227        format!("{}({})", fname, arg_s)
228    }
229}
230
231/// Human-readable name for an expression, used for output column naming
232/// and canonical aggregate keys.
233pub fn render_expr_name(expr: &Expr) -> String {
234    match expr {
235        Expr::Column(c) => {
236            if let Some(t) = &c.table {
237                format!("{}.{}", t, c.column)
238            } else {
239                c.column.clone()
240            }
241        }
242        Expr::Literal(Literal::Integer(n)) => n.to_string(),
243        Expr::Literal(Literal::Float(f)) => f.to_string(),
244        Expr::Literal(Literal::String(s)) => format!("'{}'", s),
245        Expr::Literal(Literal::Boolean(b)) => b.to_string(),
246        Expr::Literal(Literal::Null) => "null".to_string(),
247        Expr::Function(f) => {
248            if let Some(func) = AggFn::from_name(f.name.name()) {
249                let arg = f.args.first();
250                let is_star = matches!(arg, Some(Expr::Column(c)) if c.column == "*");
251                render_agg_key(func, if is_star { None } else { arg }, f.distinct)
252            } else {
253                let args: Vec<String> = f.args.iter().map(render_expr_name).collect();
254                format!("{}({})", f.name.name().to_lowercase(), args.join(", "))
255            }
256        }
257        Expr::BinaryOp { left, op, right } => format!(
258            "{} {} {}",
259            render_expr_name(left),
260            binary_op_symbol(op),
261            render_expr_name(right)
262        ),
263        Expr::UnaryOp { op, expr } => match op {
264            UnaryOperator::Minus => format!("-{}", render_expr_name(expr)),
265            UnaryOperator::Plus => render_expr_name(expr),
266            UnaryOperator::Not => format!("not {}", render_expr_name(expr)),
267            UnaryOperator::BitNot => format!("~{}", render_expr_name(expr)),
268        },
269        _ => "expr".to_string(),
270    }
271}
272
273fn binary_op_symbol(op: &BinaryOperator) -> &'static str {
274    match op {
275        BinaryOperator::Plus => "+",
276        BinaryOperator::Minus => "-",
277        BinaryOperator::Multiply => "*",
278        BinaryOperator::Divide => "/",
279        BinaryOperator::Modulo => "%",
280        BinaryOperator::Eq => "=",
281        BinaryOperator::Ne => "<>",
282        BinaryOperator::Lt => "<",
283        BinaryOperator::Le => "<=",
284        BinaryOperator::Gt => ">",
285        BinaryOperator::Ge => ">=",
286        BinaryOperator::And => "and",
287        BinaryOperator::Or => "or",
288        _ => "?",
289    }
290}
291
292// ============================================================================
293// Scalar expression evaluation (over a materialized row)
294// ============================================================================
295
296/// Evaluate a scalar expression against a row map.
297///
298/// `agg_values`, when provided, resolves aggregate function calls by their
299/// canonical key — used for HAVING / ORDER BY / projection over finalized
300/// group rows.
301fn eval_scalar(expr: &Expr, row: &HashMap<String, SochValue>, params: &[SochValue]) -> SochValue {
302    match expr {
303        Expr::Column(c) => {
304            if let Some(t) = &c.table {
305                let qualified = format!("{}.{}", t, c.column);
306                if let Some(v) = row.get(&qualified) {
307                    return v.clone();
308                }
309            }
310            row.get(&c.column).cloned().unwrap_or(SochValue::Null)
311        }
312        Expr::Literal(lit) => literal_to_value(lit),
313        Expr::Placeholder(idx) => params
314            .get((*idx as usize).saturating_sub(1))
315            .cloned()
316            .unwrap_or(SochValue::Null),
317        Expr::Function(f) => {
318            // Aggregate results are pre-bound into the row map under their
319            // canonical key by `finalize_groups`.
320            let key = render_expr_name(&Expr::Function(f.clone()));
321            row.get(&key).cloned().unwrap_or(SochValue::Null)
322        }
323        Expr::BinaryOp { left, op, right } => {
324            let l = eval_scalar(left, row, params);
325            let r = eval_scalar(right, row, params);
326            eval_binary(&l, op, &r)
327        }
328        Expr::UnaryOp { op, expr } => {
329            let v = eval_scalar(expr, row, params);
330            match op {
331                UnaryOperator::Minus => match v {
332                    SochValue::Int(i) => SochValue::Int(-i),
333                    SochValue::Float(f) => SochValue::Float(-f),
334                    _ => SochValue::Null,
335                },
336                UnaryOperator::Plus => v,
337                UnaryOperator::Not => match v {
338                    SochValue::Bool(b) => SochValue::Bool(!b),
339                    _ => SochValue::Null,
340                },
341                UnaryOperator::BitNot => match v {
342                    SochValue::Int(i) => SochValue::Int(!i),
343                    _ => SochValue::Null,
344                },
345            }
346        }
347        Expr::IsNull { expr, negated } => {
348            let v = eval_scalar(expr, row, params);
349            let is_null = v.is_null();
350            SochValue::Bool(if *negated { !is_null } else { is_null })
351        }
352        _ => SochValue::Null,
353    }
354}
355
356fn literal_to_value(lit: &Literal) -> SochValue {
357    match lit {
358        Literal::Integer(i) => SochValue::Int(*i),
359        Literal::Float(f) => SochValue::Float(*f),
360        Literal::String(s) => SochValue::Text(s.clone()),
361        Literal::Boolean(b) => SochValue::Bool(*b),
362        Literal::Null => SochValue::Null,
363        _ => SochValue::Null,
364    }
365}
366
367fn numeric(v: &SochValue) -> Option<f64> {
368    match v {
369        SochValue::Int(i) => Some(*i as f64),
370        SochValue::UInt(u) => Some(*u as f64),
371        SochValue::Float(f) => Some(*f),
372        SochValue::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
373        _ => None,
374    }
375}
376
377fn eval_binary(l: &SochValue, op: &BinaryOperator, r: &SochValue) -> SochValue {
378    use BinaryOperator::*;
379    match op {
380        Plus | Minus | Multiply | Divide | Modulo => {
381            // Integer arithmetic when both sides are ints (except division).
382            if let (SochValue::Int(a), SochValue::Int(b)) = (l, r) {
383                return match op {
384                    Plus => SochValue::Int(a.wrapping_add(*b)),
385                    Minus => SochValue::Int(a.wrapping_sub(*b)),
386                    Multiply => SochValue::Int(a.wrapping_mul(*b)),
387                    Divide => {
388                        if *b == 0 {
389                            SochValue::Null
390                        } else {
391                            SochValue::Float(*a as f64 / *b as f64)
392                        }
393                    }
394                    Modulo => {
395                        if *b == 0 {
396                            SochValue::Null
397                        } else {
398                            SochValue::Int(a % b)
399                        }
400                    }
401                    _ => unreachable!(),
402                };
403            }
404            let (a, b) = match (numeric(l), numeric(r)) {
405                (Some(a), Some(b)) => (a, b),
406                _ => return SochValue::Null,
407            };
408            match op {
409                Plus => SochValue::Float(a + b),
410                Minus => SochValue::Float(a - b),
411                Multiply => SochValue::Float(a * b),
412                Divide => {
413                    if b == 0.0 {
414                        SochValue::Null
415                    } else {
416                        SochValue::Float(a / b)
417                    }
418                }
419                Modulo => {
420                    if b == 0.0 {
421                        SochValue::Null
422                    } else {
423                        SochValue::Float(a % b)
424                    }
425                }
426                _ => unreachable!(),
427            }
428        }
429        Eq | Ne | Lt | Le | Gt | Ge => {
430            if l.is_null() || r.is_null() {
431                return SochValue::Null;
432            }
433            let ord = compare_values(l, r);
434            let b = match op {
435                Eq => ord == std::cmp::Ordering::Equal,
436                Ne => ord != std::cmp::Ordering::Equal,
437                Lt => ord == std::cmp::Ordering::Less,
438                Le => ord != std::cmp::Ordering::Greater,
439                Gt => ord == std::cmp::Ordering::Greater,
440                Ge => ord != std::cmp::Ordering::Less,
441                _ => unreachable!(),
442            };
443            SochValue::Bool(b)
444        }
445        And => match (as_bool(l), as_bool(r)) {
446            (Some(a), Some(b)) => SochValue::Bool(a && b),
447            _ => SochValue::Null,
448        },
449        Or => match (as_bool(l), as_bool(r)) {
450            (Some(a), Some(b)) => SochValue::Bool(a || b),
451            _ => SochValue::Null,
452        },
453        _ => SochValue::Null,
454    }
455}
456
457fn as_bool(v: &SochValue) -> Option<bool> {
458    match v {
459        SochValue::Bool(b) => Some(*b),
460        SochValue::Int(i) => Some(*i != 0),
461        SochValue::Null => None,
462        _ => None,
463    }
464}
465
466/// Total ordering across SochValue for grouping/sorting.
467pub fn compare_values(a: &SochValue, b: &SochValue) -> std::cmp::Ordering {
468    use std::cmp::Ordering;
469    match (numeric(a), numeric(b)) {
470        (Some(x), Some(y)) => return x.partial_cmp(&y).unwrap_or(Ordering::Equal),
471        _ => {}
472    }
473    match (a, b) {
474        (SochValue::Text(x), SochValue::Text(y)) => x.cmp(y),
475        (SochValue::Null, SochValue::Null) => Ordering::Equal,
476        (SochValue::Null, _) => Ordering::Less,
477        (_, SochValue::Null) => Ordering::Greater,
478        _ => Ordering::Equal,
479    }
480}
481
482/// Canonical hash representation of a group-key value.
483/// Normalizes Int/UInt/Float-of-integral so `1`, `1u`, `1.0` group together.
484fn key_repr(v: &SochValue) -> String {
485    match v {
486        SochValue::Null => "\u{0}N".to_string(),
487        SochValue::Int(i) => format!("i{}", i),
488        SochValue::UInt(u) => format!("i{}", u),
489        SochValue::Float(f) => {
490            if f.fract() == 0.0 && f.abs() < 9.0e15 {
491                format!("i{}", *f as i64)
492            } else {
493                format!("f{}", f)
494            }
495        }
496        SochValue::Text(s) => format!("s{}", s),
497        SochValue::Bool(b) => format!("b{}", b),
498        other => format!("{:?}", other),
499    }
500}
501
502// ============================================================================
503// Accumulators
504// ============================================================================
505
506#[derive(Debug)]
507enum Acc {
508    CountStar(u64),
509    Count(u64),
510    CountDistinct(HashSet<String>),
511    /// Sum preserving integer-ness: (int_sum, float_sum, saw_float, saw_any)
512    Sum {
513        int: i64,
514        float: f64,
515        saw_float: bool,
516        saw_any: bool,
517        overflowed: bool,
518    },
519    Avg {
520        sum: f64,
521        n: u64,
522    },
523    Min(Option<SochValue>),
524    Max(Option<SochValue>),
525    Median(Vec<f64>),
526    /// Welford online variance: (n, mean, m2)
527    Stddev {
528        n: u64,
529        mean: f64,
530        m2: f64,
531    },
532}
533
534impl Acc {
535    fn new(spec: &AggSpec) -> Self {
536        match (spec.func, spec.arg.is_some(), spec.distinct) {
537            (AggFn::Count, false, _) => Acc::CountStar(0),
538            (AggFn::Count, true, true) => Acc::CountDistinct(HashSet::new()),
539            (AggFn::Count, true, false) => Acc::Count(0),
540            (AggFn::Sum, _, _) => Acc::Sum {
541                int: 0,
542                float: 0.0,
543                saw_float: false,
544                saw_any: false,
545                overflowed: false,
546            },
547            (AggFn::Avg, _, _) => Acc::Avg { sum: 0.0, n: 0 },
548            (AggFn::Min, _, _) => Acc::Min(None),
549            (AggFn::Max, _, _) => Acc::Max(None),
550            (AggFn::Median, _, _) => Acc::Median(Vec::new()),
551            (AggFn::Stddev, _, _) => Acc::Stddev {
552                n: 0,
553                mean: 0.0,
554                m2: 0.0,
555            },
556        }
557    }
558
559    /// Update with the evaluated argument value (`None` only for COUNT(*)).
560    fn update(&mut self, val: Option<&SochValue>) {
561        match self {
562            Acc::CountStar(n) => *n += 1,
563            Acc::Count(n) => {
564                if let Some(v) = val {
565                    if !v.is_null() {
566                        *n += 1;
567                    }
568                }
569            }
570            Acc::CountDistinct(set) => {
571                if let Some(v) = val {
572                    if !v.is_null() {
573                        set.insert(key_repr(v));
574                    }
575                }
576            }
577            Acc::Sum {
578                int,
579                float,
580                saw_float,
581                saw_any,
582                overflowed,
583            } => {
584                let Some(v) = val else { return };
585                match v {
586                    SochValue::Int(i) => {
587                        *saw_any = true;
588                        match int.checked_add(*i) {
589                            Some(s) => *int = s,
590                            None => *overflowed = true,
591                        }
592                        *float += *i as f64;
593                    }
594                    SochValue::UInt(u) => {
595                        *saw_any = true;
596                        match int.checked_add(*u as i64) {
597                            Some(s) => *int = s,
598                            None => *overflowed = true,
599                        }
600                        *float += *u as f64;
601                    }
602                    SochValue::Float(f) => {
603                        *saw_any = true;
604                        *saw_float = true;
605                        *float += *f;
606                    }
607                    _ => {}
608                }
609            }
610            Acc::Avg { sum, n } => {
611                if let Some(x) = val.and_then(numeric) {
612                    *sum += x;
613                    *n += 1;
614                }
615            }
616            Acc::Min(cur) => {
617                let Some(v) = val else { return };
618                if v.is_null() {
619                    return;
620                }
621                match cur {
622                    None => *cur = Some(v.clone()),
623                    Some(c) => {
624                        if compare_values(v, c) == std::cmp::Ordering::Less {
625                            *cur = Some(v.clone());
626                        }
627                    }
628                }
629            }
630            Acc::Max(cur) => {
631                let Some(v) = val else { return };
632                if v.is_null() {
633                    return;
634                }
635                match cur {
636                    None => *cur = Some(v.clone()),
637                    Some(c) => {
638                        if compare_values(v, c) == std::cmp::Ordering::Greater {
639                            *cur = Some(v.clone());
640                        }
641                    }
642                }
643            }
644            Acc::Median(vals) => {
645                if let Some(x) = val.and_then(numeric) {
646                    vals.push(x);
647                }
648            }
649            Acc::Stddev { n, mean, m2 } => {
650                if let Some(x) = val.and_then(numeric) {
651                    *n += 1;
652                    let delta = x - *mean;
653                    *mean += delta / *n as f64;
654                    let delta2 = x - *mean;
655                    *m2 += delta * delta2;
656                }
657            }
658        }
659    }
660
661    /// Merge a partial accumulator (from a parallel chunk) into self.
662    /// Both must originate from the same `AggSpec`.
663    fn merge(&mut self, other: Acc) {
664        match (self, other) {
665            (Acc::CountStar(a), Acc::CountStar(b)) => *a += b,
666            (Acc::Count(a), Acc::Count(b)) => *a += b,
667            (Acc::CountDistinct(a), Acc::CountDistinct(b)) => a.extend(b),
668            (
669                Acc::Sum {
670                    int,
671                    float,
672                    saw_float,
673                    saw_any,
674                    overflowed,
675                },
676                Acc::Sum {
677                    int: i2,
678                    float: f2,
679                    saw_float: sf2,
680                    saw_any: sa2,
681                    overflowed: of2,
682                },
683            ) => {
684                match int.checked_add(i2) {
685                    Some(s) => *int = s,
686                    None => *overflowed = true,
687                }
688                *float += f2;
689                *saw_float |= sf2;
690                *saw_any |= sa2;
691                *overflowed |= of2;
692            }
693            (Acc::Avg { sum, n }, Acc::Avg { sum: s2, n: n2 }) => {
694                *sum += s2;
695                *n += n2;
696            }
697            (Acc::Min(a), Acc::Min(Some(b))) => match a {
698                None => *a = Some(b),
699                Some(cur) => {
700                    if compare_values(&b, cur) == std::cmp::Ordering::Less {
701                        *a = Some(b);
702                    }
703                }
704            },
705            (Acc::Max(a), Acc::Max(Some(b))) => match a {
706                None => *a = Some(b),
707                Some(cur) => {
708                    if compare_values(&b, cur) == std::cmp::Ordering::Greater {
709                        *a = Some(b);
710                    }
711                }
712            },
713            (Acc::Min(_), Acc::Min(None)) | (Acc::Max(_), Acc::Max(None)) => {}
714            (Acc::Median(a), Acc::Median(b)) => a.extend(b),
715            (
716                Acc::Stddev { n, mean, m2 },
717                Acc::Stddev {
718                    n: nb,
719                    mean: mb,
720                    m2: m2b,
721                },
722            ) => {
723                // Chan et al. parallel variance merge.
724                if nb > 0 {
725                    if *n == 0 {
726                        *n = nb;
727                        *mean = mb;
728                        *m2 = m2b;
729                    } else {
730                        let na = *n as f64;
731                        let nbf = nb as f64;
732                        let delta = mb - *mean;
733                        let total = na + nbf;
734                        *mean += delta * nbf / total;
735                        *m2 += m2b + delta * delta * na * nbf / total;
736                        *n += nb;
737                    }
738                }
739            }
740            _ => unreachable!("mismatched accumulator merge"),
741        }
742    }
743
744    fn finalize(self) -> SochValue {
745        match self {
746            Acc::CountStar(n) | Acc::Count(n) => SochValue::Int(n as i64),
747            Acc::CountDistinct(set) => SochValue::Int(set.len() as i64),
748            Acc::Sum {
749                int,
750                float,
751                saw_float,
752                saw_any,
753                overflowed,
754            } => {
755                if !saw_any {
756                    SochValue::Null
757                } else if saw_float || overflowed {
758                    SochValue::Float(float)
759                } else {
760                    SochValue::Int(int)
761                }
762            }
763            Acc::Avg { sum, n } => {
764                if n == 0 {
765                    SochValue::Null
766                } else {
767                    SochValue::Float(sum / n as f64)
768                }
769            }
770            Acc::Min(v) | Acc::Max(v) => v.unwrap_or(SochValue::Null),
771            Acc::Median(mut vals) => {
772                if vals.is_empty() {
773                    return SochValue::Null;
774                }
775                let mid = vals.len() / 2;
776                if vals.len() % 2 == 1 {
777                    let (_, m, _) = vals.select_nth_unstable_by(mid, |a, b| {
778                        a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
779                    });
780                    SochValue::Float(*m)
781                } else {
782                    // Even count: average the two middle values.
783                    let (lo, hi_first, _) = vals.select_nth_unstable_by(mid, |a, b| {
784                        a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
785                    });
786                    let lo_max = lo.iter().copied().fold(f64::NEG_INFINITY, f64::max);
787                    SochValue::Float((lo_max + *hi_first) / 2.0)
788                }
789            }
790            Acc::Stddev { n, m2, .. } => {
791                if n < 2 {
792                    SochValue::Null
793                } else {
794                    SochValue::Float((m2 / (n - 1) as f64).sqrt())
795                }
796            }
797        }
798    }
799}
800
801// ============================================================================
802// The aggregation operator
803// ============================================================================
804
805struct GroupState {
806    key_values: Vec<SochValue>,
807    first_row: HashMap<String, SochValue>,
808    accs: Vec<Acc>,
809}
810
811// ----------------------------------------------------------------------------
812// Fast path: plain-column group keys and aggregate args
813// ----------------------------------------------------------------------------
814
815/// A group-key atom that borrows string data from the input rows —
816/// zero allocations during accumulation lookups.
817#[derive(Debug, Clone, PartialEq, Eq, Hash)]
818enum KeyAtom<'a> {
819    Null,
820    Int(i64),
821    /// Normalized f64 bits (integral floats normalize to `Int`).
822    FBits(u64),
823    Str(&'a str),
824    Bool(bool),
825}
826
827impl<'a> KeyAtom<'a> {
828    fn from_value(v: &'a SochValue) -> Self {
829        match v {
830            SochValue::Null => KeyAtom::Null,
831            SochValue::Int(i) => KeyAtom::Int(*i),
832            SochValue::UInt(u) => KeyAtom::Int(*u as i64),
833            SochValue::Float(f) => {
834                if f.fract() == 0.0 && f.abs() < 9.0e15 {
835                    KeyAtom::Int(*f as i64)
836                } else if f.is_nan() {
837                    KeyAtom::FBits(f64::NAN.to_bits())
838                } else {
839                    KeyAtom::FBits(f.to_bits())
840                }
841            }
842            SochValue::Text(s) => KeyAtom::Str(s.as_str()),
843            SochValue::Bool(b) => KeyAtom::Bool(*b),
844            _ => KeyAtom::Null,
845        }
846    }
847}
848
849#[derive(Debug, Clone, PartialEq, Eq, Hash)]
850enum GroupKey<'a> {
851    Empty,
852    One(KeyAtom<'a>),
853    Many(Vec<KeyAtom<'a>>),
854}
855
856static NULL_VALUE: SochValue = SochValue::Null;
857
858/// Resolve a column reference against a row, trying qualified name first.
859#[inline]
860fn col_get<'r>(row: &'r HashMap<String, SochValue>, col: &PlainCol) -> &'r SochValue {
861    if let Some(q) = &col.qualified {
862        if let Some(v) = row.get(q) {
863            return v;
864        }
865    }
866    row.get(&col.name).unwrap_or(&NULL_VALUE)
867}
868
869/// Pre-resolved plain column: unqualified name + optional "table.col" form.
870struct PlainCol {
871    name: String,
872    qualified: Option<String>,
873}
874
875fn as_plain_col(expr: &Expr) -> Option<PlainCol> {
876    match expr {
877        Expr::Column(c) => Some(PlainCol {
878            name: c.column.clone(),
879            qualified: c.table.as_ref().map(|t| format!("{}.{}", t, c.column)),
880        }),
881        _ => None,
882    }
883}
884
885/// Build the borrowed group key for one row.
886fn make_group_key<'r>(
887    row: &'r HashMap<String, SochValue>,
888    group_cols: &[PlainCol],
889) -> GroupKey<'r> {
890    match group_cols.len() {
891        0 => GroupKey::Empty,
892        1 => GroupKey::One(KeyAtom::from_value(col_get(row, &group_cols[0]))),
893        _ => GroupKey::Many(
894            group_cols
895                .iter()
896                .map(|c| KeyAtom::from_value(col_get(row, c)))
897                .collect(),
898        ),
899    }
900}
901
902/// Try the optimized accumulation path. Applicable when every GROUP BY
903/// expression and every aggregate argument is a plain column reference
904/// (which covers typical analytics queries). Returns group states in
905/// first-seen order (per-chunk order under parallel execution).
906fn accumulate_fast<'a>(
907    select: &SelectStmt,
908    specs: &[AggSpec],
909    rows: &'a [HashMap<String, SochValue>],
910) -> Option<Vec<GroupState>> {
911    // Pre-resolve group-key columns.
912    let group_cols: Vec<PlainCol> = select
913        .group_by
914        .iter()
915        .map(as_plain_col)
916        .collect::<Option<Vec<_>>>()?;
917    // Pre-resolve aggregate argument columns (None = COUNT(*)).
918    let arg_cols: Vec<Option<PlainCol>> = specs
919        .iter()
920        .map(|s| match &s.arg {
921            None => Some(None),
922            Some(e) => as_plain_col(e).map(Some),
923        })
924        .collect::<Option<Vec<_>>>()?;
925
926    let accumulate_chunk =
927        |chunk: &'a [HashMap<String, SochValue>]| -> Vec<(GroupKey<'a>, GroupState)> {
928            let mut order: Vec<(GroupKey<'a>, GroupState)> = Vec::new();
929            let mut index: HashMap<GroupKey<'a>, usize> = HashMap::new();
930            for row in chunk {
931                let key = make_group_key(row, &group_cols);
932                let idx = match index.get(&key) {
933                    Some(&i) => i,
934                    None => {
935                        let state = GroupState {
936                            key_values: group_cols
937                                .iter()
938                                .map(|c| col_get(row, c).clone())
939                                .collect(),
940                            first_row: row.clone(),
941                            accs: specs.iter().map(Acc::new).collect(),
942                        };
943                        order.push((key.clone(), state));
944                        index.insert(key, order.len() - 1);
945                        order.len() - 1
946                    }
947                };
948                let accs = &mut order[idx].1.accs;
949                for (acc, arg) in accs.iter_mut().zip(arg_cols.iter()) {
950                    match arg {
951                        None => acc.update(None),
952                        Some(col) => acc.update(Some(col_get(row, col))),
953                    }
954                }
955            }
956            order
957        };
958
959    let merged: Vec<(GroupKey<'a>, GroupState)> = if rows.len() >= PARALLEL_THRESHOLD {
960        let n_threads = rayon::current_num_threads().max(1);
961        let chunk_size = (rows.len() / (n_threads * 4)).max(16_384);
962        let partials: Vec<Vec<(GroupKey<'a>, GroupState)>> =
963            rows.par_chunks(chunk_size).map(accumulate_chunk).collect();
964        // Merge chunk partials in chunk order.
965        let mut order: Vec<(GroupKey<'a>, GroupState)> = Vec::new();
966        let mut index: HashMap<GroupKey<'a>, usize> = HashMap::new();
967        for partial in partials {
968            for (key, state) in partial {
969                match index.get(&key) {
970                    Some(&i) => {
971                        let dst = &mut order[i].1;
972                        for (a, b) in dst.accs.iter_mut().zip(state.accs.into_iter()) {
973                            a.merge(b);
974                        }
975                    }
976                    None => {
977                        order.push((key.clone(), state));
978                        index.insert(key, order.len() - 1);
979                    }
980                }
981            }
982        }
983        order
984    } else {
985        accumulate_chunk(rows)
986    };
987
988    Some(merged.into_iter().map(|(_, s)| s).collect())
989}
990
991/// Execute aggregation over materialized input rows (already WHERE-filtered).
992///
993/// Handles GROUP BY, all aggregate accumulation, HAVING, ORDER BY,
994/// OFFSET/LIMIT, and final projection. Returns `ExecutionResult::Rows`.
995pub fn execute_aggregate(
996    select: &SelectStmt,
997    rows: &[HashMap<String, SochValue>],
998    params: &[SochValue],
999    limit: Option<usize>,
1000    offset: Option<usize>,
1001) -> SqlResult<ExecutionResult> {
1002    let specs = collect_agg_specs(select);
1003    let grouped = !select.group_by.is_empty();
1004
1005    // ---- accumulate ----
1006    // Fast path: plain-column keys/args, borrowed-key hashing, parallel
1007    // partitioned accumulation. Falls back to the general expression-based
1008    // path for computed keys or computed aggregate arguments.
1009    let mut order: Vec<GroupState> = match accumulate_fast(select, &specs, rows) {
1010        Some(states) => states,
1011        None => {
1012            let mut order: Vec<GroupState> = Vec::new();
1013            let mut index: HashMap<Vec<String>, usize> = HashMap::new();
1014
1015            for row in rows {
1016                let key_values: Vec<SochValue> = select
1017                    .group_by
1018                    .iter()
1019                    .map(|e| eval_scalar(e, row, params))
1020                    .collect();
1021                let hash_key: Vec<String> = key_values.iter().map(key_repr).collect();
1022
1023                let idx = match index.get(&hash_key) {
1024                    Some(&i) => i,
1025                    None => {
1026                        let state = GroupState {
1027                            key_values,
1028                            first_row: row.clone(),
1029                            accs: specs.iter().map(Acc::new).collect(),
1030                        };
1031                        order.push(state);
1032                        index.insert(hash_key, order.len() - 1);
1033                        order.len() - 1
1034                    }
1035                };
1036
1037                let state = &mut order[idx];
1038                for (acc, spec) in state.accs.iter_mut().zip(specs.iter()) {
1039                    match &spec.arg {
1040                        None => acc.update(None),
1041                        Some(arg) => {
1042                            let v = eval_scalar(arg, row, params);
1043                            acc.update(Some(&v));
1044                        }
1045                    }
1046                }
1047            }
1048            order
1049        }
1050    };
1051
1052    // Ungrouped aggregate over zero rows still yields one (empty) group.
1053    if !grouped && order.is_empty() {
1054        order.push(GroupState {
1055            key_values: Vec::new(),
1056            first_row: HashMap::new(),
1057            accs: specs.iter().map(Acc::new).collect(),
1058        });
1059    }
1060
1061    // ---- finalize: synthesize one row per group ----
1062    let group_names: Vec<String> = select.group_by.iter().map(render_expr_name).collect();
1063
1064    let mut out_rows: Vec<HashMap<String, SochValue>> = Vec::with_capacity(order.len());
1065    for state in order {
1066        // Start from the first row of the group so non-aggregate columns
1067        // (lenient mode) and qualified names still resolve.
1068        let mut row = state.first_row;
1069        for (name, val) in group_names.iter().zip(state.key_values.into_iter()) {
1070            row.insert(name.clone(), val);
1071        }
1072        for (spec, acc) in specs.iter().zip(state.accs.into_iter()) {
1073            row.insert(spec.key.clone(), acc.finalize());
1074        }
1075        out_rows.push(row);
1076    }
1077
1078    // ---- HAVING ----
1079    if let Some(having) = &select.having {
1080        out_rows.retain(|row| matches!(eval_scalar(having, row, params), SochValue::Bool(true)));
1081    }
1082
1083    // ---- ORDER BY (may reference aggregates or aliases) ----
1084    if !select.order_by.is_empty() {
1085        // Bind aliases so ORDER BY alias works.
1086        let alias_map: Vec<(String, Expr)> = select
1087            .columns
1088            .iter()
1089            .filter_map(|item| match item {
1090                SelectItem::Expr {
1091                    expr,
1092                    alias: Some(a),
1093                } => Some((a.clone(), expr.clone())),
1094                _ => None,
1095            })
1096            .collect();
1097        for row in &mut out_rows {
1098            for (alias, expr) in &alias_map {
1099                if !row.contains_key(alias) {
1100                    let v = eval_scalar(expr, row, params);
1101                    row.insert(alias.clone(), v);
1102                }
1103            }
1104        }
1105        out_rows.sort_by(|a, b| {
1106            for item in &select.order_by {
1107                let va = eval_scalar(&item.expr, a, params);
1108                let vb = eval_scalar(&item.expr, b, params);
1109                let mut cmp = compare_values(&va, &vb);
1110                if !item.asc {
1111                    cmp = cmp.reverse();
1112                }
1113                if cmp != std::cmp::Ordering::Equal {
1114                    return cmp;
1115                }
1116            }
1117            std::cmp::Ordering::Equal
1118        });
1119    }
1120
1121    // ---- OFFSET / LIMIT ----
1122    if let Some(off) = offset {
1123        if off > 0 {
1124            out_rows.drain(..off.min(out_rows.len()));
1125        }
1126    }
1127    if let Some(lim) = limit {
1128        out_rows.truncate(lim);
1129    }
1130
1131    // ---- projection ----
1132    let mut columns: Vec<String> = Vec::new();
1133    let mut projections: Vec<(String, Expr)> = Vec::new();
1134    for item in &select.columns {
1135        match item {
1136            SelectItem::Wildcard | SelectItem::QualifiedWildcard(_) => {
1137                // SELECT * with GROUP BY: project group keys then aggregates.
1138                for name in &group_names {
1139                    columns.push(name.clone());
1140                    projections.push((name.clone(), Expr::Column(ColumnRef::new(name.clone()))));
1141                }
1142                for spec in &specs {
1143                    columns.push(spec.key.clone());
1144                    projections.push((
1145                        spec.key.clone(),
1146                        Expr::Column(ColumnRef::new(spec.key.clone())),
1147                    ));
1148                }
1149            }
1150            SelectItem::Expr { expr, alias } => {
1151                let name = alias.clone().unwrap_or_else(|| render_expr_name(expr));
1152                columns.push(name.clone());
1153                projections.push((name, expr.clone()));
1154            }
1155        }
1156    }
1157
1158    let projected: Vec<HashMap<String, SochValue>> = out_rows
1159        .into_iter()
1160        .map(|row| {
1161            let mut out = HashMap::with_capacity(projections.len());
1162            for (name, expr) in &projections {
1163                let v = eval_scalar(expr, &row, params);
1164                out.insert(name.clone(), v);
1165            }
1166            out
1167        })
1168        .collect();
1169
1170    Ok(ExecutionResult::Rows {
1171        columns,
1172        rows: projected,
1173    })
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178    use super::super::bridge::{SqlBridge, SqlConnection};
1179    use super::*;
1180
1181    fn fcall(name: &str, arg: &str) -> Expr {
1182        Expr::Function(FunctionCall {
1183            name: ObjectName::new(name),
1184            args: vec![Expr::Column(ColumnRef::new(arg))],
1185            distinct: false,
1186            filter: None,
1187            over: None,
1188        })
1189    }
1190
1191    #[test]
1192    fn agg_fn_recognition() {
1193        assert_eq!(AggFn::from_name("median"), Some(AggFn::Median));
1194        assert_eq!(AggFn::from_name("STDDEV"), Some(AggFn::Stddev));
1195        assert_eq!(AggFn::from_name("stddev_samp"), Some(AggFn::Stddev));
1196        assert_eq!(AggFn::from_name("upper"), None);
1197    }
1198
1199    #[test]
1200    fn canonical_keys() {
1201        assert_eq!(render_expr_name(&fcall("SUM", "v1")), "sum(v1)");
1202        assert_eq!(render_expr_name(&fcall("Median", "v3")), "median(v3)");
1203    }
1204
1205    // ========================================================================
1206    // End-to-end SQL tests through SqlBridge with an in-memory connection
1207    // ========================================================================
1208
1209    /// In-memory table store implementing SqlConnection for tests.
1210    struct DataConn {
1211        tables: HashMap<String, Vec<HashMap<String, SochValue>>>,
1212    }
1213
1214    impl DataConn {
1215        fn new() -> Self {
1216            Self {
1217                tables: HashMap::new(),
1218            }
1219        }
1220
1221        fn with_table(mut self, name: &str, cols: &[&str], rows: Vec<Vec<SochValue>>) -> Self {
1222            let rows = rows
1223                .into_iter()
1224                .map(|vals| {
1225                    cols.iter()
1226                        .map(|c| c.to_string())
1227                        .zip(vals.into_iter())
1228                        .collect::<HashMap<_, _>>()
1229                })
1230                .collect();
1231            self.tables.insert(name.to_string(), rows);
1232            self
1233        }
1234    }
1235
1236    impl SqlConnection for DataConn {
1237        fn select(
1238            &self,
1239            table: &str,
1240            _: &[String],
1241            _where_clause: Option<&Expr>,
1242            _: &[OrderByItem],
1243            _: Option<usize>,
1244            _: Option<usize>,
1245            _: &[SochValue],
1246        ) -> SqlResult<ExecutionResult> {
1247            // Tests using the aggregate path don't push WHERE here.
1248            let rows = self.tables.get(table).cloned().unwrap_or_default();
1249            Ok(ExecutionResult::Rows {
1250                columns: vec![],
1251                rows,
1252            })
1253        }
1254        fn insert(
1255            &mut self,
1256            _: &str,
1257            _: Option<&[String]>,
1258            _: &[Vec<Expr>],
1259            _: Option<&OnConflict>,
1260            _: &[SochValue],
1261        ) -> SqlResult<ExecutionResult> {
1262            Ok(ExecutionResult::RowsAffected(0))
1263        }
1264        fn update(
1265            &mut self,
1266            _: &str,
1267            _: &[Assignment],
1268            _: Option<&Expr>,
1269            _: &[SochValue],
1270        ) -> SqlResult<ExecutionResult> {
1271            Ok(ExecutionResult::RowsAffected(0))
1272        }
1273        fn delete(
1274            &mut self,
1275            _: &str,
1276            _: Option<&Expr>,
1277            _: &[SochValue],
1278        ) -> SqlResult<ExecutionResult> {
1279            Ok(ExecutionResult::RowsAffected(0))
1280        }
1281        fn create_table(&mut self, _: &CreateTableStmt) -> SqlResult<ExecutionResult> {
1282            Ok(ExecutionResult::Ok)
1283        }
1284        fn drop_table(&mut self, _: &DropTableStmt) -> SqlResult<ExecutionResult> {
1285            Ok(ExecutionResult::Ok)
1286        }
1287        fn create_index(&mut self, _: &CreateIndexStmt) -> SqlResult<ExecutionResult> {
1288            Ok(ExecutionResult::Ok)
1289        }
1290        fn drop_index(&mut self, _: &DropIndexStmt) -> SqlResult<ExecutionResult> {
1291            Ok(ExecutionResult::Ok)
1292        }
1293        fn alter_table(&mut self, _: &AlterTableStmt) -> SqlResult<ExecutionResult> {
1294            Ok(ExecutionResult::Ok)
1295        }
1296        fn begin(&mut self, _: &BeginStmt) -> SqlResult<ExecutionResult> {
1297            Ok(ExecutionResult::TransactionOk)
1298        }
1299        fn commit(&mut self) -> SqlResult<ExecutionResult> {
1300            Ok(ExecutionResult::TransactionOk)
1301        }
1302        fn rollback(&mut self, _: Option<&str>) -> SqlResult<ExecutionResult> {
1303            Ok(ExecutionResult::TransactionOk)
1304        }
1305        fn table_exists(&self, t: &str) -> SqlResult<bool> {
1306            Ok(self.tables.contains_key(t))
1307        }
1308        fn index_exists(&self, _: &str) -> SqlResult<bool> {
1309            Ok(false)
1310        }
1311        fn scan_all(
1312            &self,
1313            table: &str,
1314            _: &[String],
1315        ) -> SqlResult<Vec<HashMap<String, SochValue>>> {
1316            Ok(self.tables.get(table).cloned().unwrap_or_default())
1317        }
1318        fn eval_join_predicate(
1319            &self,
1320            expr: &Expr,
1321            row: &HashMap<String, SochValue>,
1322            params: &[SochValue],
1323        ) -> Option<bool> {
1324            match eval_scalar(expr, row, params) {
1325                SochValue::Bool(b) => Some(b),
1326                SochValue::Null => Some(false),
1327                _ => None,
1328            }
1329        }
1330    }
1331
1332    fn i(v: i64) -> SochValue {
1333        SochValue::Int(v)
1334    }
1335    fn f(v: f64) -> SochValue {
1336        SochValue::Float(v)
1337    }
1338    fn t(v: &str) -> SochValue {
1339        SochValue::Text(v.to_string())
1340    }
1341
1342    /// db-benchmark-shaped fixture: x(id1 text, id3 text, v1 int, v2 int, v3 float)
1343    fn bench_bridge() -> SqlBridge<DataConn> {
1344        let conn = DataConn::new().with_table(
1345            "x",
1346            &["id1", "id3", "v1", "v2", "v3"],
1347            vec![
1348                vec![t("id001"), t("id0000001"), i(1), i(10), f(1.0)],
1349                vec![t("id001"), t("id0000002"), i(2), i(20), f(2.0)],
1350                vec![t("id002"), t("id0000001"), i(3), i(30), f(3.0)],
1351                vec![t("id002"), t("id0000002"), i(4), i(40), f(4.0)],
1352            ],
1353        );
1354        SqlBridge::new(conn)
1355    }
1356
1357    fn rows_of(result: ExecutionResult) -> Vec<HashMap<String, SochValue>> {
1358        match result {
1359            ExecutionResult::Rows { rows, .. } => rows,
1360            other => panic!("expected rows, got {:?}", other),
1361        }
1362    }
1363
1364    fn get<'a>(row: &'a HashMap<String, SochValue>, k: &str) -> &'a SochValue {
1365        row.get(k)
1366            .unwrap_or_else(|| panic!("column '{}' missing from {:?}", k, row))
1367    }
1368
1369    #[test]
1370    fn groupby_sum_q1_shape() {
1371        // db-benchmark q1: SELECT id1, sum(v1) AS v1 FROM x GROUP BY id1
1372        let mut b = bench_bridge();
1373        let rows = rows_of(
1374            b.execute("SELECT id1, sum(v1) AS v1 FROM x GROUP BY id1 ORDER BY id1")
1375                .unwrap(),
1376        );
1377        assert_eq!(rows.len(), 2);
1378        assert_eq!(get(&rows[0], "id1"), &t("id001"));
1379        assert_eq!(get(&rows[0], "v1"), &i(3));
1380        assert_eq!(get(&rows[1], "id1"), &t("id002"));
1381        assert_eq!(get(&rows[1], "v1"), &i(7));
1382    }
1383
1384    #[test]
1385    fn groupby_multi_key_mean() {
1386        // q4-like: SELECT id1, id3, avg(v1) AS m FROM x GROUP BY id1, id3
1387        let mut b = bench_bridge();
1388        let rows = rows_of(
1389            b.execute("SELECT id1, id3, avg(v1) AS m FROM x GROUP BY id1, id3 ORDER BY id1, id3")
1390                .unwrap(),
1391        );
1392        assert_eq!(rows.len(), 4);
1393        assert_eq!(get(&rows[0], "m"), &f(1.0));
1394        assert_eq!(get(&rows[3], "m"), &f(4.0));
1395    }
1396
1397    #[test]
1398    fn median_and_stddev() {
1399        // q7-like: SELECT median(v3), stddev(v3) FROM x
1400        // v3 = [1,2,3,4]: median = 2.5, sample sd = sqrt(5/3)
1401        let mut b = bench_bridge();
1402        let rows = rows_of(
1403            b.execute("SELECT median(v3) AS med, stddev(v3) AS sd FROM x")
1404                .unwrap(),
1405        );
1406        assert_eq!(rows.len(), 1);
1407        assert_eq!(get(&rows[0], "med"), &f(2.5));
1408        match get(&rows[0], "sd") {
1409            SochValue::Float(sd) => {
1410                assert!((sd - (5.0f64 / 3.0).sqrt()).abs() < 1e-12, "sd={}", sd)
1411            }
1412            other => panic!("expected float sd, got {:?}", other),
1413        }
1414    }
1415
1416    #[test]
1417    fn median_odd_count() {
1418        let conn =
1419            DataConn::new().with_table("t", &["v"], vec![vec![f(5.0)], vec![f(1.0)], vec![f(3.0)]]);
1420        let mut b = SqlBridge::new(conn);
1421        let rows = rows_of(b.execute("SELECT median(v) AS m FROM t").unwrap());
1422        assert_eq!(get(&rows[0], "m"), &f(3.0));
1423    }
1424
1425    #[test]
1426    fn range_expression_q9_shape() {
1427        // q9: SELECT id3, max(v1) - min(v2) AS range_v1_v2 FROM x GROUP BY id3
1428        let mut b = bench_bridge();
1429        let rows = rows_of(
1430            b.execute(
1431                "SELECT id3, max(v1) - min(v2) AS range_v1_v2 FROM x GROUP BY id3 ORDER BY id3",
1432            )
1433            .unwrap(),
1434        );
1435        assert_eq!(rows.len(), 2);
1436        // id0000001: max(v1)=3, min(v2)=10 -> -7
1437        assert_eq!(get(&rows[0], "range_v1_v2"), &i(-7));
1438        // id0000002: max(v1)=4, min(v2)=20 -> -16
1439        assert_eq!(get(&rows[1], "range_v1_v2"), &i(-16));
1440    }
1441
1442    #[test]
1443    fn count_star_vs_count_col_with_nulls() {
1444        let conn = DataConn::new().with_table(
1445            "t",
1446            &["g", "v"],
1447            vec![
1448                vec![t("a"), i(1)],
1449                vec![t("a"), SochValue::Null],
1450                vec![t("b"), i(2)],
1451            ],
1452        );
1453        let mut b = SqlBridge::new(conn);
1454        let rows = rows_of(
1455            b.execute("SELECT g, count(*) AS n, count(v) AS nv FROM t GROUP BY g ORDER BY g")
1456                .unwrap(),
1457        );
1458        assert_eq!(rows.len(), 2);
1459        assert_eq!(get(&rows[0], "n"), &i(2));
1460        assert_eq!(get(&rows[0], "nv"), &i(1));
1461        assert_eq!(get(&rows[1], "n"), &i(1));
1462        assert_eq!(get(&rows[1], "nv"), &i(1));
1463    }
1464
1465    #[test]
1466    fn count_distinct() {
1467        // q6-like: SELECT id3, count(DISTINCT id1) AS u FROM x GROUP BY id3
1468        let mut b = bench_bridge();
1469        let rows = rows_of(
1470            b.execute("SELECT id3, count(DISTINCT id1) AS u FROM x GROUP BY id3 ORDER BY id3")
1471                .unwrap(),
1472        );
1473        assert_eq!(rows.len(), 2);
1474        assert_eq!(get(&rows[0], "u"), &i(2));
1475        assert_eq!(get(&rows[1], "u"), &i(2));
1476    }
1477
1478    #[test]
1479    fn having_filters_groups() {
1480        let mut b = bench_bridge();
1481        let rows = rows_of(
1482            b.execute("SELECT id1, sum(v1) AS s FROM x GROUP BY id1 HAVING sum(v1) > 5")
1483                .unwrap(),
1484        );
1485        assert_eq!(rows.len(), 1);
1486        assert_eq!(get(&rows[0], "id1"), &t("id002"));
1487        assert_eq!(get(&rows[0], "s"), &i(7));
1488    }
1489
1490    #[test]
1491    fn order_by_aggregate_desc_with_limit() {
1492        let mut b = bench_bridge();
1493        let rows = rows_of(
1494            b.execute("SELECT id1, sum(v1) AS s FROM x GROUP BY id1 ORDER BY s DESC LIMIT 1")
1495                .unwrap(),
1496        );
1497        assert_eq!(rows.len(), 1);
1498        assert_eq!(get(&rows[0], "id1"), &t("id002"));
1499    }
1500
1501    #[test]
1502    fn ungrouped_aggregate_over_empty_table() {
1503        let conn = DataConn::new().with_table("e", &["v"], vec![]);
1504        let mut b = SqlBridge::new(conn);
1505        let rows = rows_of(
1506            b.execute("SELECT count(*) AS n, sum(v) AS s FROM e")
1507                .unwrap(),
1508        );
1509        assert_eq!(rows.len(), 1, "ungrouped agg over empty input = one row");
1510        assert_eq!(get(&rows[0], "n"), &i(0));
1511        assert_eq!(get(&rows[0], "s"), &SochValue::Null);
1512    }
1513
1514    #[test]
1515    fn grouped_aggregate_over_empty_table_yields_no_rows() {
1516        let conn = DataConn::new().with_table("e", &["g", "v"], vec![]);
1517        let mut b = SqlBridge::new(conn);
1518        let rows = rows_of(
1519            b.execute("SELECT g, sum(v) AS s FROM e GROUP BY g")
1520                .unwrap(),
1521        );
1522        assert!(rows.is_empty());
1523    }
1524
1525    #[test]
1526    fn sum_overflow_promotes_to_float() {
1527        let conn =
1528            DataConn::new().with_table("t", &["v"], vec![vec![i(i64::MAX)], vec![i(i64::MAX)]]);
1529        let mut b = SqlBridge::new(conn);
1530        let rows = rows_of(b.execute("SELECT sum(v) AS s FROM t").unwrap());
1531        match get(&rows[0], "s") {
1532            SochValue::Float(v) => assert!(*v > 1.8e19),
1533            other => panic!("expected float after overflow, got {:?}", other),
1534        }
1535    }
1536
1537    #[test]
1538    fn aggregate_after_join() {
1539        // join + group: SELECT x.id1, sum(y.w) FROM x JOIN y ON x.id1 = y.id1 GROUP BY x.id1
1540        let conn = DataConn::new()
1541            .with_table(
1542                "a",
1543                &["id", "v"],
1544                vec![
1545                    vec![t("k1"), i(1)],
1546                    vec![t("k1"), i(2)],
1547                    vec![t("k2"), i(3)],
1548                ],
1549            )
1550            .with_table(
1551                "b",
1552                &["id", "w"],
1553                vec![vec![t("k1"), i(10)], vec![t("k2"), i(20)]],
1554            );
1555        let mut br = SqlBridge::new(conn);
1556        let rows = rows_of(
1557            br.execute(
1558                "SELECT a.id, sum(a.v) AS sv, sum(b.w) AS sw \
1559                 FROM a JOIN b ON a.id = b.id GROUP BY a.id ORDER BY a.id",
1560            )
1561            .unwrap(),
1562        );
1563        assert_eq!(rows.len(), 2);
1564        assert_eq!(get(&rows[0], "sv"), &i(3));
1565        assert_eq!(get(&rows[0], "sw"), &i(20)); // 10 joined to both k1 rows
1566        assert_eq!(get(&rows[1], "sv"), &i(3));
1567        assert_eq!(get(&rows[1], "sw"), &i(20));
1568    }
1569
1570    #[test]
1571    fn lowercase_function_names_parse() {
1572        // db-benchmark SQL uses lowercase: sum(v1), median(v3)
1573        let mut b = bench_bridge();
1574        assert!(b.execute("SELECT id1, sum(v1) FROM x GROUP BY id1").is_ok());
1575        assert!(b.execute("SELECT median(v3) FROM x").is_ok());
1576        assert!(b.execute("SELECT stddev(v3) FROM x").is_ok());
1577    }
1578
1579    #[test]
1580    fn parallel_path_matches_reference_computation() {
1581        // 150k rows (> PARALLEL_THRESHOLD) exercising the rayon merge:
1582        // sum, avg, count, median, stddev per group, verified against
1583        // values computed directly in the test.
1584        let n: usize = 150_000;
1585        let groups = 7usize;
1586        let mut data: Vec<Vec<SochValue>> = Vec::with_capacity(n);
1587        for idx in 0..n {
1588            data.push(vec![
1589                t(&format!("g{}", idx % groups)),
1590                f((idx * 31 % 1000) as f64 / 4.0),
1591            ]);
1592        }
1593        // Reference computation.
1594        let mut per_group: Vec<Vec<f64>> = vec![Vec::new(); groups];
1595        for idx in 0..n {
1596            per_group[idx % groups].push((idx * 31 % 1000) as f64 / 4.0);
1597        }
1598
1599        let conn = DataConn::new().with_table("big", &["g", "v"], data);
1600        let mut b = SqlBridge::new(conn);
1601        let rows = rows_of(
1602            b.execute(
1603                "SELECT g, count(*) AS n, sum(v) AS s, avg(v) AS m, \
1604                 median(v) AS med, stddev(v) AS sd FROM big GROUP BY g ORDER BY g",
1605            )
1606            .unwrap(),
1607        );
1608        assert_eq!(rows.len(), groups);
1609
1610        for (gi, row) in rows.iter().enumerate() {
1611            let vals = &per_group[gi];
1612            let cnt = vals.len() as f64;
1613            let sum: f64 = vals.iter().sum();
1614            let mean = sum / cnt;
1615            let var = vals.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / (cnt - 1.0);
1616            let mut sorted = vals.clone();
1617            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1618            let med = if sorted.len() % 2 == 1 {
1619                sorted[sorted.len() / 2]
1620            } else {
1621                (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
1622            };
1623
1624            assert_eq!(get(row, "g"), &t(&format!("g{}", gi)));
1625            assert_eq!(get(row, "n"), &i(vals.len() as i64));
1626            match get(row, "s") {
1627                SochValue::Float(v) => assert!((v - sum).abs() < 1e-6, "sum"),
1628                other => panic!("sum type {:?}", other),
1629            }
1630            match get(row, "m") {
1631                SochValue::Float(v) => assert!((v - mean).abs() < 1e-9, "mean"),
1632                other => panic!("mean type {:?}", other),
1633            }
1634            match get(row, "med") {
1635                SochValue::Float(v) => assert!((v - med).abs() < 1e-9, "median"),
1636                other => panic!("median type {:?}", other),
1637            }
1638            match get(row, "sd") {
1639                SochValue::Float(v) => {
1640                    assert!((v - var.sqrt()).abs() < 1e-9, "sd {} vs {}", v, var.sqrt())
1641                }
1642                other => panic!("sd type {:?}", other),
1643            }
1644        }
1645    }
1646
1647    #[test]
1648    fn unaliased_aggregate_column_name_is_canonical() {
1649        let mut b = bench_bridge();
1650        let result = b
1651            .execute("SELECT id1, sum(v1) FROM x GROUP BY id1")
1652            .unwrap();
1653        let cols = result.columns().unwrap().clone();
1654        assert!(cols.contains(&"sum(v1)".to_string()), "cols={:?}", cols);
1655    }
1656}