Skip to main content

fsqlite_func/
agg_builtins.rs

1//! Built-in aggregate functions (§13.4).
2//!
3//! Implements: avg, count, group_concat, string_agg, max, min, sum, total,
4//! median, percentile, percentile_cont, percentile_disc.
5//!
6//! # NULL handling
7//! All aggregate functions skip NULL values (except `count(*)` which counts
8//! all rows). Empty-set behavior:
9//! - avg / sum / max / min / median → NULL
10//! - total → 0.0
11//! - count → 0
12#![allow(
13    clippy::unnecessary_literal_bound,
14    clippy::too_many_lines,
15    clippy::cast_possible_truncation,
16    clippy::cast_possible_wrap,
17    clippy::cast_precision_loss,
18    clippy::match_same_arms,
19    clippy::items_after_statements,
20    clippy::float_cmp,
21    clippy::cast_sign_loss,
22    clippy::suboptimal_flops
23)]
24
25use fsqlite_error::{FrankenError, Result};
26use fsqlite_types::SqliteValue;
27
28use crate::{AggregateFunction, FunctionRegistry};
29
30// ─── Kahan compensated summation ──────────────────────────────────────────
31
32/// Kahan-Babuska-Neumaier compensated summation step.  Uses magnitude-aware
33/// error term selection to match the precision behavior of C SQLite's
34/// `kahanBabuskaNeumaierStep` (func.c:1871-1883).
35#[inline]
36fn kahan_add(sum: &mut f64, compensation: &mut f64, value: f64) {
37    let s = *sum;
38    let t = s + value;
39    if s.abs() > value.abs() {
40        *compensation += (s - t) + value;
41    } else {
42        *compensation += (value - t) + s;
43    }
44    *sum = t;
45}
46
47// ═══════════════════════════════════════════════════════════════════════════
48// avg(X)
49// ═══════════════════════════════════════════════════════════════════════════
50
51pub struct AvgState {
52    sum: f64,
53    compensation: f64,
54    count: i64,
55}
56
57pub struct AvgFunc;
58
59impl AggregateFunction for AvgFunc {
60    type State = AvgState;
61
62    fn initial_state(&self) -> Self::State {
63        AvgState {
64            sum: 0.0,
65            compensation: 0.0,
66            count: 0,
67        }
68    }
69
70    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
71        if !args[0].is_null() {
72            kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
73            state.count += 1;
74        }
75        Ok(())
76    }
77
78    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
79        if state.count == 0 {
80            Ok(SqliteValue::Null)
81        } else {
82            Ok(SqliteValue::Float(
83                (state.sum + state.compensation) / state.count as f64,
84            ))
85        }
86    }
87
88    fn num_args(&self) -> i32 {
89        1
90    }
91
92    fn name(&self) -> &str {
93        "avg"
94    }
95}
96
97// ═══════════════════════════════════════════════════════════════════════════
98// count(*) and count(X)
99// ═══════════════════════════════════════════════════════════════════════════
100
101/// `count(*)` — counts all rows including those with NULL values.
102pub struct CountStarFunc;
103
104impl AggregateFunction for CountStarFunc {
105    type State = i64;
106
107    fn initial_state(&self) -> Self::State {
108        0
109    }
110
111    fn step(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
112        *state += 1;
113        Ok(())
114    }
115
116    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
117        Ok(SqliteValue::Integer(state))
118    }
119
120    fn num_args(&self) -> i32 {
121        0 // count(*) takes no column argument
122    }
123
124    fn name(&self) -> &str {
125        "count"
126    }
127}
128
129/// `count(X)` — counts non-NULL values of X.
130pub struct CountFunc;
131
132impl AggregateFunction for CountFunc {
133    type State = i64;
134
135    fn initial_state(&self) -> Self::State {
136        0
137    }
138
139    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
140        if !args[0].is_null() {
141            *state += 1;
142        }
143        Ok(())
144    }
145
146    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
147        Ok(SqliteValue::Integer(state))
148    }
149
150    fn num_args(&self) -> i32 {
151        1
152    }
153
154    fn name(&self) -> &str {
155        "count"
156    }
157}
158
159// ═══════════════════════════════════════════════════════════════════════════
160// group_concat(X [, SEP])
161// ═══════════════════════════════════════════════════════════════════════════
162
163pub struct GroupConcatState {
164    /// Incrementally built result string.  C SQLite appends
165    /// `separator + value` at each step (separator only before 2nd+ value),
166    /// using the separator from *that row's* argument, not a single global one.
167    result: String,
168    has_value: bool,
169}
170
171pub struct GroupConcatFunc;
172
173impl AggregateFunction for GroupConcatFunc {
174    type State = GroupConcatState;
175
176    fn initial_state(&self) -> Self::State {
177        GroupConcatState {
178            result: String::new(),
179            has_value: false,
180        }
181    }
182
183    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
184        if args[0].is_null() {
185            return Ok(());
186        }
187        let sep = if args.len() > 1 {
188            if args[1].is_null() {
189                String::new()
190            } else {
191                args[1].to_text()
192            }
193        } else {
194            ",".to_owned()
195        };
196        if state.has_value {
197            state.result.push_str(&sep);
198        }
199        state.result.push_str(&args[0].to_text());
200        state.has_value = true;
201        Ok(())
202    }
203
204    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
205        if state.has_value {
206            Ok(SqliteValue::Text(state.result.into()))
207        } else {
208            Ok(SqliteValue::Null)
209        }
210    }
211
212    fn num_args(&self) -> i32 {
213        -1 // 1 or 2 args
214    }
215
216    fn name(&self) -> &str {
217        "group_concat"
218    }
219}
220
221// ═══════════════════════════════════════════════════════════════════════════
222// max(X) — aggregate, single arg
223// ═══════════════════════════════════════════════════════════════════════════
224
225pub struct AggMaxFunc;
226
227impl AggregateFunction for AggMaxFunc {
228    type State = Option<SqliteValue>;
229
230    fn initial_state(&self) -> Self::State {
231        None
232    }
233
234    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
235        if args[0].is_null() {
236            return Ok(());
237        }
238        let candidate = &args[0];
239        match state {
240            None => *state = Some(candidate.clone()),
241            Some(current) => {
242                if candidate > current {
243                    *state = Some(candidate.clone());
244                }
245            }
246        }
247        Ok(())
248    }
249
250    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
251        Ok(state.unwrap_or(SqliteValue::Null))
252    }
253
254    fn num_args(&self) -> i32 {
255        1
256    }
257
258    fn name(&self) -> &str {
259        "max"
260    }
261}
262
263// ═══════════════════════════════════════════════════════════════════════════
264// min(X) — aggregate, single arg
265// ═══════════════════════════════════════════════════════════════════════════
266
267pub struct AggMinFunc;
268
269impl AggregateFunction for AggMinFunc {
270    type State = Option<SqliteValue>;
271
272    fn initial_state(&self) -> Self::State {
273        None
274    }
275
276    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
277        if args[0].is_null() {
278            return Ok(());
279        }
280        let candidate = &args[0];
281        match state {
282            None => *state = Some(candidate.clone()),
283            Some(current) => {
284                if candidate < current {
285                    *state = Some(candidate.clone());
286                }
287            }
288        }
289        Ok(())
290    }
291
292    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
293        Ok(state.unwrap_or(SqliteValue::Null))
294    }
295
296    fn num_args(&self) -> i32 {
297        1
298    }
299
300    fn name(&self) -> &str {
301        "min"
302    }
303}
304
305// ═══════════════════════════════════════════════════════════════════════════
306// sum(X)
307// ═══════════════════════════════════════════════════════════════════════════
308
309/// State for `sum()`: tracks whether all values are integers, the running
310/// integer sum, and the float sum as fallback.  Uses Kahan compensated
311/// summation for the float path to match C SQLite's precision.
312pub struct SumState {
313    int_sum: i64,
314    float_sum: f64,
315    float_compensation: f64,
316    all_integer: bool,
317    has_values: bool,
318    overflowed: bool,
319}
320
321pub struct SumFunc;
322
323impl AggregateFunction for SumFunc {
324    type State = SumState;
325
326    fn initial_state(&self) -> Self::State {
327        SumState {
328            int_sum: 0,
329            float_sum: 0.0,
330            float_compensation: 0.0,
331            all_integer: true,
332            has_values: false,
333            overflowed: false,
334        }
335    }
336
337    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
338        if args[0].is_null() {
339            return Ok(());
340        }
341        state.has_values = true;
342        match &args[0] {
343            SqliteValue::Integer(i) => {
344                if state.all_integer && !state.overflowed {
345                    match state.int_sum.checked_add(*i) {
346                        Some(s) => state.int_sum = s,
347                        None => state.overflowed = true,
348                    }
349                }
350                kahan_add(
351                    &mut state.float_sum,
352                    &mut state.float_compensation,
353                    *i as f64,
354                );
355            }
356            other => {
357                state.all_integer = false;
358                kahan_add(
359                    &mut state.float_sum,
360                    &mut state.float_compensation,
361                    other.to_float(),
362                );
363            }
364        }
365        Ok(())
366    }
367
368    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
369        if !state.has_values {
370            return Ok(SqliteValue::Null);
371        }
372        if state.overflowed {
373            return Err(FrankenError::IntegerOverflow);
374        }
375        if state.all_integer {
376            Ok(SqliteValue::Integer(state.int_sum))
377        } else {
378            Ok(SqliteValue::Float(
379                state.float_sum + state.float_compensation,
380            ))
381        }
382    }
383
384    fn num_args(&self) -> i32 {
385        1
386    }
387
388    fn name(&self) -> &str {
389        "sum"
390    }
391}
392
393// ═══════════════════════════════════════════════════════════════════════════
394// total(X) — always returns float, 0.0 for empty set, never overflows.
395// ═══════════════════════════════════════════════════════════════════════════
396
397pub struct TotalFunc;
398
399/// State for `total()`: Kahan compensated accumulator.
400pub struct TotalState {
401    sum: f64,
402    compensation: f64,
403}
404
405impl AggregateFunction for TotalFunc {
406    type State = TotalState;
407
408    fn initial_state(&self) -> Self::State {
409        TotalState {
410            sum: 0.0,
411            compensation: 0.0,
412        }
413    }
414
415    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
416        if !args[0].is_null() {
417            kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
418        }
419        Ok(())
420    }
421
422    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
423        Ok(SqliteValue::Float(state.sum + state.compensation))
424    }
425
426    fn num_args(&self) -> i32 {
427        1
428    }
429
430    fn name(&self) -> &str {
431        "total"
432    }
433}
434
435// ═══════════════════════════════════════════════════════════════════════════
436// median(X) — equivalent to percentile_cont(X, 0.5)
437// ═══════════════════════════════════════════════════════════════════════════
438
439pub struct MedianFunc;
440
441impl AggregateFunction for MedianFunc {
442    type State = Vec<f64>;
443
444    fn initial_state(&self) -> Self::State {
445        Vec::new()
446    }
447
448    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
449        if !args[0].is_null() {
450            state.push(args[0].to_float());
451        }
452        Ok(())
453    }
454
455    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
456        if state.is_empty() {
457            return Ok(SqliteValue::Null);
458        }
459        state.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
460        let result = percentile_cont_impl(&state, 0.5);
461        Ok(SqliteValue::Float(result))
462    }
463
464    fn num_args(&self) -> i32 {
465        1
466    }
467
468    fn name(&self) -> &str {
469        "median"
470    }
471}
472
473// ═══════════════════════════════════════════════════════════════════════════
474// percentile(Y, P) — P in 0..100
475// ═══════════════════════════════════════════════════════════════════════════
476
477pub struct PercentileState {
478    values: Vec<f64>,
479    p: Option<f64>,
480}
481
482pub struct PercentileFunc;
483
484impl AggregateFunction for PercentileFunc {
485    type State = PercentileState;
486
487    fn initial_state(&self) -> Self::State {
488        PercentileState {
489            values: Vec::new(),
490            p: None,
491        }
492    }
493
494    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
495        if !args[0].is_null() {
496            state.values.push(args[0].to_float());
497        }
498        // Capture P from the second argument (constant expression).
499        if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
500            state.p = Some(args[1].to_float());
501        }
502        Ok(())
503    }
504
505    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
506        if state.values.is_empty() {
507            return Ok(SqliteValue::Null);
508        }
509        let p = state.p.unwrap_or(50.0);
510        state
511            .values
512            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
513        // Convert P from 0-100 to 0-1 for the shared implementation.
514        let result = percentile_cont_impl(&state.values, p / 100.0);
515        Ok(SqliteValue::Float(result))
516    }
517
518    fn num_args(&self) -> i32 {
519        2
520    }
521
522    fn name(&self) -> &str {
523        "percentile"
524    }
525}
526
527// ═══════════════════════════════════════════════════════════════════════════
528// percentile_cont(Y, P) — P in 0..1, continuous interpolation
529// ═══════════════════════════════════════════════════════════════════════════
530
531pub struct PercentileContFunc;
532
533impl AggregateFunction for PercentileContFunc {
534    type State = PercentileState;
535
536    fn initial_state(&self) -> Self::State {
537        PercentileState {
538            values: Vec::new(),
539            p: None,
540        }
541    }
542
543    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
544        if !args[0].is_null() {
545            state.values.push(args[0].to_float());
546        }
547        if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
548            state.p = Some(args[1].to_float());
549        }
550        Ok(())
551    }
552
553    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
554        if state.values.is_empty() {
555            return Ok(SqliteValue::Null);
556        }
557        let p = state.p.unwrap_or(0.5);
558        state
559            .values
560            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
561        let result = percentile_cont_impl(&state.values, p);
562        Ok(SqliteValue::Float(result))
563    }
564
565    fn num_args(&self) -> i32 {
566        2
567    }
568
569    fn name(&self) -> &str {
570        "percentile_cont"
571    }
572}
573
574// ═══════════════════════════════════════════════════════════════════════════
575// percentile_disc(Y, P) — P in 0..1, discrete (returns actual value)
576// ═══════════════════════════════════════════════════════════════════════════
577
578pub struct PercentileDiscFunc;
579
580impl AggregateFunction for PercentileDiscFunc {
581    type State = PercentileState;
582
583    fn initial_state(&self) -> Self::State {
584        PercentileState {
585            values: Vec::new(),
586            p: None,
587        }
588    }
589
590    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
591        if !args[0].is_null() {
592            state.values.push(args[0].to_float());
593        }
594        if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
595            state.p = Some(args[1].to_float());
596        }
597        Ok(())
598    }
599
600    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
601        if state.values.is_empty() {
602            return Ok(SqliteValue::Null);
603        }
604        let p = state.p.unwrap_or(0.5);
605        let p = if p.is_nan() { 0.5 } else { p.clamp(0.0, 1.0) };
606        state
607            .values
608            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
609        // Discrete: pick the value at the ceiling index.
610        let n = state.values.len();
611        let idx = ((p * n as f64).ceil() as usize)
612            .saturating_sub(1)
613            .min(n - 1);
614        Ok(SqliteValue::Float(state.values[idx]))
615    }
616
617    fn num_args(&self) -> i32 {
618        2
619    }
620
621    fn name(&self) -> &str {
622        "percentile_disc"
623    }
624}
625
626// ── Shared percentile helper ──────────────────────────────────────────────
627
628/// Continuous percentile with linear interpolation.
629/// `sorted` must be sorted ascending. `p` is in [0, 1].
630fn percentile_cont_impl(sorted: &[f64], p: f64) -> f64 {
631    let n = sorted.len();
632    if n == 1 {
633        return sorted[0];
634    }
635    let p = if p.is_nan() { 0.5 } else { p.clamp(0.0, 1.0) };
636    let rank = p * (n - 1) as f64;
637    let lower = rank.floor() as usize;
638    let upper = rank.ceil() as usize;
639    if lower == upper {
640        sorted[lower]
641    } else {
642        let frac = rank - lower as f64;
643        sorted[lower] * (1.0 - frac) + sorted[upper] * frac
644    }
645}
646
647// ── Registration ──────────────────────────────────────────────────────────
648
649/// Register all §13.4 aggregate functions into the given registry.
650pub fn register_aggregate_builtins(registry: &mut FunctionRegistry) {
651    registry.register_aggregate(AvgFunc);
652    registry.register_aggregate(CountStarFunc);
653    registry.register_aggregate(CountFunc);
654    registry.register_aggregate(GroupConcatFunc);
655    registry.register_aggregate(AggMaxFunc);
656    registry.register_aggregate(AggMinFunc);
657    registry.register_aggregate(SumFunc);
658    registry.register_aggregate(TotalFunc);
659    registry.register_aggregate(MedianFunc);
660    registry.register_aggregate(PercentileFunc);
661    registry.register_aggregate(PercentileContFunc);
662    registry.register_aggregate(PercentileDiscFunc);
663
664    // string_agg is an alias for group_concat with mandatory separator.
665    struct StringAggFunc;
666    impl AggregateFunction for StringAggFunc {
667        type State = GroupConcatState;
668
669        fn initial_state(&self) -> Self::State {
670            GroupConcatState {
671                result: String::new(),
672                has_value: false,
673            }
674        }
675
676        fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
677            GroupConcatFunc.step(state, args)
678        }
679
680        fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
681            GroupConcatFunc.finalize(state)
682        }
683
684        fn num_args(&self) -> i32 {
685            2 // string_agg requires separator
686        }
687
688        fn name(&self) -> &str {
689            "string_agg"
690        }
691    }
692    registry.register_aggregate(StringAggFunc);
693}
694
695// ── Tests ─────────────────────────────────────────────────────────────────
696
697#[cfg(test)]
698mod tests {
699    use super::*;
700
701    const EPS: f64 = 1e-12;
702
703    fn int(v: i64) -> SqliteValue {
704        SqliteValue::Integer(v)
705    }
706
707    fn float(v: f64) -> SqliteValue {
708        SqliteValue::Float(v)
709    }
710
711    fn null() -> SqliteValue {
712        SqliteValue::Null
713    }
714
715    fn text(s: &str) -> SqliteValue {
716        SqliteValue::Text(s.into())
717    }
718
719    fn assert_float_eq(result: &SqliteValue, expected: f64) {
720        match result {
721            SqliteValue::Float(v) => {
722                assert!((v - expected).abs() < EPS, "expected {expected}, got {v}");
723            }
724            other => {
725                assert!(
726                    matches!(other, SqliteValue::Float(_)),
727                    "expected Float({expected}), got {other:?}"
728                );
729            }
730        }
731    }
732
733    /// Helper: run an aggregate over a list of single-arg row values.
734    fn run_agg<F: AggregateFunction>(func: &F, rows: &[SqliteValue]) -> SqliteValue {
735        let mut state = func.initial_state();
736        for row in rows {
737            func.step(&mut state, std::slice::from_ref(row)).unwrap();
738        }
739        func.finalize(state).unwrap()
740    }
741
742    /// Helper: run an aggregate over a list of two-arg row values.
743    fn run_agg2<F: AggregateFunction>(
744        func: &F,
745        rows: &[(SqliteValue, SqliteValue)],
746    ) -> SqliteValue {
747        let mut state = func.initial_state();
748        for (a, b) in rows {
749            func.step(&mut state, &[a.clone(), b.clone()]).unwrap();
750        }
751        func.finalize(state).unwrap()
752    }
753
754    // ── avg ───────────────────────────────────────────────────────────
755
756    #[test]
757    fn test_avg_basic() {
758        let r = run_agg(&AvgFunc, &[int(1), int(2), int(3), int(4), int(5)]);
759        assert_float_eq(&r, 3.0);
760    }
761
762    #[test]
763    fn test_avg_with_nulls() {
764        let r = run_agg(&AvgFunc, &[int(1), null(), int(3)]);
765        assert_float_eq(&r, 2.0);
766    }
767
768    #[test]
769    fn test_avg_empty() {
770        let r = run_agg(&AvgFunc, &[]);
771        assert_eq!(r, SqliteValue::Null);
772    }
773
774    #[test]
775    fn test_avg_returns_real() {
776        let r = run_agg(&AvgFunc, &[int(2), int(4)]);
777        assert!(matches!(r, SqliteValue::Float(_)));
778    }
779
780    // ── count ─────────────────────────────────────────────────────────
781
782    #[test]
783    fn test_count_star() {
784        // count(*) counts all rows including NULLs.
785        let mut state = CountStarFunc.initial_state();
786        CountStarFunc.step(&mut state, &[]).unwrap(); // row 1
787        CountStarFunc.step(&mut state, &[]).unwrap(); // row 2
788        CountStarFunc.step(&mut state, &[]).unwrap(); // row 3
789        let r = CountStarFunc.finalize(state).unwrap();
790        assert_eq!(r, int(3));
791    }
792
793    #[test]
794    fn test_count_column() {
795        let r = run_agg(&CountFunc, &[int(1), null(), int(3), null(), int(5)]);
796        assert_eq!(r, int(3));
797    }
798
799    #[test]
800    fn test_count_empty() {
801        let r = run_agg(&CountFunc, &[]);
802        assert_eq!(r, int(0));
803    }
804
805    // ── group_concat ──────────────────────────────────────────────────
806
807    #[test]
808    fn test_group_concat_basic() {
809        let r = run_agg(&GroupConcatFunc, &[text("a"), text("b"), text("c")]);
810        assert_eq!(r, SqliteValue::Text("a,b,c".into()));
811    }
812
813    #[test]
814    fn test_group_concat_custom_sep() {
815        let rows = vec![
816            (text("a"), text("; ")),
817            (text("b"), text("; ")),
818            (text("c"), text("; ")),
819        ];
820        let r = run_agg2(&GroupConcatFunc, &rows);
821        assert_eq!(r, SqliteValue::Text("a; b; c".into()));
822    }
823
824    #[test]
825    fn test_group_concat_null_skipped() {
826        let r = run_agg(&GroupConcatFunc, &[text("a"), null(), text("c")]);
827        assert_eq!(r, SqliteValue::Text("a,c".into()));
828    }
829
830    #[test]
831    fn test_group_concat_empty() {
832        let r = run_agg(&GroupConcatFunc, &[]);
833        assert_eq!(r, SqliteValue::Null);
834    }
835
836    #[test]
837    fn test_group_concat_varying_separator() {
838        // C SQLite uses the separator from each row's argument, not a single
839        // global separator. SELECT group_concat(val, sep) with varying sep
840        // produces a+b*c, not a*b*c (the old bug used the last-seen sep).
841        let rows = vec![
842            (text("a"), text("-")),
843            (text("b"), text("+")),
844            (text("c"), text("*")),
845        ];
846        let r = run_agg2(&GroupConcatFunc, &rows);
847        assert_eq!(r, SqliteValue::Text("a+b*c".into()));
848    }
849
850    #[test]
851    fn test_group_concat_single_value() {
852        let r = run_agg(&GroupConcatFunc, &[text("only")]);
853        assert_eq!(r, SqliteValue::Text("only".into()));
854    }
855
856    // ── max (aggregate) ───────────────────────────────────────────────
857
858    #[test]
859    fn test_max_aggregate() {
860        let r = run_agg(&AggMaxFunc, &[int(3), int(7), int(1), int(5)]);
861        assert_eq!(r, int(7));
862    }
863
864    #[test]
865    fn test_max_aggregate_null_skipped() {
866        let r = run_agg(&AggMaxFunc, &[int(3), null(), int(7), null()]);
867        assert_eq!(r, int(7));
868    }
869
870    #[test]
871    fn test_max_aggregate_empty() {
872        let r = run_agg(&AggMaxFunc, &[]);
873        assert_eq!(r, SqliteValue::Null);
874    }
875
876    // ── min (aggregate) ───────────────────────────────────────────────
877
878    #[test]
879    fn test_min_aggregate() {
880        let r = run_agg(&AggMinFunc, &[int(3), int(7), int(1), int(5)]);
881        assert_eq!(r, int(1));
882    }
883
884    #[test]
885    fn test_min_aggregate_null_skipped() {
886        let r = run_agg(&AggMinFunc, &[int(3), null(), int(1), null()]);
887        assert_eq!(r, int(1));
888    }
889
890    #[test]
891    fn test_min_aggregate_empty() {
892        let r = run_agg(&AggMinFunc, &[]);
893        assert_eq!(r, SqliteValue::Null);
894    }
895
896    // ── sum ───────────────────────────────────────────────────────────
897
898    #[test]
899    fn test_sum_integers() {
900        let r = run_agg(&SumFunc, &[int(1), int(2), int(3)]);
901        assert_eq!(r, int(6));
902    }
903
904    #[test]
905    fn test_sum_reals() {
906        let r = run_agg(&SumFunc, &[float(1.5), float(2.5)]);
907        assert_float_eq(&r, 4.0);
908    }
909
910    #[test]
911    fn test_sum_empty_null() {
912        let r = run_agg(&SumFunc, &[]);
913        assert_eq!(r, SqliteValue::Null);
914    }
915
916    #[test]
917    fn test_sum_overflow_error() {
918        let mut state = SumFunc.initial_state();
919        SumFunc.step(&mut state, &[int(i64::MAX)]).unwrap();
920        SumFunc.step(&mut state, &[int(1)]).unwrap();
921        let err = SumFunc.finalize(state);
922        assert!(err.is_err(), "sum should raise overflow error");
923    }
924
925    #[test]
926    fn test_sum_null_skipped() {
927        let r = run_agg(&SumFunc, &[int(1), null(), int(3)]);
928        assert_eq!(r, int(4));
929    }
930
931    // ── total ─────────────────────────────────────────────────────────
932
933    #[test]
934    fn test_total_basic() {
935        let r = run_agg(&TotalFunc, &[int(1), int(2), int(3)]);
936        assert_float_eq(&r, 6.0);
937    }
938
939    #[test]
940    fn test_total_empty_zero() {
941        let r = run_agg(&TotalFunc, &[]);
942        assert_float_eq(&r, 0.0);
943    }
944
945    #[test]
946    fn test_total_no_overflow() {
947        // total uses f64 and never overflows.
948        let r = run_agg(&TotalFunc, &[int(i64::MAX), int(i64::MAX)]);
949        assert!(matches!(r, SqliteValue::Float(_)));
950    }
951
952    // ── median ────────────────────────────────────────────────────────
953
954    #[test]
955    fn test_median_basic() {
956        let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4), int(5)]);
957        assert_float_eq(&r, 3.0);
958    }
959
960    #[test]
961    fn test_median_even() {
962        let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4)]);
963        assert_float_eq(&r, 2.5);
964    }
965
966    #[test]
967    fn test_median_null_skipped() {
968        let r = run_agg(&MedianFunc, &[int(1), null(), int(3)]);
969        assert_float_eq(&r, 2.0);
970    }
971
972    #[test]
973    fn test_median_empty() {
974        let r = run_agg(&MedianFunc, &[]);
975        assert_eq!(r, SqliteValue::Null);
976    }
977
978    // ── percentile ────────────────────────────────────────────────────
979
980    #[test]
981    fn test_percentile_50() {
982        // percentile(col, 50) = median
983        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
984            (int(1), float(50.0)),
985            (int(2), float(50.0)),
986            (int(3), float(50.0)),
987            (int(4), float(50.0)),
988            (int(5), float(50.0)),
989        ];
990        let r = run_agg2(&PercentileFunc, &rows);
991        assert_float_eq(&r, 3.0);
992    }
993
994    #[test]
995    fn test_percentile_0() {
996        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
997            (int(10), float(0.0)),
998            (int(20), float(0.0)),
999            (int(30), float(0.0)),
1000        ];
1001        let r = run_agg2(&PercentileFunc, &rows);
1002        assert_float_eq(&r, 10.0);
1003    }
1004
1005    #[test]
1006    fn test_percentile_100() {
1007        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1008            (int(10), float(100.0)),
1009            (int(20), float(100.0)),
1010            (int(30), float(100.0)),
1011        ];
1012        let r = run_agg2(&PercentileFunc, &rows);
1013        assert_float_eq(&r, 30.0);
1014    }
1015
1016    // ── percentile_cont ───────────────────────────────────────────────
1017
1018    #[test]
1019    fn test_percentile_cont_basic() {
1020        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1021            (int(1), float(0.5)),
1022            (int(2), float(0.5)),
1023            (int(3), float(0.5)),
1024            (int(4), float(0.5)),
1025            (int(5), float(0.5)),
1026        ];
1027        let r = run_agg2(&PercentileContFunc, &rows);
1028        assert_float_eq(&r, 3.0);
1029    }
1030
1031    // ── percentile_disc ───────────────────────────────────────────────
1032
1033    #[test]
1034    fn test_percentile_disc_basic() {
1035        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1036            (int(1), float(0.5)),
1037            (int(2), float(0.5)),
1038            (int(3), float(0.5)),
1039            (int(4), float(0.5)),
1040            (int(5), float(0.5)),
1041        ];
1042        let r = run_agg2(&PercentileDiscFunc, &rows);
1043        // Discrete: returns an actual input value.
1044        match r {
1045            SqliteValue::Float(v) => {
1046                // Should be one of the actual input values (3.0 for 0.5 in 5 items).
1047                assert!(
1048                    [1.0, 2.0, 3.0, 4.0, 5.0].contains(&v),
1049                    "expected actual value, got {v}"
1050                );
1051            }
1052            other => {
1053                assert!(
1054                    matches!(other, SqliteValue::Float(_)),
1055                    "expected Float, got {other:?}"
1056                );
1057            }
1058        }
1059    }
1060
1061    #[test]
1062    fn test_percentile_disc_no_interpolation() {
1063        // With 4 items at p=0.5, cont would interpolate, disc should not.
1064        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1065            (int(10), float(0.5)),
1066            (int(20), float(0.5)),
1067            (int(30), float(0.5)),
1068            (int(40), float(0.5)),
1069        ];
1070        let r = run_agg2(&PercentileDiscFunc, &rows);
1071        match r {
1072            SqliteValue::Float(v) => {
1073                // Must be one of {10, 20, 30, 40}, not 25.0.
1074                assert!(
1075                    [10.0, 20.0, 30.0, 40.0].contains(&v),
1076                    "disc must not interpolate: got {v}"
1077                );
1078            }
1079            other => {
1080                assert!(
1081                    matches!(other, SqliteValue::Float(_)),
1082                    "expected Float, got {other:?}"
1083                );
1084            }
1085        }
1086    }
1087
1088    // ── string_agg (alias) ────────────────────────────────────────────
1089
1090    #[test]
1091    fn test_string_agg_alias() {
1092        let mut reg = FunctionRegistry::new();
1093        register_aggregate_builtins(&mut reg);
1094        let sa = reg
1095            .find_aggregate("string_agg", 2)
1096            .expect("string_agg registered");
1097        let mut state = sa.initial_state();
1098        sa.step(&mut state, &[text("a"), text(",")]).unwrap();
1099        sa.step(&mut state, &[text("b"), text(",")]).unwrap();
1100        let r = sa.finalize(state).unwrap();
1101        assert_eq!(r, SqliteValue::Text("a,b".into()));
1102    }
1103
1104    // ── registration ──────────────────────────────────────────────────
1105
1106    #[test]
1107    fn test_register_aggregate_builtins_all_present() {
1108        let mut reg = FunctionRegistry::new();
1109        register_aggregate_builtins(&mut reg);
1110
1111        let expected = [
1112            ("avg", 1),
1113            ("count", 0), // count(*)
1114            ("count", 1), // count(X)
1115            ("max", 1),
1116            ("min", 1),
1117            ("sum", 1),
1118            ("total", 1),
1119            ("median", 1),
1120            ("percentile", 2),
1121            ("percentile_cont", 2),
1122            ("percentile_disc", 2),
1123            ("string_agg", 2),
1124        ];
1125
1126        for (name, arity) in expected {
1127            assert!(
1128                reg.find_aggregate(name, arity).is_some(),
1129                "aggregate '{name}/{arity}' not registered"
1130            );
1131        }
1132
1133        // group_concat is variadic
1134        assert!(reg.find_aggregate("group_concat", 1).is_some());
1135        assert!(reg.find_aggregate("group_concat", 2).is_some());
1136    }
1137
1138    // ── E2E: full lifecycle through registry ──────────────────────────
1139
1140    #[test]
1141    fn test_e2e_registry_invoke_aggregates() {
1142        let mut reg = FunctionRegistry::new();
1143        register_aggregate_builtins(&mut reg);
1144
1145        // avg through registry
1146        let avg = reg.find_aggregate("avg", 1).unwrap();
1147        let mut state = avg.initial_state();
1148        avg.step(&mut state, &[int(10)]).unwrap();
1149        avg.step(&mut state, &[int(20)]).unwrap();
1150        avg.step(&mut state, &[int(30)]).unwrap();
1151        let r = avg.finalize(state).unwrap();
1152        assert_float_eq(&r, 20.0);
1153
1154        // sum through registry
1155        let sum = reg.find_aggregate("sum", 1).unwrap();
1156        let mut state = sum.initial_state();
1157        sum.step(&mut state, &[int(1)]).unwrap();
1158        sum.step(&mut state, &[int(2)]).unwrap();
1159        sum.step(&mut state, &[int(3)]).unwrap();
1160        let r = sum.finalize(state).unwrap();
1161        assert_eq!(r, int(6));
1162    }
1163}