Skip to main content

spg_engine/
aggregate.rs

1//! Aggregate executor.
2//!
3//! Handles `SELECT … <aggs> … [GROUP BY …]` queries. The planning strategy
4//! is straightforward:
5//!
6//! 1. Walk the SELECT (and ORDER BY) expressions to find every aggregate
7//!    function call. Dedupe by AST equality and assign each `__agg_<i>`.
8//! 2. Same for every `GROUP BY` expression: assign `__grp_<j>`.
9//! 3. Stream the WHERE-filtered rows, group by the tuple of GROUP BY
10//!    values, and update per-group aggregate state.
11//! 4. Materialise a synthetic per-group row containing
12//!    `[__grp_0..__grp_K, __agg_0..__agg_N]` and rewrite the user's
13//!    SELECT / ORDER BY expressions to reference those synthetic columns
14//!    instead of the originals.
15//! 5. Evaluate the rewritten expressions against the synthetic schema and
16//!    emit results.
17//!
18//! v1.8 implements `count(*)`, `count(expr)`, `sum`, `min`, `max`, `avg`.
19//! NULL semantics follow PG: aggregates skip NULL inputs (except
20//! `count(*)`, which counts rows). `sum(int)` widens to `BigInt`;
21//! `avg(int|bigint)` returns `Float`.
22
23use alloc::boxed::Box;
24use alloc::collections::BTreeMap;
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34/// True if this statement should go through the aggregate path.
35pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36    if stmt.group_by.is_some() || stmt.having.is_some() {
37        return true;
38    }
39    for item in &stmt.items {
40        if let SelectItem::Expr { expr, .. } = item
41            && contains_aggregate(expr)
42        {
43            return true;
44        }
45    }
46    for o in &stmt.order_by {
47        if contains_aggregate(&o.expr) {
48            return true;
49        }
50    }
51    if let Some(h) = &stmt.having
52        && contains_aggregate(h)
53    {
54        return true;
55    }
56    false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60    match e {
61        Expr::FunctionCall { name, args } => {
62            is_aggregate_name(name) || args.iter().any(contains_aggregate)
63        }
64        Expr::AggregateOrdered { .. } => true,
65        Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
66        Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
67            contains_aggregate(expr)
68        }
69        Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
70        Expr::Extract { source, .. } => contains_aggregate(source),
71        // v4.10 subqueries + v4.12 window functions / Literal /
72        // Column — all non-aggregate leaves from the regular
73        // aggregate planner's POV. Window-bearing projections are
74        // routed to exec_select_with_window before this runs.
75        Expr::ScalarSubquery(_)
76        | Expr::Exists { .. }
77        | Expr::InSubquery { .. }
78        | Expr::WindowFunction { .. }
79        | Expr::Literal(_)
80        | Expr::Placeholder(_)
81        | Expr::Column(_) => false,
82        // v7.10.10 — recurse into array constructor / subscript /
83        // ANY/ALL children. Aggregates inside `ARRAY[SUM(x)]` are
84        // valid PG and must be detected here.
85        Expr::Array(items) => items.iter().any(contains_aggregate),
86        Expr::ArraySubscript { target, index } => {
87            contains_aggregate(target) || contains_aggregate(index)
88        }
89        Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
90        // v7.13.0 — CASE WHEN … END. Recurse into operand,
91        // every (WHEN, THEN) pair, and the ELSE branch.
92        Expr::Case {
93            operand,
94            branches,
95            else_branch,
96        } => {
97            operand.as_deref().is_some_and(contains_aggregate)
98                || branches
99                    .iter()
100                    .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
101                || else_branch.as_deref().is_some_and(contains_aggregate)
102        }
103    }
104}
105
106pub fn is_aggregate_name(name: &str) -> bool {
107    matches!(
108        name.to_ascii_lowercase().as_str(),
109        "count"
110            | "count_star"
111            | "sum"
112            | "min"
113            | "max"
114            | "avg"
115            // v7.17.0 — variadic / collection aggregates. ORM
116            // reports (Hibernate / Rails / Django) emit these in
117            // GROUP BY rollups; pre-7.17 SPG hit "unknown
118            // aggregate".
119            | "string_agg"
120            | "array_agg"
121            // v7.17.0 — boolean aggregates. `every` is SQL-standard
122            // alias for `bool_and`.
123            | "bool_and"
124            | "bool_or"
125            | "every"
126    )
127}
128
129/// Per-aggregate running state.
130#[derive(Debug, Default, Clone)]
131struct AggState {
132    count: i64,
133    sum_int: i64,
134    sum_float: f64,
135    extreme: Option<Value>,
136    use_float: bool,
137    /// v7.17.0 — running collection for string_agg / array_agg.
138    /// Each entry is one row's contribution (NULL preserved as
139    /// `Value::Null`; string_agg's finalize step drops them, but
140    /// array_agg keeps them). Pushing in insertion order matches
141    /// PG behaviour when no `ORDER BY` is given inside the
142    /// aggregate call.
143    items: Vec<Value>,
144    /// v7.24 (round-16 A) — per-item ORDER BY key tuples, parallel
145    /// to `items` (pushed under the same skip/keep conditions).
146    /// Empty when the aggregate carries no internal ordering.
147    item_keys: Vec<Vec<Value>>,
148    /// v7.17.0 — captured separator for string_agg. PG accepts a
149    /// non-constant separator expression but in practice every
150    /// caller passes a literal; the engine snapshots the last
151    /// non-NULL text it sees, which matches PG's "use the latest
152    /// row's value" behaviour.
153    separator: Option<String>,
154    /// v7.17.0 — running boolean accumulator for bool_and /
155    /// bool_or / every. `None` until the first non-NULL input;
156    /// at finalize None → SQL NULL.
157    bool_acc: Option<bool>,
158}
159
160#[derive(Debug, Clone)]
161struct AggSpec {
162    name: String, // lowercased
163    /// First argument (value expression) for every aggregate
164    /// except `count(*)`. `None` for `count_star`.
165    arg: Option<Expr>,
166    /// v7.17.0 — second argument. Only `string_agg(value, sep)`
167    /// uses it today. `None` for every other aggregate (or for
168    /// `array_agg`, which is single-arg). Carried in the spec so
169    /// per-row evaluation can re-use the same separator
170    /// expression across calls.
171    arg2: Option<Expr>,
172    /// v7.24 (round-16 A) — aggregate-internal ORDER BY keys
173    /// (`array_agg(x ORDER BY y DESC NULLS LAST)`). Empty for the
174    /// plain form. Only the collection aggregates honour it;
175    /// other aggregates are order-insensitive and ignore it (PG
176    /// accepts the syntax everywhere too).
177    order_by: Vec<spg_sql::ast::OrderBy>,
178}
179
180/// Output of running the aggregate path. Schema describes one row per
181/// group; rows are not yet ORDER BY-sorted (caller does it).
182#[derive(Debug)]
183pub struct AggResult {
184    pub columns: Vec<ColumnSchema>,
185    pub rows: Vec<Row>,
186}
187
188/// Execute aggregate logic against an already-WHERE-filtered iterator of
189/// rows. `table_alias` is the alias accepted by column resolution.
190#[allow(clippy::too_many_lines)]
191pub fn run(
192    stmt: &SelectStatement,
193    rows: &[&Row],
194    schema_cols: &[ColumnSchema],
195    table_alias: Option<&str>,
196) -> Result<AggResult, EvalError> {
197    let ctx = EvalContext::new(schema_cols, table_alias);
198    let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
199
200    // Collect aggregate sub-expressions across items + order_by.
201    let mut agg_specs: Vec<AggSpec> = Vec::new();
202    for item in &stmt.items {
203        if let SelectItem::Expr { expr, .. } = item {
204            collect_aggregates(expr, &mut agg_specs);
205        }
206    }
207    for o in &stmt.order_by {
208        collect_aggregates(&o.expr, &mut agg_specs);
209    }
210    if let Some(h) = &stmt.having {
211        collect_aggregates(h, &mut agg_specs);
212    }
213    // v7.17.0 — arity validation. The collector tolerates an
214    // arbitrary positional-arg count; here we enforce the
215    // per-aggregate contract so a malformed call (e.g.
216    // `array_agg()` or `string_agg(x)`) surfaces as a SQL error
217    // rather than silently coercing to a degenerate aggregate.
218    validate_agg_arities(stmt, &agg_specs)?;
219
220    // Map group key (vec of values, encoded as canonical string) -> group state.
221    // Order of insertion is preserved via a parallel Vec of keys.
222    let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
223    let mut key_order: Vec<String> = Vec::new();
224    // When there are no GROUP BY exprs *and* there is at least one aggregate,
225    // every row collapses into a single anonymous group keyed by "".
226    if rows.is_empty() && group_exprs.is_empty() {
227        // Single empty-aggregate group: count=0, sum=0, max=NULL, etc.
228        let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
229        groups.insert(String::new(), (Vec::new(), init));
230        key_order.push(String::new());
231    }
232
233    for row in rows {
234        let group_vals: Vec<Value> = group_exprs
235            .iter()
236            .map(|g| eval::eval_expr(g, row, &ctx))
237            .collect::<Result<_, _>>()?;
238        // v7.17.0 Phase 2.5b — case-insensitive group keying.
239        // For each group_expr that's a column reference on a
240        // CaseInsensitive text column, fold the corresponding
241        // value before encoding the key. Display value
242        // (`group_vals`) stays original — only the key folds.
243        let mut key_vals = group_vals.clone();
244        for (i, g) in group_exprs.iter().enumerate() {
245            if matches!(
246                eval::column_collation(g, &ctx),
247                Some(spg_storage::Collation::CaseInsensitive)
248            ) {
249                if let Value::Text(s) = &key_vals[i] {
250                    key_vals[i] = Value::Text(s.to_ascii_lowercase());
251                }
252            }
253        }
254        let key = encode_key(&key_vals);
255        let entry = groups.entry(key.clone()).or_insert_with(|| {
256            key_order.push(key.clone());
257            let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
258            (group_vals.clone(), init)
259        });
260        for (i, spec) in agg_specs.iter().enumerate() {
261            let arg_val = match &spec.arg {
262                None => Value::Bool(true), // count_star: sentinel non-null
263                Some(e) => eval::eval_expr(e, row, &ctx)?,
264            };
265            // v7.17.0 — `string_agg(value, separator)` evaluates the
266            // separator per row but PG treats it as constant; we
267            // pass the per-row value into update_state so a future
268            // varying-separator caller still sees correct output,
269            // even though SPG (like PG) only uses the most recent.
270            let arg2_val = match &spec.arg2 {
271                None => None,
272                Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
273            };
274            // v7.24 (round-16 A) — aggregate-internal ORDER BY:
275            // evaluate the key tuple against the source row.
276            let order_keys = if spec.order_by.is_empty() {
277                None
278            } else {
279                let mut keys = Vec::with_capacity(spec.order_by.len());
280                for o in &spec.order_by {
281                    keys.push(eval::eval_expr(&o.expr, row, &ctx)?);
282                }
283                Some(keys)
284            };
285            update_state(
286                &mut entry.1[i],
287                &spec.name,
288                &arg_val,
289                arg2_val.as_ref(),
290                order_keys,
291            )?;
292        }
293    }
294
295    // Build synthetic schema: __grp_0..K then __agg_0..N.
296    let group_types: Vec<DataType> = if rows.is_empty() {
297        // Use Text as a safe stand-in — empty result means schema isn't
298        // observable. Avoids needing to evaluate group exprs on no row.
299        group_exprs.iter().map(|_| DataType::Text).collect()
300    } else {
301        let probe = rows[0];
302        group_exprs
303            .iter()
304            .map(|g| {
305                eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
306            })
307            .collect::<Result<_, _>>()?
308    };
309    let agg_types: Vec<DataType> = agg_specs.iter().map(infer_agg_type).collect();
310    let mut synth_schema: Vec<ColumnSchema> = Vec::new();
311    for (i, ty) in group_types.iter().enumerate() {
312        synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
313    }
314    for (i, ty) in agg_types.iter().enumerate() {
315        synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
316    }
317
318    // Materialise synthetic rows.
319    let mut synth_rows: Vec<Row> = Vec::new();
320    for k in &key_order {
321        let (gvals, states) = &groups[k];
322        let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
323        values.extend(gvals.iter().cloned());
324        for (i, st) in states.iter().enumerate() {
325            // v7.24 (round-16 A) — order the collected items per the
326            // aggregate-internal ORDER BY before finalize consumes
327            // them.
328            let st_sorted;
329            let st_final: &AggState =
330                if !agg_specs[i].order_by.is_empty() && st.item_keys.len() == st.items.len() {
331                    let mut idx: Vec<usize> = (0..st.items.len()).collect();
332                    let ob = &agg_specs[i].order_by;
333                    idx.sort_by(|&x, &y| {
334                        for (k, o) in ob.iter().enumerate() {
335                            let cmp = crate::order_by_value_cmp(
336                                o.desc,
337                                o.nulls_first,
338                                &st.item_keys[x][k],
339                                &st.item_keys[y][k],
340                            );
341                            if cmp != core::cmp::Ordering::Equal {
342                                return cmp;
343                            }
344                        }
345                        core::cmp::Ordering::Equal
346                    });
347                    let mut sorted = st.clone();
348                    sorted.items = idx.iter().map(|&j| st.items[j].clone()).collect();
349                    st_sorted = sorted;
350                    &st_sorted
351                } else {
352                    st
353                };
354            values.push(finalize(&agg_specs[i].name, st_final));
355        }
356        synth_rows.push(Row::new(values));
357    }
358
359    // Rewrite the user's SELECT items + ORDER BY to reference synthetic
360    // columns. After rewriting, every remaining `Expr::Column` must
361    // resolve against the synthetic schema (i.e. must have been a GROUP
362    // BY expression).
363    let columns: Vec<ColumnSchema> = stmt
364        .items
365        .iter()
366        .map(|item| match item {
367            SelectItem::Wildcard => Err(EvalError::TypeMismatch {
368                detail: "SELECT * with aggregates is not supported".into(),
369            }),
370            SelectItem::Expr { expr, alias } => {
371                let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
372                let name = alias.clone().unwrap_or_else(|| expr.to_string());
373                Ok(ColumnSchema::new(
374                    name,
375                    agg_or_group_type(&rewritten, &synth_schema),
376                    true,
377                ))
378            }
379        })
380        .collect::<Result<_, _>>()?;
381
382    // Project per synthetic row. HAVING filters out groups *before*
383    // we keep the projected row — same semantics as PG: HAVING runs
384    // against the aggregated row (so `HAVING count(*) > 1` works) and
385    // sees only group-by'd columns plus aggregate values.
386    let synth_ctx = EvalContext::new(&synth_schema, None);
387    let having_rewritten = stmt
388        .having
389        .as_ref()
390        .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
391    let mut kept_synth: Vec<Row> = Vec::new();
392    let mut out_rows: Vec<Row> = Vec::new();
393    for srow in synth_rows {
394        if let Some(h) = &having_rewritten {
395            let cond = eval::eval_expr(h, &srow, &synth_ctx)?;
396            if !matches!(cond, Value::Bool(true)) {
397                continue;
398            }
399        }
400        let mut values: Vec<Value> = Vec::with_capacity(columns.len());
401        for item in &stmt.items {
402            if let SelectItem::Expr { expr, .. } = item {
403                let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
404                values.push(eval::eval_expr(&rewritten, &srow, &synth_ctx)?);
405            }
406        }
407        kept_synth.push(srow);
408        out_rows.push(Row::new(values));
409    }
410
411    // ORDER BY: evaluate the rewritten order_by against each synth row,
412    // sort, then drop the keys. Limit is applied by the caller.
413    if !stmt.order_by.is_empty() {
414        // v6.4.0 — multi-key ORDER BY on aggregate output. Each key
415        // gets its own rewrite + per-key DESC flag.
416        let rewritten: Vec<Expr> = stmt
417            .order_by
418            .iter()
419            .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
420            .collect();
421        let keys_meta: Vec<(bool, Option<bool>)> = stmt
422            .order_by
423            .iter()
424            .map(|o| (o.desc, o.nulls_first))
425            .collect();
426        let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
427            .into_iter()
428            .zip(out_rows)
429            .map(|(s, o)| {
430                let mut keys = Vec::with_capacity(rewritten.len());
431                for e in &rewritten {
432                    keys.push(eval::eval_expr(e, &s, &synth_ctx)?);
433                }
434                Ok::<_, EvalError>((keys, o))
435            })
436            .collect::<Result<_, _>>()?;
437        tagged.sort_by(|a, b| {
438            use core::cmp::Ordering;
439            for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
440                let (desc, nf) = keys_meta[i];
441                let cmp = crate::order_by_value_cmp(desc, nf, ka, kb);
442                if cmp != Ordering::Equal {
443                    return cmp;
444                }
445            }
446            Ordering::Equal
447        });
448        out_rows = tagged.into_iter().map(|(_, o)| o).collect();
449    }
450
451    Ok(AggResult {
452        columns,
453        rows: out_rows,
454    })
455}
456
457/// v7.17.0 — walk the statement again to validate the positional
458/// arity of every aggregate call site. Done after AST collection
459/// rather than inside `collect_aggregates` so the collector stays
460/// infallible; callers in `run()` can do a single early-error
461/// exit before any per-row work.
462fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
463    fn walk(e: &Expr) -> Result<(), EvalError> {
464        if let Expr::FunctionCall { name, args } = e {
465            let lower = name.to_ascii_lowercase();
466            let expected: Option<usize> = match lower.as_str() {
467                "count_star" => Some(0),
468                "count" | "sum" | "avg" | "min" | "max" | "array_agg"
469                // v7.17.0 — boolean aggregates also take exactly
470                // one arg. `every` is an alias normalised inside
471                // collect_aggregates / rewrite_expr.
472                | "bool_and" | "bool_or" | "every" => Some(1),
473                "string_agg" => Some(2),
474                _ => None,
475            };
476            if let Some(want) = expected
477                && args.len() != want
478            {
479                return Err(EvalError::TypeMismatch {
480                    detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
481                });
482            }
483            for a in args {
484                walk(a)?;
485            }
486        } else if let Expr::Binary { lhs, rhs, .. } = e {
487            walk(lhs)?;
488            walk(rhs)?;
489        } else if let Expr::Unary { expr, .. }
490        | Expr::Cast { expr, .. }
491        | Expr::IsNull { expr, .. } = e
492        {
493            walk(expr)?;
494        }
495        Ok(())
496    }
497    for item in &stmt.items {
498        if let SelectItem::Expr { expr, .. } = item {
499            walk(expr)?;
500        }
501    }
502    for o in &stmt.order_by {
503        walk(&o.expr)?;
504    }
505    if let Some(h) = &stmt.having {
506        walk(h)?;
507    }
508    Ok(())
509}
510
511fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
512    match e {
513        // v7.24 (round-16 A) — ordered aggregate: register the inner
514        // call's spec with the ordering attached.
515        Expr::AggregateOrdered { call, order_by } => {
516            if let Expr::FunctionCall { name, args } = call.as_ref() {
517                let lower = name.to_ascii_lowercase();
518                if is_aggregate_name(&lower) {
519                    let canonical = if lower == "every" {
520                        "bool_and".to_string()
521                    } else {
522                        lower
523                    };
524                    let spec = AggSpec {
525                        name: canonical,
526                        arg: args.first().cloned(),
527                        arg2: if name.eq_ignore_ascii_case("string_agg") {
528                            args.get(1).cloned()
529                        } else {
530                            None
531                        },
532                        order_by: order_by.clone(),
533                    };
534                    if !out.iter().any(|s| {
535                        s.name == spec.name
536                            && s.arg == spec.arg
537                            && s.arg2 == spec.arg2
538                            && s.order_by == spec.order_by
539                    }) {
540                        out.push(spec);
541                    }
542                    return;
543                }
544            }
545            collect_aggregates(call, out);
546            for o in order_by {
547                collect_aggregates(&o.expr, out);
548            }
549        }
550        Expr::FunctionCall { name, args } => {
551            let lower = name.to_ascii_lowercase();
552            if is_aggregate_name(&lower) {
553                let arg = if lower == "count_star" {
554                    None
555                } else {
556                    args.first().cloned()
557                };
558                // v7.17.0 — second positional arg for
559                // `string_agg(value, separator)`. Everything else
560                // ignores it.
561                let arg2 = if lower == "string_agg" {
562                    args.get(1).cloned()
563                } else {
564                    None
565                };
566                // v7.17.0 — `every` is the SQL-standard alias for
567                // `bool_and`; collapse at collection time so
568                // update_state / finalize need only one arm.
569                let canonical = if lower == "every" {
570                    "bool_and".to_string()
571                } else {
572                    lower
573                };
574                let spec = AggSpec {
575                    name: canonical,
576                    arg: arg.clone(),
577                    arg2: arg2.clone(),
578                    order_by: Vec::new(),
579                };
580                if !out.iter().any(|s| {
581                    s.name == spec.name
582                        && s.arg == spec.arg
583                        && s.arg2 == spec.arg2
584                        && s.order_by == spec.order_by
585                }) {
586                    out.push(spec);
587                }
588                // Don't recurse into the arg — nested aggregates are
589                // illegal in standard SQL.
590            } else {
591                for a in args {
592                    collect_aggregates(a, out);
593                }
594            }
595        }
596        Expr::Binary { lhs, rhs, .. } => {
597            collect_aggregates(lhs, out);
598            collect_aggregates(rhs, out);
599        }
600        Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
601            collect_aggregates(expr, out);
602        }
603        Expr::Like { expr, pattern, .. } => {
604            collect_aggregates(expr, out);
605            collect_aggregates(pattern, out);
606        }
607        Expr::Extract { source, .. } => collect_aggregates(source, out),
608        // v4.10 subquery + v4.12 window / Literal / Column —
609        // non-recursing leaves for the aggregate collector.
610        Expr::ScalarSubquery(_)
611        | Expr::Exists { .. }
612        | Expr::InSubquery { .. }
613        | Expr::WindowFunction { .. }
614        | Expr::Literal(_)
615        | Expr::Placeholder(_)
616        | Expr::Column(_) => {}
617        // v7.10.10 — recurse into array constructor children +
618        // subscript / ANY/ALL operands.
619        Expr::Array(items) => {
620            for elem in items {
621                collect_aggregates(elem, out);
622            }
623        }
624        Expr::ArraySubscript { target, index } => {
625            collect_aggregates(target, out);
626            collect_aggregates(index, out);
627        }
628        Expr::AnyAll { expr, array, .. } => {
629            collect_aggregates(expr, out);
630            collect_aggregates(array, out);
631        }
632        Expr::Case {
633            operand,
634            branches,
635            else_branch,
636        } => {
637            if let Some(o) = operand {
638                collect_aggregates(o, out);
639            }
640            for (w, t) in branches {
641                collect_aggregates(w, out);
642                collect_aggregates(t, out);
643            }
644            if let Some(e) = else_branch {
645                collect_aggregates(e, out);
646            }
647        }
648    }
649}
650
651fn update_state(
652    st: &mut AggState,
653    name: &str,
654    v: &Value,
655    arg2: Option<&Value>,
656    order_keys: Option<Vec<Value>>,
657) -> Result<(), EvalError> {
658    let is_null = matches!(v, Value::Null);
659    match name {
660        "count_star" => st.count += 1,
661        "count" => {
662            if !is_null {
663                st.count += 1;
664            }
665        }
666        "sum" | "avg" => {
667            if is_null {
668                return Ok(());
669            }
670            st.count += 1;
671            match v {
672                Value::Int(n) => st.sum_int += i64::from(*n),
673                Value::BigInt(n) => st.sum_int += *n,
674                Value::Float(x) => {
675                    st.use_float = true;
676                    st.sum_float += *x;
677                }
678                other => {
679                    return Err(EvalError::TypeMismatch {
680                        detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
681                    });
682                }
683            }
684        }
685        "min" => {
686            if is_null {
687                return Ok(());
688            }
689            match &st.extreme {
690                None => st.extreme = Some(v.clone()),
691                Some(cur) => {
692                    if value_cmp(v, cur) == core::cmp::Ordering::Less {
693                        st.extreme = Some(v.clone());
694                    }
695                }
696            }
697        }
698        "max" => {
699            if is_null {
700                return Ok(());
701            }
702            match &st.extreme {
703                None => st.extreme = Some(v.clone()),
704                Some(cur) => {
705                    if value_cmp(v, cur) == core::cmp::Ordering::Greater {
706                        st.extreme = Some(v.clone());
707                    }
708                }
709            }
710        }
711        // v7.17.0 — string_agg(value, separator). NULL value is
712        // skipped (PG aggregate-skip-null). Separator captured
713        // from the latest row that flows through; matches PG's
714        // semantics of evaluating the separator per row but using
715        // the last value at finalize time (in practice it's
716        // constant). count is bumped so we can distinguish "empty
717        // group → NULL" from "all-NULL group → NULL".
718        "string_agg" => {
719            if let Some(sep) = arg2
720                && let Value::Text(s) = sep
721            {
722                st.separator = Some(s.clone());
723            }
724            if is_null {
725                return Ok(());
726            }
727            if let Value::Text(s) = v {
728                st.items.push(Value::Text(s.clone()));
729                if let Some(k) = order_keys {
730                    st.item_keys.push(k);
731                }
732                st.count += 1;
733            } else {
734                return Err(EvalError::TypeMismatch {
735                    detail: format!("string_agg requires text value, got {:?}", v.data_type()),
736                });
737            }
738        }
739        // v7.17.0 — array_agg(value). Unlike string_agg, NULL
740        // elements are KEPT in the array (PG behaviour); the
741        // result is NULL only when ZERO rows fed in. Element type
742        // is locked from the first row's value type; subsequent
743        // rows must match (PG also rejects mixed-type array_agg).
744        "array_agg" => {
745            st.items.push(v.clone());
746            if let Some(k) = order_keys {
747                st.item_keys.push(k);
748            }
749            st.count += 1;
750        }
751        // v7.17.0 — bool_and(p): TRUE iff every non-NULL input is
752        // TRUE. NULL skipped; running accumulator stays at TRUE
753        // until the first non-NULL FALSE.
754        "bool_and" => {
755            if is_null {
756                return Ok(());
757            }
758            let b = match v {
759                Value::Bool(b) => *b,
760                other => {
761                    return Err(EvalError::TypeMismatch {
762                        detail: format!("bool_and requires bool, got {:?}", other.data_type()),
763                    });
764                }
765            };
766            st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
767        }
768        // v7.17.0 — bool_or(p): TRUE iff any non-NULL input is
769        // TRUE. NULL skipped.
770        "bool_or" => {
771            if is_null {
772                return Ok(());
773            }
774            let b = match v {
775                Value::Bool(b) => *b,
776                other => {
777                    return Err(EvalError::TypeMismatch {
778                        detail: format!("bool_or requires bool, got {:?}", other.data_type()),
779                    });
780                }
781            };
782            st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
783        }
784        _ => unreachable!("non-aggregate {name} in update_state"),
785    }
786    Ok(())
787}
788
789#[allow(clippy::cast_precision_loss)]
790fn finalize(name: &str, st: &AggState) -> Value {
791    match name {
792        "count" | "count_star" => Value::BigInt(st.count),
793        "sum" => {
794            if st.count == 0 {
795                Value::Null
796            } else if st.use_float {
797                Value::Float(st.sum_float + (st.sum_int as f64))
798            } else {
799                Value::BigInt(st.sum_int)
800            }
801        }
802        "avg" => {
803            if st.count == 0 {
804                Value::Null
805            } else {
806                let total = if st.use_float {
807                    st.sum_float + (st.sum_int as f64)
808                } else {
809                    st.sum_int as f64
810                };
811                Value::Float(total / (st.count as f64))
812            }
813        }
814        "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
815        // v7.17.0 — string_agg: join all collected text items with
816        // the captured separator. Empty / all-NULL group → NULL
817        // (PG semantics).
818        "string_agg" => {
819            if st.items.is_empty() {
820                return Value::Null;
821            }
822            let sep = st.separator.clone().unwrap_or_default();
823            let mut out = String::new();
824            for (i, item) in st.items.iter().enumerate() {
825                if i > 0 {
826                    out.push_str(&sep);
827                }
828                if let Value::Text(s) = item {
829                    out.push_str(s);
830                }
831            }
832            Value::Text(out)
833        }
834        // v7.17.0 — array_agg: collect into a typed array. NULL
835        // elements are preserved per PG. Result type is decided
836        // by the first non-NULL element seen (or Text fallback
837        // when the whole group is NULL — PG would surface the
838        // declared input type, but SPG hasn't yet wired the
839        // aggregate's static input-type from `describe`).
840        "array_agg" => {
841            if st.items.is_empty() {
842                return Value::Null;
843            }
844            let probe = st.items.iter().find(|v| !v.is_null());
845            match probe.and_then(spg_storage::Value::data_type) {
846                Some(DataType::Int) | Some(DataType::SmallInt) => {
847                    let items: Vec<Option<i32>> = st
848                        .items
849                        .iter()
850                        .map(|v| match v {
851                            Value::Int(n) => Some(*n),
852                            Value::SmallInt(n) => Some(i32::from(*n)),
853                            _ => None,
854                        })
855                        .collect();
856                    Value::IntArray(items)
857                }
858                Some(DataType::BigInt) => {
859                    let items: Vec<Option<i64>> = st
860                        .items
861                        .iter()
862                        .map(|v| match v {
863                            Value::BigInt(n) => Some(*n),
864                            _ => None,
865                        })
866                        .collect();
867                    Value::BigIntArray(items)
868                }
869                _ => {
870                    let items: Vec<Option<String>> = st
871                        .items
872                        .iter()
873                        .map(|v| match v {
874                            Value::Text(s) => Some(s.clone()),
875                            Value::Null => None,
876                            other => Some(format!("{other:?}")),
877                        })
878                        .collect();
879                    Value::TextArray(items)
880                }
881            }
882        }
883        // v7.17.0 — bool_and / bool_or finalize: lazy-init pattern
884        // means `None` is exactly "empty group or all-NULL", which
885        // PG surfaces as SQL NULL.
886        "bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
887        _ => unreachable!(),
888    }
889}
890
891fn infer_agg_type(spec: &AggSpec) -> DataType {
892    match spec.name.as_str() {
893        // count/count_star are exact integer counts; sum widens to BigInt
894        // and reports as such even for Float input (the value column is
895        // nullable so the wire layer surfaces the Float at runtime).
896        "count" | "count_star" | "sum" => DataType::BigInt,
897        "avg" => DataType::Float,
898        // v7.17.0 — string_agg always returns TEXT.
899        "string_agg" => DataType::Text,
900        // v7.17.0 — array_agg's declared output type can't be
901        // known without inspecting the argument's expression
902        // shape. Default to TextArray; finalize widens to
903        // IntArray / BigIntArray when the actual elements are
904        // numeric. Downstream column metadata reports TextArray
905        // which is the lowest common denominator.
906        "array_agg" => DataType::TextArray,
907        // v7.17.0 — boolean aggregates always return BOOL (nullable
908        // — empty / all-NULL group → NULL).
909        "bool_and" | "bool_or" => DataType::Bool,
910        // min/max: we don't know the input type without probing — default
911        // to Text and let downstream rendering coerce.
912        _ => DataType::Text,
913    }
914}
915
916fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
917    if let Expr::Column(c) = e
918        && let Some(s) = synth.iter().find(|s| s.name == c.name)
919    {
920        return s.ty;
921    }
922    // Compound expression — fall back to Text (matches build_projection
923    // behaviour for non-column expressions in the non-aggregate path).
924    DataType::Text
925}
926
927fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
928    // v7.24 (round-16 A) — ordered aggregate: match on the inner
929    // call PLUS the ordering keys.
930    if let Expr::AggregateOrdered { call, order_by } = e
931        && let Expr::FunctionCall { name, args } = call.as_ref()
932    {
933        let lower = name.to_ascii_lowercase();
934        if is_aggregate_name(&lower) {
935            let canonical: &str = if lower == "every" { "bool_and" } else { &lower };
936            let arg = args.first().cloned();
937            let arg2 = if lower == "string_agg" {
938                args.get(1).cloned()
939            } else {
940                None
941            };
942            for (i, spec) in aggs.iter().enumerate() {
943                if spec.name == canonical
944                    && spec.arg == arg
945                    && spec.arg2 == arg2
946                    && spec.order_by == *order_by
947                {
948                    return Expr::Column(spg_sql::ast::ColumnName {
949                        qualifier: None,
950                        name: format!("__agg_{i}"),
951                    });
952                }
953            }
954        }
955    }
956    // Match aggregate FunctionCalls first — they sit outside group_by.
957    if let Expr::FunctionCall { name, args } = e {
958        let lower = name.to_ascii_lowercase();
959        if is_aggregate_name(&lower) {
960            let arg = if lower == "count_star" {
961                None
962            } else {
963                args.first().cloned()
964            };
965            // v7.17.0 — match the spec we registered for
966            // string_agg(value, separator) on the full pair.
967            let arg2 = if lower == "string_agg" {
968                args.get(1).cloned()
969            } else {
970                None
971            };
972            // v7.17.0 — `every` collapses into `bool_and` at
973            // collection; mirror that here so the rewrite finds
974            // the matching synth column.
975            let canonical: &str = if lower == "every" {
976                "bool_and"
977            } else {
978                lower.as_str()
979            };
980            for (i, spec) in aggs.iter().enumerate() {
981                if spec.name == canonical
982                    && spec.arg == arg
983                    && spec.arg2 == arg2
984                    && spec.order_by.is_empty()
985                {
986                    return Expr::Column(spg_sql::ast::ColumnName {
987                        qualifier: None,
988                        name: format!("__agg_{i}"),
989                    });
990                }
991            }
992        }
993    }
994    // Match a group_by expression by AST equality.
995    for (i, g) in group_exprs.iter().enumerate() {
996        if g == e {
997            return Expr::Column(spg_sql::ast::ColumnName {
998                qualifier: None,
999                name: format!("__grp_{i}"),
1000            });
1001        }
1002    }
1003    // Recurse into children.
1004    match e {
1005        Expr::AggregateOrdered { call, order_by } => Expr::AggregateOrdered {
1006            call: Box::new(rewrite_expr(call, group_exprs, aggs)),
1007            order_by: order_by
1008                .iter()
1009                .map(|o| spg_sql::ast::OrderBy {
1010                    expr: rewrite_expr(&o.expr, group_exprs, aggs),
1011                    desc: o.desc,
1012                    nulls_first: o.nulls_first,
1013                })
1014                .collect(),
1015        },
1016        Expr::Binary { lhs, op, rhs } => Expr::Binary {
1017            lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
1018            op: *op,
1019            rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
1020        },
1021        Expr::Unary { op, expr } => Expr::Unary {
1022            op: *op,
1023            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1024        },
1025        Expr::Cast { expr, target } => Expr::Cast {
1026            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1027            target: *target,
1028        },
1029        Expr::IsNull { expr, negated } => Expr::IsNull {
1030            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1031            negated: *negated,
1032        },
1033        Expr::FunctionCall { name, args } => Expr::FunctionCall {
1034            name: name.clone(),
1035            args: args
1036                .iter()
1037                .map(|a| rewrite_expr(a, group_exprs, aggs))
1038                .collect(),
1039        },
1040        Expr::Like {
1041            expr,
1042            pattern,
1043            negated,
1044        } => Expr::Like {
1045            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1046            pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
1047            negated: *negated,
1048        },
1049        Expr::Extract { field, source } => Expr::Extract {
1050            field: *field,
1051            source: Box::new(rewrite_expr(source, group_exprs, aggs)),
1052        },
1053        // v4.10 subquery + v4.12 window / Literal / Column —
1054        // clone-pass (these don't participate in aggregate rewrite).
1055        Expr::ScalarSubquery(_)
1056        | Expr::Exists { .. }
1057        | Expr::InSubquery { .. }
1058        | Expr::WindowFunction { .. }
1059        | Expr::Literal(_)
1060        | Expr::Placeholder(_)
1061        | Expr::Column(_) => e.clone(),
1062        // v7.10.10 — recurse children for array nodes.
1063        Expr::Array(items) => Expr::Array(
1064            items
1065                .iter()
1066                .map(|elem| rewrite_expr(elem, group_exprs, aggs))
1067                .collect(),
1068        ),
1069        Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
1070            target: Box::new(rewrite_expr(target, group_exprs, aggs)),
1071            index: Box::new(rewrite_expr(index, group_exprs, aggs)),
1072        },
1073        Expr::AnyAll {
1074            expr,
1075            op,
1076            array,
1077            is_any,
1078        } => Expr::AnyAll {
1079            expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1080            op: *op,
1081            array: Box::new(rewrite_expr(array, group_exprs, aggs)),
1082            is_any: *is_any,
1083        },
1084        Expr::Case {
1085            operand,
1086            branches,
1087            else_branch,
1088        } => Expr::Case {
1089            operand: operand
1090                .as_deref()
1091                .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
1092            branches: branches
1093                .iter()
1094                .map(|(w, t)| {
1095                    (
1096                        rewrite_expr(w, group_exprs, aggs),
1097                        rewrite_expr(t, group_exprs, aggs),
1098                    )
1099                })
1100                .collect(),
1101            else_branch: else_branch
1102                .as_deref()
1103                .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
1104        },
1105    }
1106}
1107
1108/// Canonical string key for a tuple of group values. Used as map key.
1109fn encode_key(vals: &[Value]) -> String {
1110    let mut out = String::new();
1111    for v in vals {
1112        match v {
1113            Value::Null => out.push_str("N|"),
1114            Value::SmallInt(n) => {
1115                out.push('s');
1116                out.push_str(&n.to_string());
1117                out.push('|');
1118            }
1119            Value::Int(n) => {
1120                out.push('I');
1121                out.push_str(&n.to_string());
1122                out.push('|');
1123            }
1124            Value::BigInt(n) => {
1125                out.push('B');
1126                out.push_str(&n.to_string());
1127                out.push('|');
1128            }
1129            Value::Float(x) => {
1130                out.push('F');
1131                out.push_str(&x.to_string());
1132                out.push('|');
1133            }
1134            Value::Bool(b) => {
1135                out.push(if *b { 'T' } else { 'f' });
1136                out.push('|');
1137            }
1138            Value::Text(s) => {
1139                out.push('S');
1140                out.push_str(s);
1141                out.push('|');
1142            }
1143            Value::Vector(v) => {
1144                out.push('V');
1145                for x in v {
1146                    out.push_str(&x.to_string());
1147                    out.push(',');
1148                }
1149                out.push('|');
1150            }
1151            // v6.0.1: GROUP BY on a `VECTOR(N) USING SQ8` column.
1152            // Two cells with byte-identical `(min, max, bytes)`
1153            // share the same group; equivalence is byte-equality
1154            // (same as f32 grouping today — neither path tries to
1155            // normalise nan/-0).
1156            Value::Sq8Vector(q) => {
1157                out.push('Q');
1158                out.push_str(&q.min.to_string());
1159                out.push('@');
1160                out.push_str(&q.max.to_string());
1161                out.push(':');
1162                for b in &q.bytes {
1163                    out.push_str(&b.to_string());
1164                    out.push(',');
1165                }
1166                out.push('|');
1167            }
1168            // v6.0.3: GROUP BY on a `VECTOR(N) USING HALF` column.
1169            // Byte-equality over the raw u16 bits; matches the SQ8
1170            // path's byte-key model.
1171            Value::HalfVector(h) => {
1172                out.push('H');
1173                for b in &h.bytes {
1174                    out.push_str(&b.to_string());
1175                    out.push(',');
1176                }
1177                out.push('|');
1178            }
1179            Value::Numeric { scaled, scale } => {
1180                out.push('D');
1181                out.push_str(&scaled.to_string());
1182                out.push('@');
1183                out.push_str(&scale.to_string());
1184                out.push('|');
1185            }
1186            Value::Date(d) => {
1187                out.push('d');
1188                out.push_str(&d.to_string());
1189                out.push('|');
1190            }
1191            Value::Timestamp(t) => {
1192                out.push('t');
1193                out.push_str(&t.to_string());
1194                out.push('|');
1195            }
1196            Value::Interval { months, micros } => {
1197                out.push('i');
1198                out.push_str(&months.to_string());
1199                out.push('m');
1200                out.push_str(&micros.to_string());
1201                out.push('|');
1202            }
1203            Value::Json(s) => {
1204                out.push('j');
1205                out.push_str(s);
1206                out.push('|');
1207            }
1208            // v7.5.0 — Value is #[non_exhaustive] for downstream
1209            // forward-compat. Any future variant lacking explicit
1210            // handling here will share a debug-derived group key,
1211            // which is observably wrong but won't crash.
1212            _ => {
1213                out.push('?');
1214                out.push_str(&format!("{v:?}"));
1215                out.push('|');
1216            }
1217        }
1218    }
1219    out
1220}
1221
1222#[allow(clippy::cast_precision_loss)]
1223fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
1224    use core::cmp::Ordering::Equal;
1225    match (a, b) {
1226        (Value::Null, Value::Null) => Equal,
1227        (Value::Null, _) => core::cmp::Ordering::Greater, // NULLs last
1228        (_, Value::Null) => core::cmp::Ordering::Less,
1229        (Value::Int(x), Value::Int(y)) => x.cmp(y),
1230        (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
1231        (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
1232        (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
1233        (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
1234        (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
1235        (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
1236        (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
1237        (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
1238        (Value::Text(x), Value::Text(y)) => x.cmp(y),
1239        (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1240        _ => Equal,
1241    }
1242}