Skip to main content

spg_engine/
aggregate.rs

1//! Aggregate executor.
2//!
3//! Handles `SELECT … <aggs> … [GROUP BY …]` queries. The planning strategy
4//! is straightforward:
5//!
6//! 1. Walk the SELECT (and ORDER BY) expressions to find every aggregate
7//!    function call. Dedupe by AST equality and assign each `__agg_<i>`.
8//! 2. Same for every `GROUP BY` expression: assign `__grp_<j>`.
9//! 3. Stream the WHERE-filtered rows, group by the tuple of GROUP BY
10//!    values, and update per-group aggregate state.
11//! 4. Materialise a synthetic per-group row containing
12//!    `[__grp_0..__grp_K, __agg_0..__agg_N]` and rewrite the user's
13//!    SELECT / ORDER BY expressions to reference those synthetic columns
14//!    instead of the originals.
15//! 5. Evaluate the rewritten expressions against the synthetic schema and
16//!    emit results.
17//!
18//! v1.8 implements `count(*)`, `count(expr)`, `sum`, `min`, `max`, `avg`.
19//! NULL semantics follow PG: aggregates skip NULL inputs (except
20//! `count(*)`, which counts rows). `sum(int)` widens to `BigInt`;
21//! `avg(int|bigint)` returns `Float`.
22
23use alloc::boxed::Box;
24use alloc::collections::{BTreeMap, BTreeSet};
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34/// True if this statement should go through the aggregate path.
35pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36    if stmt.group_by.is_some() || stmt.having.is_some() {
37        return true;
38    }
39    for item in &stmt.items {
40        if let SelectItem::Expr { expr, .. } = item
41            && contains_aggregate(expr)
42        {
43            return true;
44        }
45    }
46    for o in &stmt.order_by {
47        if contains_aggregate(&o.expr) {
48            return true;
49        }
50    }
51    if let Some(h) = &stmt.having
52        && contains_aggregate(h)
53    {
54        return true;
55    }
56    false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60    match e {
61        Expr::FunctionCall { name, args } => {
62            is_aggregate_name(name) || args.iter().any(contains_aggregate)
63        }
64        Expr::AggregateOrdered { .. } => true,
65        Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
66        Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
67            contains_aggregate(expr)
68        }
69        Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
70        Expr::Extract { source, .. } => contains_aggregate(source),
71        // v4.10 subqueries + v4.12 window functions / Literal /
72        // Column — all non-aggregate leaves from the regular
73        // aggregate planner's POV. Window-bearing projections are
74        // routed to exec_select_with_window before this runs.
75        Expr::ScalarSubquery(_)
76        | Expr::Exists { .. }
77        | Expr::InSubquery { .. }
78        | Expr::WindowFunction { .. }
79        | Expr::Literal(_)
80        | Expr::Placeholder(_)
81        | Expr::Column(_) => false,
82        // v7.10.10 — recurse into array constructor / subscript /
83        // ANY/ALL children. Aggregates inside `ARRAY[SUM(x)]` are
84        // valid PG and must be detected here.
85        Expr::Array(items) => items.iter().any(contains_aggregate),
86        Expr::ArraySubscript { target, index } => {
87            contains_aggregate(target) || contains_aggregate(index)
88        }
89        Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
90        // v7.13.0 — CASE WHEN … END. Recurse into operand,
91        // every (WHEN, THEN) pair, and the ELSE branch.
92        Expr::Case {
93            operand,
94            branches,
95            else_branch,
96        } => {
97            operand.as_deref().is_some_and(contains_aggregate)
98                || branches
99                    .iter()
100                    .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
101                || else_branch.as_deref().is_some_and(contains_aggregate)
102        }
103    }
104}
105
106pub fn is_aggregate_name(name: &str) -> bool {
107    matches!(
108        name.to_ascii_lowercase().as_str(),
109        "count"
110            | "count_star"
111            | "sum"
112            | "min"
113            | "max"
114            | "avg"
115            // v7.17.0 — variadic / collection aggregates. ORM
116            // reports (Hibernate / Rails / Django) emit these in
117            // GROUP BY rollups; pre-7.17 SPG hit "unknown
118            // aggregate".
119            | "string_agg"
120            | "array_agg"
121            // v7.17.0 — boolean aggregates. `every` is SQL-standard
122            // alias for `bool_and`.
123            | "bool_and"
124            | "bool_or"
125            | "every"
126    )
127}
128
129/// Per-aggregate running state.
130#[derive(Debug, Default, Clone)]
131struct AggState {
132    count: i64,
133    sum_int: i64,
134    sum_float: f64,
135    extreme: Option<Value>,
136    use_float: bool,
137    /// v7.17.0 — running collection for string_agg / array_agg.
138    /// Each entry is one row's contribution (NULL preserved as
139    /// `Value::Null`; string_agg's finalize step drops them, but
140    /// array_agg keeps them). Pushing in insertion order matches
141    /// PG behaviour when no `ORDER BY` is given inside the
142    /// aggregate call.
143    items: Vec<Value>,
144    /// v7.25 (round-17) — per-group dedupe set for DISTINCT
145    /// aggregates (encoded values; NULLs never reach it because
146    /// the caller's skip runs after the per-aggregate NULL rules).
147    seen: BTreeSet<String>,
148    /// v7.24 (round-16 A) — per-item ORDER BY key tuples, parallel
149    /// to `items` (pushed under the same skip/keep conditions).
150    /// Empty when the aggregate carries no internal ordering.
151    item_keys: Vec<Vec<Value>>,
152    /// v7.17.0 — captured separator for string_agg. PG accepts a
153    /// non-constant separator expression but in practice every
154    /// caller passes a literal; the engine snapshots the last
155    /// non-NULL text it sees, which matches PG's "use the latest
156    /// row's value" behaviour.
157    separator: Option<String>,
158    /// v7.17.0 — running boolean accumulator for bool_and /
159    /// bool_or / every. `None` until the first non-NULL input;
160    /// at finalize None → SQL NULL.
161    bool_acc: Option<bool>,
162}
163
164#[derive(Debug, Clone)]
165struct AggSpec {
166    name: String, // lowercased
167    /// First argument (value expression) for every aggregate
168    /// except `count(*)`. `None` for `count_star`.
169    arg: Option<Expr>,
170    /// v7.17.0 — second argument. Only `string_agg(value, sep)`
171    /// uses it today. `None` for every other aggregate (or for
172    /// `array_agg`, which is single-arg). Carried in the spec so
173    /// per-row evaluation can re-use the same separator
174    /// expression across calls.
175    arg2: Option<Expr>,
176    /// v7.25 (round-17) — `COUNT(DISTINCT x)` & friends: dedupe
177    /// the input stream per group before accumulation.
178    distinct: bool,
179    /// v7.24 (round-16 A) — aggregate-internal ORDER BY keys
180    /// (`array_agg(x ORDER BY y DESC NULLS LAST)`). Empty for the
181    /// plain form. Only the collection aggregates honour it;
182    /// other aggregates are order-insensitive and ignore it (PG
183    /// accepts the syntax everywhere too).
184    order_by: Vec<spg_sql::ast::OrderBy>,
185}
186
187/// Output of running the aggregate path. Schema describes one row per
188/// group; rows are not yet ORDER BY-sorted (caller does it).
189#[derive(Debug)]
190pub struct AggResult {
191    pub columns: Vec<ColumnSchema>,
192    pub rows: Vec<Row>,
193}
194
195/// Execute aggregate logic against an already-WHERE-filtered iterator of
196/// rows. `table_alias` is the alias accepted by column resolution.
197#[allow(clippy::too_many_lines)]
198pub fn run(
199    stmt: &SelectStatement,
200    rows: &[&Row],
201    schema_cols: &[ColumnSchema],
202    table_alias: Option<&str>,
203) -> Result<AggResult, EvalError> {
204    let ctx = EvalContext::new(schema_cols, table_alias);
205    let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
206
207    // Collect aggregate sub-expressions across items + order_by.
208    let mut agg_specs: Vec<AggSpec> = Vec::new();
209    for item in &stmt.items {
210        if let SelectItem::Expr { expr, .. } = item {
211            collect_aggregates(expr, &mut agg_specs);
212        }
213    }
214    for o in &stmt.order_by {
215        collect_aggregates(&o.expr, &mut agg_specs);
216    }
217    if let Some(h) = &stmt.having {
218        collect_aggregates(h, &mut agg_specs);
219    }
220    // v7.17.0 — arity validation. The collector tolerates an
221    // arbitrary positional-arg count; here we enforce the
222    // per-aggregate contract so a malformed call (e.g.
223    // `array_agg()` or `string_agg(x)`) surfaces as a SQL error
224    // rather than silently coercing to a degenerate aggregate.
225    validate_agg_arities(stmt, &agg_specs)?;
226
227    // Map group key (vec of values, encoded as canonical string) -> group state.
228    // Order of insertion is preserved via a parallel Vec of keys.
229    let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
230    let mut key_order: Vec<String> = Vec::new();
231    // When there are no GROUP BY exprs *and* there is at least one aggregate,
232    // every row collapses into a single anonymous group keyed by "".
233    if rows.is_empty() && group_exprs.is_empty() {
234        // Single empty-aggregate group: count=0, sum=0, max=NULL, etc.
235        let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
236        groups.insert(String::new(), (Vec::new(), init));
237        key_order.push(String::new());
238    }
239
240    for row in rows {
241        let group_vals: Vec<Value> = group_exprs
242            .iter()
243            .map(|g| eval::eval_expr(g, row, &ctx))
244            .collect::<Result<_, _>>()?;
245        // v7.17.0 Phase 2.5b — case-insensitive group keying.
246        // For each group_expr that's a column reference on a
247        // CaseInsensitive text column, fold the corresponding
248        // value before encoding the key. Display value
249        // (`group_vals`) stays original — only the key folds.
250        let mut key_vals = group_vals.clone();
251        for (i, g) in group_exprs.iter().enumerate() {
252            if matches!(
253                eval::column_collation(g, &ctx),
254                Some(spg_storage::Collation::CaseInsensitive)
255            ) {
256                if let Value::Text(s) = &key_vals[i] {
257                    key_vals[i] = Value::Text(s.to_ascii_lowercase());
258                }
259            }
260        }
261        let key = encode_key(&key_vals);
262        let entry = groups.entry(key.clone()).or_insert_with(|| {
263            key_order.push(key.clone());
264            let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
265            (group_vals.clone(), init)
266        });
267        for (i, spec) in agg_specs.iter().enumerate() {
268            let arg_val = match &spec.arg {
269                None => Value::Bool(true), // count_star: sentinel non-null
270                Some(e) => eval::eval_expr(e, row, &ctx)?,
271            };
272            // v7.17.0 — `string_agg(value, separator)` evaluates the
273            // separator per row but PG treats it as constant; we
274            // pass the per-row value into update_state so a future
275            // varying-separator caller still sees correct output,
276            // even though SPG (like PG) only uses the most recent.
277            let arg2_val = match &spec.arg2 {
278                None => None,
279                Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
280            };
281            // v7.24 (round-16 A) — aggregate-internal ORDER BY:
282            // evaluate the key tuple against the source row.
283            let order_keys = if spec.order_by.is_empty() {
284                None
285            } else {
286                let mut keys = Vec::with_capacity(spec.order_by.len());
287                for o in &spec.order_by {
288                    keys.push(eval::eval_expr(&o.expr, row, &ctx)?);
289                }
290                Some(keys)
291            };
292            // v7.25 (round-17) — DISTINCT: drop repeated inputs
293            // before they reach the accumulator. NULLs flow through
294            // (each aggregate's own NULL rule applies; PG also
295            // treats NULL as a single distinct value for array_agg).
296            if spec.distinct {
297                let key = encode_key(core::slice::from_ref(&arg_val));
298                if !entry.1[i].seen.insert(key) {
299                    continue;
300                }
301            }
302            update_state(
303                &mut entry.1[i],
304                &spec.name,
305                &arg_val,
306                arg2_val.as_ref(),
307                order_keys,
308            )?;
309        }
310    }
311
312    // Build synthetic schema: __grp_0..K then __agg_0..N.
313    let group_types: Vec<DataType> = if rows.is_empty() {
314        // Use Text as a safe stand-in — empty result means schema isn't
315        // observable. Avoids needing to evaluate group exprs on no row.
316        group_exprs.iter().map(|_| DataType::Text).collect()
317    } else {
318        let probe = rows[0];
319        group_exprs
320            .iter()
321            .map(|g| {
322                eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
323            })
324            .collect::<Result<_, _>>()?
325    };
326    let agg_types: Vec<DataType> = agg_specs.iter().map(infer_agg_type).collect();
327    let mut synth_schema: Vec<ColumnSchema> = Vec::new();
328    for (i, ty) in group_types.iter().enumerate() {
329        synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
330    }
331    for (i, ty) in agg_types.iter().enumerate() {
332        synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
333    }
334
335    // Materialise synthetic rows.
336    let mut synth_rows: Vec<Row> = Vec::new();
337    for k in &key_order {
338        let (gvals, states) = &groups[k];
339        let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
340        values.extend(gvals.iter().cloned());
341        for (i, st) in states.iter().enumerate() {
342            // v7.24 (round-16 A) — order the collected items per the
343            // aggregate-internal ORDER BY before finalize consumes
344            // them.
345            let st_sorted;
346            let st_final: &AggState =
347                if !agg_specs[i].order_by.is_empty() && st.item_keys.len() == st.items.len() {
348                    let mut idx: Vec<usize> = (0..st.items.len()).collect();
349                    let ob = &agg_specs[i].order_by;
350                    idx.sort_by(|&x, &y| {
351                        for (k, o) in ob.iter().enumerate() {
352                            let cmp = crate::order_by_value_cmp(
353                                o.desc,
354                                o.nulls_first,
355                                &st.item_keys[x][k],
356                                &st.item_keys[y][k],
357                            );
358                            if cmp != core::cmp::Ordering::Equal {
359                                return cmp;
360                            }
361                        }
362                        core::cmp::Ordering::Equal
363                    });
364                    let mut sorted = st.clone();
365                    sorted.items = idx.iter().map(|&j| st.items[j].clone()).collect();
366                    st_sorted = sorted;
367                    &st_sorted
368                } else {
369                    st
370                };
371            values.push(finalize(&agg_specs[i].name, st_final));
372        }
373        synth_rows.push(Row::new(values));
374    }
375
376    // Rewrite the user's SELECT items + ORDER BY to reference synthetic
377    // columns. After rewriting, every remaining `Expr::Column` must
378    // resolve against the synthetic schema (i.e. must have been a GROUP
379    // BY expression).
380    let columns: Vec<ColumnSchema> = stmt
381        .items
382        .iter()
383        .map(|item| match item {
384            SelectItem::Wildcard => Err(EvalError::TypeMismatch {
385                detail: "SELECT * with aggregates is not supported".into(),
386            }),
387            SelectItem::Expr { expr, alias } => {
388                let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
389                let name = alias.clone().unwrap_or_else(|| expr.to_string());
390                Ok(ColumnSchema::new(
391                    name,
392                    agg_or_group_type(&rewritten, &synth_schema),
393                    true,
394                ))
395            }
396        })
397        .collect::<Result<_, _>>()?;
398
399    // Project per synthetic row. HAVING filters out groups *before*
400    // we keep the projected row — same semantics as PG: HAVING runs
401    // against the aggregated row (so `HAVING count(*) > 1` works) and
402    // sees only group-by'd columns plus aggregate values.
403    let synth_ctx = EvalContext::new(&synth_schema, None);
404    let having_rewritten = stmt
405        .having
406        .as_ref()
407        .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
408    let mut kept_synth: Vec<Row> = Vec::new();
409    let mut out_rows: Vec<Row> = Vec::new();
410    for srow in synth_rows {
411        if let Some(h) = &having_rewritten {
412            let cond = eval::eval_expr(h, &srow, &synth_ctx)?;
413            if !matches!(cond, Value::Bool(true)) {
414                continue;
415            }
416        }
417        let mut values: Vec<Value> = Vec::with_capacity(columns.len());
418        for item in &stmt.items {
419            if let SelectItem::Expr { expr, .. } = item {
420                let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
421                values.push(eval::eval_expr(&rewritten, &srow, &synth_ctx)?);
422            }
423        }
424        kept_synth.push(srow);
425        out_rows.push(Row::new(values));
426    }
427
428    // ORDER BY: evaluate the rewritten order_by against each synth row,
429    // sort, then drop the keys. Limit is applied by the caller.
430    if !stmt.order_by.is_empty() {
431        // v6.4.0 — multi-key ORDER BY on aggregate output. Each key
432        // gets its own rewrite + per-key DESC flag.
433        let rewritten: Vec<Expr> = stmt
434            .order_by
435            .iter()
436            .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
437            .collect();
438        let keys_meta: Vec<(bool, Option<bool>)> = stmt
439            .order_by
440            .iter()
441            .map(|o| (o.desc, o.nulls_first))
442            .collect();
443        let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
444            .into_iter()
445            .zip(out_rows)
446            .map(|(s, o)| {
447                let mut keys = Vec::with_capacity(rewritten.len());
448                for e in &rewritten {
449                    keys.push(eval::eval_expr(e, &s, &synth_ctx)?);
450                }
451                Ok::<_, EvalError>((keys, o))
452            })
453            .collect::<Result<_, _>>()?;
454        tagged.sort_by(|a, b| {
455            use core::cmp::Ordering;
456            for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
457                let (desc, nf) = keys_meta[i];
458                let cmp = crate::order_by_value_cmp(desc, nf, ka, kb);
459                if cmp != Ordering::Equal {
460                    return cmp;
461                }
462            }
463            Ordering::Equal
464        });
465        out_rows = tagged.into_iter().map(|(_, o)| o).collect();
466    }
467
468    Ok(AggResult {
469        columns,
470        rows: out_rows,
471    })
472}
473
474/// v7.17.0 — walk the statement again to validate the positional
475/// arity of every aggregate call site. Done after AST collection
476/// rather than inside `collect_aggregates` so the collector stays
477/// infallible; callers in `run()` can do a single early-error
478/// exit before any per-row work.
479fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
480    fn walk(e: &Expr) -> Result<(), EvalError> {
481        if let Expr::FunctionCall { name, args } = e {
482            let lower = name.to_ascii_lowercase();
483            let expected: Option<usize> = match lower.as_str() {
484                "count_star" => Some(0),
485                "count" | "sum" | "avg" | "min" | "max" | "array_agg"
486                // v7.17.0 — boolean aggregates also take exactly
487                // one arg. `every` is an alias normalised inside
488                // collect_aggregates / rewrite_expr.
489                | "bool_and" | "bool_or" | "every" => Some(1),
490                "string_agg" => Some(2),
491                _ => None,
492            };
493            if let Some(want) = expected
494                && args.len() != want
495            {
496                return Err(EvalError::TypeMismatch {
497                    detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
498                });
499            }
500            for a in args {
501                walk(a)?;
502            }
503        } else if let Expr::Binary { lhs, rhs, .. } = e {
504            walk(lhs)?;
505            walk(rhs)?;
506        } else if let Expr::Unary { expr, .. }
507        | Expr::Cast { expr, .. }
508        | Expr::IsNull { expr, .. } = e
509        {
510            walk(expr)?;
511        }
512        Ok(())
513    }
514    for item in &stmt.items {
515        if let SelectItem::Expr { expr, .. } = item {
516            walk(expr)?;
517        }
518    }
519    for o in &stmt.order_by {
520        walk(&o.expr)?;
521    }
522    if let Some(h) = &stmt.having {
523        walk(h)?;
524    }
525    Ok(())
526}
527
528fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
529    match e {
530        // v7.24 (round-16 A) — ordered aggregate: register the inner
531        // call's spec with the ordering attached.
532        Expr::AggregateOrdered {
533            call,
534            order_by,
535            distinct,
536        } => {
537            if let Expr::FunctionCall { name, args } = call.as_ref() {
538                let lower = name.to_ascii_lowercase();
539                if is_aggregate_name(&lower) {
540                    let canonical = if lower == "every" {
541                        "bool_and".to_string()
542                    } else {
543                        lower
544                    };
545                    let spec = AggSpec {
546                        name: canonical,
547                        arg: args.first().cloned(),
548                        arg2: if name.eq_ignore_ascii_case("string_agg") {
549                            args.get(1).cloned()
550                        } else {
551                            None
552                        },
553                        distinct: *distinct,
554                        order_by: order_by.clone(),
555                    };
556                    if !out.iter().any(|s| {
557                        s.name == spec.name
558                            && s.arg == spec.arg
559                            && s.arg2 == spec.arg2
560                            && s.distinct == spec.distinct
561                            && s.order_by == spec.order_by
562                    }) {
563                        out.push(spec);
564                    }
565                    return;
566                }
567            }
568            collect_aggregates(call, out);
569            for o in order_by {
570                collect_aggregates(&o.expr, out);
571            }
572        }
573        Expr::FunctionCall { name, args } => {
574            let lower = name.to_ascii_lowercase();
575            if is_aggregate_name(&lower) {
576                let arg = if lower == "count_star" {
577                    None
578                } else {
579                    args.first().cloned()
580                };
581                // v7.17.0 — second positional arg for
582                // `string_agg(value, separator)`. Everything else
583                // ignores it.
584                let arg2 = if lower == "string_agg" {
585                    args.get(1).cloned()
586                } else {
587                    None
588                };
589                // v7.17.0 — `every` is the SQL-standard alias for
590                // `bool_and`; collapse at collection time so
591                // update_state / finalize need only one arm.
592                let canonical = if lower == "every" {
593                    "bool_and".to_string()
594                } else {
595                    lower
596                };
597                let spec = AggSpec {
598                    name: canonical,
599                    arg: arg.clone(),
600                    arg2: arg2.clone(),
601                    distinct: false,
602                    order_by: Vec::new(),
603                };
604                if !out.iter().any(|s| {
605                    s.name == spec.name
606                        && s.arg == spec.arg
607                        && s.arg2 == spec.arg2
608                        && !s.distinct
609                        && s.order_by == spec.order_by
610                }) {
611                    out.push(spec);
612                }
613                // Don't recurse into the arg — nested aggregates are
614                // illegal in standard SQL.
615            } else {
616                for a in args {
617                    collect_aggregates(a, out);
618                }
619            }
620        }
621        Expr::Binary { lhs, rhs, .. } => {
622            collect_aggregates(lhs, out);
623            collect_aggregates(rhs, out);
624        }
625        Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
626            collect_aggregates(expr, out);
627        }
628        Expr::Like { expr, pattern, .. } => {
629            collect_aggregates(expr, out);
630            collect_aggregates(pattern, out);
631        }
632        Expr::Extract { source, .. } => collect_aggregates(source, out),
633        // v4.10 subquery + v4.12 window / Literal / Column —
634        // non-recursing leaves for the aggregate collector.
635        Expr::ScalarSubquery(_)
636        | Expr::Exists { .. }
637        | Expr::InSubquery { .. }
638        | Expr::WindowFunction { .. }
639        | Expr::Literal(_)
640        | Expr::Placeholder(_)
641        | Expr::Column(_) => {}
642        // v7.10.10 — recurse into array constructor children +
643        // subscript / ANY/ALL operands.
644        Expr::Array(items) => {
645            for elem in items {
646                collect_aggregates(elem, out);
647            }
648        }
649        Expr::ArraySubscript { target, index } => {
650            collect_aggregates(target, out);
651            collect_aggregates(index, out);
652        }
653        Expr::AnyAll { expr, array, .. } => {
654            collect_aggregates(expr, out);
655            collect_aggregates(array, out);
656        }
657        Expr::Case {
658            operand,
659            branches,
660            else_branch,
661        } => {
662            if let Some(o) = operand {
663                collect_aggregates(o, out);
664            }
665            for (w, t) in branches {
666                collect_aggregates(w, out);
667                collect_aggregates(t, out);
668            }
669            if let Some(e) = else_branch {
670                collect_aggregates(e, out);
671            }
672        }
673    }
674}
675
676fn update_state(
677    st: &mut AggState,
678    name: &str,
679    v: &Value,
680    arg2: Option<&Value>,
681    order_keys: Option<Vec<Value>>,
682) -> Result<(), EvalError> {
683    let is_null = matches!(v, Value::Null);
684    match name {
685        "count_star" => st.count += 1,
686        "count" => {
687            if !is_null {
688                st.count += 1;
689            }
690        }
691        "sum" | "avg" => {
692            if is_null {
693                return Ok(());
694            }
695            st.count += 1;
696            match v {
697                Value::Int(n) => st.sum_int += i64::from(*n),
698                Value::BigInt(n) => st.sum_int += *n,
699                Value::Float(x) => {
700                    st.use_float = true;
701                    st.sum_float += *x;
702                }
703                other => {
704                    return Err(EvalError::TypeMismatch {
705                        detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
706                    });
707                }
708            }
709        }
710        "min" => {
711            if is_null {
712                return Ok(());
713            }
714            match &st.extreme {
715                None => st.extreme = Some(v.clone()),
716                Some(cur) => {
717                    if value_cmp(v, cur) == core::cmp::Ordering::Less {
718                        st.extreme = Some(v.clone());
719                    }
720                }
721            }
722        }
723        "max" => {
724            if is_null {
725                return Ok(());
726            }
727            match &st.extreme {
728                None => st.extreme = Some(v.clone()),
729                Some(cur) => {
730                    if value_cmp(v, cur) == core::cmp::Ordering::Greater {
731                        st.extreme = Some(v.clone());
732                    }
733                }
734            }
735        }
736        // v7.17.0 — string_agg(value, separator). NULL value is
737        // skipped (PG aggregate-skip-null). Separator captured
738        // from the latest row that flows through; matches PG's
739        // semantics of evaluating the separator per row but using
740        // the last value at finalize time (in practice it's
741        // constant). count is bumped so we can distinguish "empty
742        // group → NULL" from "all-NULL group → NULL".
743        "string_agg" => {
744            if let Some(sep) = arg2
745                && let Value::Text(s) = sep
746            {
747                st.separator = Some(s.clone());
748            }
749            if is_null {
750                return Ok(());
751            }
752            if let Value::Text(s) = v {
753                st.items.push(Value::Text(s.clone()));
754                if let Some(k) = order_keys {
755                    st.item_keys.push(k);
756                }
757                st.count += 1;
758            } else {
759                return Err(EvalError::TypeMismatch {
760                    detail: format!("string_agg requires text value, got {:?}", v.data_type()),
761                });
762            }
763        }
764        // v7.17.0 — array_agg(value). Unlike string_agg, NULL
765        // elements are KEPT in the array (PG behaviour); the
766        // result is NULL only when ZERO rows fed in. Element type
767        // is locked from the first row's value type; subsequent
768        // rows must match (PG also rejects mixed-type array_agg).
769        "array_agg" => {
770            st.items.push(v.clone());
771            if let Some(k) = order_keys {
772                st.item_keys.push(k);
773            }
774            st.count += 1;
775        }
776        // v7.17.0 — bool_and(p): TRUE iff every non-NULL input is
777        // TRUE. NULL skipped; running accumulator stays at TRUE
778        // until the first non-NULL FALSE.
779        "bool_and" => {
780            if is_null {
781                return Ok(());
782            }
783            let b = match v {
784                Value::Bool(b) => *b,
785                other => {
786                    return Err(EvalError::TypeMismatch {
787                        detail: format!("bool_and requires bool, got {:?}", other.data_type()),
788                    });
789                }
790            };
791            st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
792        }
793        // v7.17.0 — bool_or(p): TRUE iff any non-NULL input is
794        // TRUE. NULL skipped.
795        "bool_or" => {
796            if is_null {
797                return Ok(());
798            }
799            let b = match v {
800                Value::Bool(b) => *b,
801                other => {
802                    return Err(EvalError::TypeMismatch {
803                        detail: format!("bool_or requires bool, got {:?}", other.data_type()),
804                    });
805                }
806            };
807            st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
808        }
809        _ => unreachable!("non-aggregate {name} in update_state"),
810    }
811    Ok(())
812}
813
814#[allow(clippy::cast_precision_loss)]
815fn finalize(name: &str, st: &AggState) -> Value {
816    match name {
817        "count" | "count_star" => Value::BigInt(st.count),
818        "sum" => {
819            if st.count == 0 {
820                Value::Null
821            } else if st.use_float {
822                Value::Float(st.sum_float + (st.sum_int as f64))
823            } else {
824                Value::BigInt(st.sum_int)
825            }
826        }
827        "avg" => {
828            if st.count == 0 {
829                Value::Null
830            } else {
831                let total = if st.use_float {
832                    st.sum_float + (st.sum_int as f64)
833                } else {
834                    st.sum_int as f64
835                };
836                Value::Float(total / (st.count as f64))
837            }
838        }
839        "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
840        // v7.17.0 — string_agg: join all collected text items with
841        // the captured separator. Empty / all-NULL group → NULL
842        // (PG semantics).
843        "string_agg" => {
844            if st.items.is_empty() {
845                return Value::Null;
846            }
847            let sep = st.separator.clone().unwrap_or_default();
848            let mut out = String::new();
849            for (i, item) in st.items.iter().enumerate() {
850                if i > 0 {
851                    out.push_str(&sep);
852                }
853                if let Value::Text(s) = item {
854                    out.push_str(s);
855                }
856            }
857            Value::Text(out)
858        }
859        // v7.17.0 — array_agg: collect into a typed array. NULL
860        // elements are preserved per PG. Result type is decided
861        // by the first non-NULL element seen (or Text fallback
862        // when the whole group is NULL — PG would surface the
863        // declared input type, but SPG hasn't yet wired the
864        // aggregate's static input-type from `describe`).
865        "array_agg" => {
866            if st.items.is_empty() {
867                return Value::Null;
868            }
869            let probe = st.items.iter().find(|v| !v.is_null());
870            match probe.and_then(spg_storage::Value::data_type) {
871                Some(DataType::Int) | Some(DataType::SmallInt) => {
872                    let items: Vec<Option<i32>> = st
873                        .items
874                        .iter()
875                        .map(|v| match v {
876                            Value::Int(n) => Some(*n),
877                            Value::SmallInt(n) => Some(i32::from(*n)),
878                            _ => None,
879                        })
880                        .collect();
881                    Value::IntArray(items)
882                }
883                Some(DataType::BigInt) => {
884                    let items: Vec<Option<i64>> = st
885                        .items
886                        .iter()
887                        .map(|v| match v {
888                            Value::BigInt(n) => Some(*n),
889                            _ => None,
890                        })
891                        .collect();
892                    Value::BigIntArray(items)
893                }
894                _ => {
895                    let items: Vec<Option<String>> = st
896                        .items
897                        .iter()
898                        .map(|v| match v {
899                            Value::Text(s) => Some(s.clone()),
900                            Value::Null => None,
901                            other => Some(format!("{other:?}")),
902                        })
903                        .collect();
904                    Value::TextArray(items)
905                }
906            }
907        }
908        // v7.17.0 — bool_and / bool_or finalize: lazy-init pattern
909        // means `None` is exactly "empty group or all-NULL", which
910        // PG surfaces as SQL NULL.
911        "bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
912        _ => unreachable!(),
913    }
914}
915
916fn infer_agg_type(spec: &AggSpec) -> DataType {
917    match spec.name.as_str() {
918        // count/count_star are exact integer counts; sum widens to BigInt
919        // and reports as such even for Float input (the value column is
920        // nullable so the wire layer surfaces the Float at runtime).
921        "count" | "count_star" | "sum" => DataType::BigInt,
922        "avg" => DataType::Float,
923        // v7.17.0 — string_agg always returns TEXT.
924        "string_agg" => DataType::Text,
925        // v7.17.0 — array_agg's declared output type can't be
926        // known without inspecting the argument's expression
927        // shape. Default to TextArray; finalize widens to
928        // IntArray / BigIntArray when the actual elements are
929        // numeric. Downstream column metadata reports TextArray
930        // which is the lowest common denominator.
931        "array_agg" => DataType::TextArray,
932        // v7.17.0 — boolean aggregates always return BOOL (nullable
933        // — empty / all-NULL group → NULL).
934        "bool_and" | "bool_or" => DataType::Bool,
935        // min/max: we don't know the input type without probing — default
936        // to Text and let downstream rendering coerce.
937        _ => DataType::Text,
938    }
939}
940
941fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
942    if let Expr::Column(c) = e
943        && let Some(s) = synth.iter().find(|s| s.name == c.name)
944    {
945        return s.ty;
946    }
947    // Compound expression — fall back to Text (matches build_projection
948    // behaviour for non-column expressions in the non-aggregate path).
949    DataType::Text
950}
951
952fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
953    // v7.24 (round-16 A) — ordered aggregate: match on the inner
954    // call PLUS the ordering keys.
955    if let Expr::AggregateOrdered {
956        call,
957        order_by,
958        distinct,
959    } = e
960        && let Expr::FunctionCall { name, args } = call.as_ref()
961    {
962        let lower = name.to_ascii_lowercase();
963        if is_aggregate_name(&lower) {
964            let canonical: &str = if lower == "every" { "bool_and" } else { &lower };
965            let arg = args.first().cloned();
966            let arg2 = if lower == "string_agg" {
967                args.get(1).cloned()
968            } else {
969                None
970            };
971            for (i, spec) in aggs.iter().enumerate() {
972                if spec.name == canonical
973                    && spec.arg == arg
974                    && spec.arg2 == arg2
975                    && spec.distinct == *distinct
976                    && spec.order_by == *order_by
977                {
978                    return Expr::Column(spg_sql::ast::ColumnName {
979                        qualifier: None,
980                        name: format!("__agg_{i}"),
981                    });
982                }
983            }
984        }
985    }
986    // Match aggregate FunctionCalls first — they sit outside group_by.
987    if let Expr::FunctionCall { name, args } = e {
988        let lower = name.to_ascii_lowercase();
989        if is_aggregate_name(&lower) {
990            let arg = if lower == "count_star" {
991                None
992            } else {
993                args.first().cloned()
994            };
995            // v7.17.0 — match the spec we registered for
996            // string_agg(value, separator) on the full pair.
997            let arg2 = if lower == "string_agg" {
998                args.get(1).cloned()
999            } else {
1000                None
1001            };
1002            // v7.17.0 — `every` collapses into `bool_and` at
1003            // collection; mirror that here so the rewrite finds
1004            // the matching synth column.
1005            let canonical: &str = if lower == "every" {
1006                "bool_and"
1007            } else {
1008                lower.as_str()
1009            };
1010            for (i, spec) in aggs.iter().enumerate() {
1011                if spec.name == canonical
1012                    && spec.arg == arg
1013                    && spec.arg2 == arg2
1014                    && !spec.distinct
1015                    && spec.order_by.is_empty()
1016                {
1017                    return Expr::Column(spg_sql::ast::ColumnName {
1018                        qualifier: None,
1019                        name: format!("__agg_{i}"),
1020                    });
1021                }
1022            }
1023        }
1024    }
1025    // Match a group_by expression by AST equality.
1026    for (i, g) in group_exprs.iter().enumerate() {
1027        if g == e {
1028            return Expr::Column(spg_sql::ast::ColumnName {
1029                qualifier: None,
1030                name: format!("__grp_{i}"),
1031            });
1032        }
1033    }
1034    // Recurse into children.
1035    match e {
1036        Expr::AggregateOrdered {
1037            call,
1038            order_by,
1039            distinct,
1040        } => Expr::AggregateOrdered {
1041            call: Box::new(rewrite_expr(call, group_exprs, aggs)),
1042            distinct: *distinct,
1043            order_by: order_by
1044                .iter()
1045                .map(|o| spg_sql::ast::OrderBy {
1046                    expr: rewrite_expr(&o.expr, group_exprs, aggs),
1047                    desc: o.desc,
1048                    nulls_first: o.nulls_first,
1049                })
1050                .collect(),
1051        },
1052        Expr::Binary { lhs, op, rhs } => Expr::Binary {
1053            lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
1054            op: *op,
1055            rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
1056        },
1057        Expr::Unary { op, expr } => Expr::Unary {
1058            op: *op,
1059            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1060        },
1061        Expr::Cast { expr, target } => Expr::Cast {
1062            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1063            target: *target,
1064        },
1065        Expr::IsNull { expr, negated } => Expr::IsNull {
1066            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1067            negated: *negated,
1068        },
1069        Expr::FunctionCall { name, args } => Expr::FunctionCall {
1070            name: name.clone(),
1071            args: args
1072                .iter()
1073                .map(|a| rewrite_expr(a, group_exprs, aggs))
1074                .collect(),
1075        },
1076        Expr::Like {
1077            expr,
1078            pattern,
1079            negated,
1080            case_insensitive,
1081        } => Expr::Like {
1082            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1083            pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
1084            negated: *negated,
1085            case_insensitive: *case_insensitive,
1086        },
1087        Expr::Extract { field, source } => Expr::Extract {
1088            field: *field,
1089            source: Box::new(rewrite_expr(source, group_exprs, aggs)),
1090        },
1091        // v4.10 subquery + v4.12 window / Literal / Column —
1092        // clone-pass (these don't participate in aggregate rewrite).
1093        Expr::ScalarSubquery(_)
1094        | Expr::Exists { .. }
1095        | Expr::InSubquery { .. }
1096        | Expr::WindowFunction { .. }
1097        | Expr::Literal(_)
1098        | Expr::Placeholder(_)
1099        | Expr::Column(_) => e.clone(),
1100        // v7.10.10 — recurse children for array nodes.
1101        Expr::Array(items) => Expr::Array(
1102            items
1103                .iter()
1104                .map(|elem| rewrite_expr(elem, group_exprs, aggs))
1105                .collect(),
1106        ),
1107        Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
1108            target: Box::new(rewrite_expr(target, group_exprs, aggs)),
1109            index: Box::new(rewrite_expr(index, group_exprs, aggs)),
1110        },
1111        Expr::AnyAll {
1112            expr,
1113            op,
1114            array,
1115            is_any,
1116        } => Expr::AnyAll {
1117            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1118            op: *op,
1119            array: Box::new(rewrite_expr(array, group_exprs, aggs)),
1120            is_any: *is_any,
1121        },
1122        Expr::Case {
1123            operand,
1124            branches,
1125            else_branch,
1126        } => Expr::Case {
1127            operand: operand
1128                .as_deref()
1129                .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
1130            branches: branches
1131                .iter()
1132                .map(|(w, t)| {
1133                    (
1134                        rewrite_expr(w, group_exprs, aggs),
1135                        rewrite_expr(t, group_exprs, aggs),
1136                    )
1137                })
1138                .collect(),
1139            else_branch: else_branch
1140                .as_deref()
1141                .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
1142        },
1143    }
1144}
1145
1146/// Canonical string key for a tuple of group values. Used as map key.
1147fn encode_key(vals: &[Value]) -> String {
1148    let mut out = String::new();
1149    for v in vals {
1150        match v {
1151            Value::Null => out.push_str("N|"),
1152            Value::SmallInt(n) => {
1153                out.push('s');
1154                out.push_str(&n.to_string());
1155                out.push('|');
1156            }
1157            Value::Int(n) => {
1158                out.push('I');
1159                out.push_str(&n.to_string());
1160                out.push('|');
1161            }
1162            Value::BigInt(n) => {
1163                out.push('B');
1164                out.push_str(&n.to_string());
1165                out.push('|');
1166            }
1167            Value::Float(x) => {
1168                out.push('F');
1169                out.push_str(&x.to_string());
1170                out.push('|');
1171            }
1172            Value::Bool(b) => {
1173                out.push(if *b { 'T' } else { 'f' });
1174                out.push('|');
1175            }
1176            Value::Text(s) => {
1177                out.push('S');
1178                out.push_str(s);
1179                out.push('|');
1180            }
1181            Value::Vector(v) => {
1182                out.push('V');
1183                for x in v {
1184                    out.push_str(&x.to_string());
1185                    out.push(',');
1186                }
1187                out.push('|');
1188            }
1189            // v6.0.1: GROUP BY on a `VECTOR(N) USING SQ8` column.
1190            // Two cells with byte-identical `(min, max, bytes)`
1191            // share the same group; equivalence is byte-equality
1192            // (same as f32 grouping today — neither path tries to
1193            // normalise nan/-0).
1194            Value::Sq8Vector(q) => {
1195                out.push('Q');
1196                out.push_str(&q.min.to_string());
1197                out.push('@');
1198                out.push_str(&q.max.to_string());
1199                out.push(':');
1200                for b in &q.bytes {
1201                    out.push_str(&b.to_string());
1202                    out.push(',');
1203                }
1204                out.push('|');
1205            }
1206            // v6.0.3: GROUP BY on a `VECTOR(N) USING HALF` column.
1207            // Byte-equality over the raw u16 bits; matches the SQ8
1208            // path's byte-key model.
1209            Value::HalfVector(h) => {
1210                out.push('H');
1211                for b in &h.bytes {
1212                    out.push_str(&b.to_string());
1213                    out.push(',');
1214                }
1215                out.push('|');
1216            }
1217            Value::Numeric { scaled, scale } => {
1218                out.push('D');
1219                out.push_str(&scaled.to_string());
1220                out.push('@');
1221                out.push_str(&scale.to_string());
1222                out.push('|');
1223            }
1224            Value::Date(d) => {
1225                out.push('d');
1226                out.push_str(&d.to_string());
1227                out.push('|');
1228            }
1229            Value::Timestamp(t) => {
1230                out.push('t');
1231                out.push_str(&t.to_string());
1232                out.push('|');
1233            }
1234            Value::Interval { months, micros } => {
1235                out.push('i');
1236                out.push_str(&months.to_string());
1237                out.push('m');
1238                out.push_str(&micros.to_string());
1239                out.push('|');
1240            }
1241            Value::Json(s) => {
1242                out.push('j');
1243                out.push_str(s);
1244                out.push('|');
1245            }
1246            // v7.5.0 — Value is #[non_exhaustive] for downstream
1247            // forward-compat. Any future variant lacking explicit
1248            // handling here will share a debug-derived group key,
1249            // which is observably wrong but won't crash.
1250            _ => {
1251                out.push('?');
1252                out.push_str(&format!("{v:?}"));
1253                out.push('|');
1254            }
1255        }
1256    }
1257    out
1258}
1259
1260#[allow(clippy::cast_precision_loss)]
1261fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
1262    use core::cmp::Ordering::Equal;
1263    match (a, b) {
1264        (Value::Null, Value::Null) => Equal,
1265        (Value::Null, _) => core::cmp::Ordering::Greater, // NULLs last
1266        (_, Value::Null) => core::cmp::Ordering::Less,
1267        (Value::Int(x), Value::Int(y)) => x.cmp(y),
1268        (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
1269        (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
1270        (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
1271        (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
1272        (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
1273        (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
1274        (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
1275        (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
1276        (Value::Text(x), Value::Text(y)) => x.cmp(y),
1277        (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1278        _ => Equal,
1279    }
1280}