Skip to main content

fsqlite_func/
agg_builtins.rs

1//! Built-in aggregate functions (§13.4).
2//!
3//! Implements: avg, count, group_concat, string_agg, max, min, sum, total,
4//! median, percentile, percentile_cont, percentile_disc.
5//!
6//! # NULL handling
7//! All aggregate functions skip NULL values (except `count(*)` which counts
8//! all rows). Empty-set behavior:
9//! - avg / sum / max / min / median → NULL
10//! - total → 0.0
11//! - count → 0
12#![allow(
13    clippy::unnecessary_literal_bound,
14    clippy::too_many_lines,
15    clippy::cast_possible_truncation,
16    clippy::cast_possible_wrap,
17    clippy::cast_precision_loss,
18    clippy::match_same_arms,
19    clippy::items_after_statements,
20    clippy::float_cmp,
21    clippy::cast_sign_loss,
22    clippy::suboptimal_flops
23)]
24
25use fsqlite_error::{FrankenError, Result};
26use fsqlite_types::SqliteValue;
27
28use crate::{AggregateFunction, FunctionRegistry};
29
30// ─── Kahan compensated summation ──────────────────────────────────────────
31
32/// Kahan compensated summation step.  Maintains a running compensation term
33/// that captures the low-order bits lost during each addition, matching the
34/// precision behavior of C SQLite's aggregate accumulator.
35#[inline]
36fn kahan_add(sum: &mut f64, compensation: &mut f64, value: f64) {
37    let y = value - *compensation;
38    let t = *sum + y;
39    *compensation = (t - *sum) - y;
40    *sum = t;
41}
42
43// ═══════════════════════════════════════════════════════════════════════════
44// avg(X)
45// ═══════════════════════════════════════════════════════════════════════════
46
47pub struct AvgState {
48    sum: f64,
49    compensation: f64,
50    count: i64,
51}
52
53pub struct AvgFunc;
54
55impl AggregateFunction for AvgFunc {
56    type State = AvgState;
57
58    fn initial_state(&self) -> Self::State {
59        AvgState {
60            sum: 0.0,
61            compensation: 0.0,
62            count: 0,
63        }
64    }
65
66    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
67        if !args[0].is_null() {
68            kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
69            state.count += 1;
70        }
71        Ok(())
72    }
73
74    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
75        if state.count == 0 {
76            Ok(SqliteValue::Null)
77        } else {
78            Ok(SqliteValue::Float(state.sum / state.count as f64))
79        }
80    }
81
82    fn num_args(&self) -> i32 {
83        1
84    }
85
86    fn name(&self) -> &str {
87        "avg"
88    }
89}
90
91// ═══════════════════════════════════════════════════════════════════════════
92// count(*) and count(X)
93// ═══════════════════════════════════════════════════════════════════════════
94
95/// `count(*)` — counts all rows including those with NULL values.
96pub struct CountStarFunc;
97
98impl AggregateFunction for CountStarFunc {
99    type State = i64;
100
101    fn initial_state(&self) -> Self::State {
102        0
103    }
104
105    fn step(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
106        *state += 1;
107        Ok(())
108    }
109
110    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
111        Ok(SqliteValue::Integer(state))
112    }
113
114    fn num_args(&self) -> i32 {
115        0 // count(*) takes no column argument
116    }
117
118    fn name(&self) -> &str {
119        "count"
120    }
121}
122
123/// `count(X)` — counts non-NULL values of X.
124pub struct CountFunc;
125
126impl AggregateFunction for CountFunc {
127    type State = i64;
128
129    fn initial_state(&self) -> Self::State {
130        0
131    }
132
133    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
134        if !args[0].is_null() {
135            *state += 1;
136        }
137        Ok(())
138    }
139
140    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
141        Ok(SqliteValue::Integer(state))
142    }
143
144    fn num_args(&self) -> i32 {
145        1
146    }
147
148    fn name(&self) -> &str {
149        "count"
150    }
151}
152
153// ═══════════════════════════════════════════════════════════════════════════
154// group_concat(X [, SEP])
155// ═══════════════════════════════════════════════════════════════════════════
156
157pub struct GroupConcatState {
158    values: Vec<String>,
159    separator: String,
160}
161
162pub struct GroupConcatFunc;
163
164impl AggregateFunction for GroupConcatFunc {
165    type State = GroupConcatState;
166
167    fn initial_state(&self) -> Self::State {
168        GroupConcatState {
169            values: Vec::new(),
170            separator: ",".to_owned(),
171        }
172    }
173
174    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
175        if args[0].is_null() {
176            return Ok(());
177        }
178        // Set separator from second arg on EVERY call if provided, since it's an expression
179        // that could conceptually change, though usually it's a constant.
180        if args.len() > 1 && !args[1].is_null() {
181            state.separator = args[1].to_text();
182        }
183        state.values.push(args[0].to_text());
184        Ok(())
185    }
186
187    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
188        if state.values.is_empty() {
189            Ok(SqliteValue::Null)
190        } else {
191            Ok(SqliteValue::Text(state.values.join(&state.separator)))
192        }
193    }
194
195    fn num_args(&self) -> i32 {
196        -1 // 1 or 2 args
197    }
198
199    fn name(&self) -> &str {
200        "group_concat"
201    }
202}
203
204// ═══════════════════════════════════════════════════════════════════════════
205// max(X) — aggregate, single arg
206// ═══════════════════════════════════════════════════════════════════════════
207
208pub struct AggMaxFunc;
209
210impl AggregateFunction for AggMaxFunc {
211    type State = Option<SqliteValue>;
212
213    fn initial_state(&self) -> Self::State {
214        None
215    }
216
217    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
218        if args[0].is_null() {
219            return Ok(());
220        }
221        let candidate = &args[0];
222        match state {
223            None => *state = Some(candidate.clone()),
224            Some(current) => {
225                if candidate > current {
226                    *state = Some(candidate.clone());
227                }
228            }
229        }
230        Ok(())
231    }
232
233    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
234        Ok(state.unwrap_or(SqliteValue::Null))
235    }
236
237    fn num_args(&self) -> i32 {
238        1
239    }
240
241    fn name(&self) -> &str {
242        "max"
243    }
244}
245
246// ═══════════════════════════════════════════════════════════════════════════
247// min(X) — aggregate, single arg
248// ═══════════════════════════════════════════════════════════════════════════
249
250pub struct AggMinFunc;
251
252impl AggregateFunction for AggMinFunc {
253    type State = Option<SqliteValue>;
254
255    fn initial_state(&self) -> Self::State {
256        None
257    }
258
259    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
260        if args[0].is_null() {
261            return Ok(());
262        }
263        let candidate = &args[0];
264        match state {
265            None => *state = Some(candidate.clone()),
266            Some(current) => {
267                if candidate < current {
268                    *state = Some(candidate.clone());
269                }
270            }
271        }
272        Ok(())
273    }
274
275    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
276        Ok(state.unwrap_or(SqliteValue::Null))
277    }
278
279    fn num_args(&self) -> i32 {
280        1
281    }
282
283    fn name(&self) -> &str {
284        "min"
285    }
286}
287
288// ═══════════════════════════════════════════════════════════════════════════
289// sum(X)
290// ═══════════════════════════════════════════════════════════════════════════
291
292/// State for `sum()`: tracks whether all values are integers, the running
293/// integer sum, and the float sum as fallback.  Uses Kahan compensated
294/// summation for the float path to match C SQLite's precision.
295pub struct SumState {
296    int_sum: i64,
297    float_sum: f64,
298    float_compensation: f64,
299    all_integer: bool,
300    has_values: bool,
301    overflowed: bool,
302}
303
304pub struct SumFunc;
305
306impl AggregateFunction for SumFunc {
307    type State = SumState;
308
309    fn initial_state(&self) -> Self::State {
310        SumState {
311            int_sum: 0,
312            float_sum: 0.0,
313            float_compensation: 0.0,
314            all_integer: true,
315            has_values: false,
316            overflowed: false,
317        }
318    }
319
320    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
321        if args[0].is_null() {
322            return Ok(());
323        }
324        state.has_values = true;
325        match &args[0] {
326            SqliteValue::Integer(i) => {
327                if state.all_integer && !state.overflowed {
328                    match state.int_sum.checked_add(*i) {
329                        Some(s) => state.int_sum = s,
330                        None => state.overflowed = true,
331                    }
332                }
333                kahan_add(
334                    &mut state.float_sum,
335                    &mut state.float_compensation,
336                    *i as f64,
337                );
338            }
339            other => {
340                state.all_integer = false;
341                kahan_add(
342                    &mut state.float_sum,
343                    &mut state.float_compensation,
344                    other.to_float(),
345                );
346            }
347        }
348        Ok(())
349    }
350
351    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
352        if !state.has_values {
353            return Ok(SqliteValue::Null);
354        }
355        if state.overflowed {
356            return Err(FrankenError::IntegerOverflow);
357        }
358        if state.all_integer {
359            Ok(SqliteValue::Integer(state.int_sum))
360        } else {
361            Ok(SqliteValue::Float(state.float_sum))
362        }
363    }
364
365    fn num_args(&self) -> i32 {
366        1
367    }
368
369    fn name(&self) -> &str {
370        "sum"
371    }
372}
373
374// ═══════════════════════════════════════════════════════════════════════════
375// total(X) — always returns float, 0.0 for empty set, never overflows.
376// ═══════════════════════════════════════════════════════════════════════════
377
378pub struct TotalFunc;
379
380/// State for `total()`: Kahan compensated accumulator.
381pub struct TotalState {
382    sum: f64,
383    compensation: f64,
384}
385
386impl AggregateFunction for TotalFunc {
387    type State = TotalState;
388
389    fn initial_state(&self) -> Self::State {
390        TotalState {
391            sum: 0.0,
392            compensation: 0.0,
393        }
394    }
395
396    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
397        if !args[0].is_null() {
398            kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
399        }
400        Ok(())
401    }
402
403    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
404        Ok(SqliteValue::Float(state.sum))
405    }
406
407    fn num_args(&self) -> i32 {
408        1
409    }
410
411    fn name(&self) -> &str {
412        "total"
413    }
414}
415
416// ═══════════════════════════════════════════════════════════════════════════
417// median(X) — equivalent to percentile_cont(X, 0.5)
418// ═══════════════════════════════════════════════════════════════════════════
419
420pub struct MedianFunc;
421
422impl AggregateFunction for MedianFunc {
423    type State = Vec<f64>;
424
425    fn initial_state(&self) -> Self::State {
426        Vec::new()
427    }
428
429    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
430        if !args[0].is_null() {
431            state.push(args[0].to_float());
432        }
433        Ok(())
434    }
435
436    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
437        if state.is_empty() {
438            return Ok(SqliteValue::Null);
439        }
440        state.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
441        let result = percentile_cont_impl(&state, 0.5);
442        Ok(SqliteValue::Float(result))
443    }
444
445    fn num_args(&self) -> i32 {
446        1
447    }
448
449    fn name(&self) -> &str {
450        "median"
451    }
452}
453
454// ═══════════════════════════════════════════════════════════════════════════
455// percentile(Y, P) — P in 0..100
456// ═══════════════════════════════════════════════════════════════════════════
457
458pub struct PercentileState {
459    values: Vec<f64>,
460    p: Option<f64>,
461}
462
463pub struct PercentileFunc;
464
465impl AggregateFunction for PercentileFunc {
466    type State = PercentileState;
467
468    fn initial_state(&self) -> Self::State {
469        PercentileState {
470            values: Vec::new(),
471            p: None,
472        }
473    }
474
475    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
476        if !args[0].is_null() {
477            state.values.push(args[0].to_float());
478        }
479        // Capture P from the second argument (constant expression).
480        if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
481            state.p = Some(args[1].to_float());
482        }
483        Ok(())
484    }
485
486    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
487        if state.values.is_empty() {
488            return Ok(SqliteValue::Null);
489        }
490        let p = state.p.unwrap_or(50.0);
491        state
492            .values
493            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
494        // Convert P from 0-100 to 0-1 for the shared implementation.
495        let result = percentile_cont_impl(&state.values, p / 100.0);
496        Ok(SqliteValue::Float(result))
497    }
498
499    fn num_args(&self) -> i32 {
500        2
501    }
502
503    fn name(&self) -> &str {
504        "percentile"
505    }
506}
507
508// ═══════════════════════════════════════════════════════════════════════════
509// percentile_cont(Y, P) — P in 0..1, continuous interpolation
510// ═══════════════════════════════════════════════════════════════════════════
511
512pub struct PercentileContFunc;
513
514impl AggregateFunction for PercentileContFunc {
515    type State = PercentileState;
516
517    fn initial_state(&self) -> Self::State {
518        PercentileState {
519            values: Vec::new(),
520            p: None,
521        }
522    }
523
524    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
525        if !args[0].is_null() {
526            state.values.push(args[0].to_float());
527        }
528        if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
529            state.p = Some(args[1].to_float());
530        }
531        Ok(())
532    }
533
534    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
535        if state.values.is_empty() {
536            return Ok(SqliteValue::Null);
537        }
538        let p = state.p.unwrap_or(0.5);
539        state
540            .values
541            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
542        let result = percentile_cont_impl(&state.values, p);
543        Ok(SqliteValue::Float(result))
544    }
545
546    fn num_args(&self) -> i32 {
547        2
548    }
549
550    fn name(&self) -> &str {
551        "percentile_cont"
552    }
553}
554
555// ═══════════════════════════════════════════════════════════════════════════
556// percentile_disc(Y, P) — P in 0..1, discrete (returns actual value)
557// ═══════════════════════════════════════════════════════════════════════════
558
559pub struct PercentileDiscFunc;
560
561impl AggregateFunction for PercentileDiscFunc {
562    type State = PercentileState;
563
564    fn initial_state(&self) -> Self::State {
565        PercentileState {
566            values: Vec::new(),
567            p: None,
568        }
569    }
570
571    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
572        if !args[0].is_null() {
573            state.values.push(args[0].to_float());
574        }
575        if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
576            state.p = Some(args[1].to_float());
577        }
578        Ok(())
579    }
580
581    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
582        if state.values.is_empty() {
583            return Ok(SqliteValue::Null);
584        }
585        let p = state.p.unwrap_or(0.5);
586        state
587            .values
588            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
589        // Discrete: pick the value at the ceiling index.
590        let n = state.values.len();
591        let idx = ((p * n as f64).ceil() as usize)
592            .saturating_sub(1)
593            .min(n - 1);
594        Ok(SqliteValue::Float(state.values[idx]))
595    }
596
597    fn num_args(&self) -> i32 {
598        2
599    }
600
601    fn name(&self) -> &str {
602        "percentile_disc"
603    }
604}
605
606// ── Shared percentile helper ──────────────────────────────────────────────
607
608/// Continuous percentile with linear interpolation.
609/// `sorted` must be sorted ascending. `p` is in [0, 1].
610fn percentile_cont_impl(sorted: &[f64], p: f64) -> f64 {
611    let n = sorted.len();
612    if n == 1 {
613        return sorted[0];
614    }
615    let p = p.clamp(0.0, 1.0);
616    let rank = p * (n - 1) as f64;
617    let lower = rank.floor() as usize;
618    let upper = rank.ceil() as usize;
619    if lower == upper {
620        sorted[lower]
621    } else {
622        let frac = rank - lower as f64;
623        sorted[lower] * (1.0 - frac) + sorted[upper] * frac
624    }
625}
626
627// ── Registration ──────────────────────────────────────────────────────────
628
629/// Register all §13.4 aggregate functions into the given registry.
630pub fn register_aggregate_builtins(registry: &mut FunctionRegistry) {
631    registry.register_aggregate(AvgFunc);
632    registry.register_aggregate(CountStarFunc);
633    registry.register_aggregate(CountFunc);
634    registry.register_aggregate(GroupConcatFunc);
635    registry.register_aggregate(AggMaxFunc);
636    registry.register_aggregate(AggMinFunc);
637    registry.register_aggregate(SumFunc);
638    registry.register_aggregate(TotalFunc);
639    registry.register_aggregate(MedianFunc);
640    registry.register_aggregate(PercentileFunc);
641    registry.register_aggregate(PercentileContFunc);
642    registry.register_aggregate(PercentileDiscFunc);
643
644    // string_agg is an alias for group_concat with mandatory separator.
645    struct StringAggFunc;
646    impl AggregateFunction for StringAggFunc {
647        type State = GroupConcatState;
648
649        fn initial_state(&self) -> Self::State {
650            GroupConcatState {
651                values: Vec::new(),
652                separator: ",".to_owned(),
653            }
654        }
655
656        fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
657            GroupConcatFunc.step(state, args)
658        }
659
660        fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
661            GroupConcatFunc.finalize(state)
662        }
663
664        fn num_args(&self) -> i32 {
665            2 // string_agg requires separator
666        }
667
668        fn name(&self) -> &str {
669            "string_agg"
670        }
671    }
672    registry.register_aggregate(StringAggFunc);
673}
674
675// ── Tests ─────────────────────────────────────────────────────────────────
676
677#[cfg(test)]
678mod tests {
679    use super::*;
680
681    const EPS: f64 = 1e-12;
682
683    fn int(v: i64) -> SqliteValue {
684        SqliteValue::Integer(v)
685    }
686
687    fn float(v: f64) -> SqliteValue {
688        SqliteValue::Float(v)
689    }
690
691    fn null() -> SqliteValue {
692        SqliteValue::Null
693    }
694
695    fn text(s: &str) -> SqliteValue {
696        SqliteValue::Text(s.to_owned())
697    }
698
699    fn assert_float_eq(result: &SqliteValue, expected: f64) {
700        match result {
701            SqliteValue::Float(v) => {
702                assert!((v - expected).abs() < EPS, "expected {expected}, got {v}");
703            }
704            other => {
705                assert!(
706                    matches!(other, SqliteValue::Float(_)),
707                    "expected Float({expected}), got {other:?}"
708                );
709            }
710        }
711    }
712
713    /// Helper: run an aggregate over a list of single-arg row values.
714    fn run_agg<F: AggregateFunction>(func: &F, rows: &[SqliteValue]) -> SqliteValue {
715        let mut state = func.initial_state();
716        for row in rows {
717            func.step(&mut state, std::slice::from_ref(row)).unwrap();
718        }
719        func.finalize(state).unwrap()
720    }
721
722    /// Helper: run an aggregate over a list of two-arg row values.
723    fn run_agg2<F: AggregateFunction>(
724        func: &F,
725        rows: &[(SqliteValue, SqliteValue)],
726    ) -> SqliteValue {
727        let mut state = func.initial_state();
728        for (a, b) in rows {
729            func.step(&mut state, &[a.clone(), b.clone()]).unwrap();
730        }
731        func.finalize(state).unwrap()
732    }
733
734    // ── avg ───────────────────────────────────────────────────────────
735
736    #[test]
737    fn test_avg_basic() {
738        let r = run_agg(&AvgFunc, &[int(1), int(2), int(3), int(4), int(5)]);
739        assert_float_eq(&r, 3.0);
740    }
741
742    #[test]
743    fn test_avg_with_nulls() {
744        let r = run_agg(&AvgFunc, &[int(1), null(), int(3)]);
745        assert_float_eq(&r, 2.0);
746    }
747
748    #[test]
749    fn test_avg_empty() {
750        let r = run_agg(&AvgFunc, &[]);
751        assert_eq!(r, SqliteValue::Null);
752    }
753
754    #[test]
755    fn test_avg_returns_real() {
756        let r = run_agg(&AvgFunc, &[int(2), int(4)]);
757        assert!(matches!(r, SqliteValue::Float(_)));
758    }
759
760    // ── count ─────────────────────────────────────────────────────────
761
762    #[test]
763    fn test_count_star() {
764        // count(*) counts all rows including NULLs.
765        let mut state = CountStarFunc.initial_state();
766        CountStarFunc.step(&mut state, &[]).unwrap(); // row 1
767        CountStarFunc.step(&mut state, &[]).unwrap(); // row 2
768        CountStarFunc.step(&mut state, &[]).unwrap(); // row 3
769        let r = CountStarFunc.finalize(state).unwrap();
770        assert_eq!(r, int(3));
771    }
772
773    #[test]
774    fn test_count_column() {
775        let r = run_agg(&CountFunc, &[int(1), null(), int(3), null(), int(5)]);
776        assert_eq!(r, int(3));
777    }
778
779    #[test]
780    fn test_count_empty() {
781        let r = run_agg(&CountFunc, &[]);
782        assert_eq!(r, int(0));
783    }
784
785    // ── group_concat ──────────────────────────────────────────────────
786
787    #[test]
788    fn test_group_concat_basic() {
789        let r = run_agg(&GroupConcatFunc, &[text("a"), text("b"), text("c")]);
790        assert_eq!(r, SqliteValue::Text("a,b,c".to_owned()));
791    }
792
793    #[test]
794    fn test_group_concat_custom_sep() {
795        let rows = vec![
796            (text("a"), text("; ")),
797            (text("b"), text("; ")),
798            (text("c"), text("; ")),
799        ];
800        let r = run_agg2(&GroupConcatFunc, &rows);
801        assert_eq!(r, SqliteValue::Text("a; b; c".to_owned()));
802    }
803
804    #[test]
805    fn test_group_concat_null_skipped() {
806        let r = run_agg(&GroupConcatFunc, &[text("a"), null(), text("c")]);
807        assert_eq!(r, SqliteValue::Text("a,c".to_owned()));
808    }
809
810    #[test]
811    fn test_group_concat_empty() {
812        let r = run_agg(&GroupConcatFunc, &[]);
813        assert_eq!(r, SqliteValue::Null);
814    }
815
816    // ── max (aggregate) ───────────────────────────────────────────────
817
818    #[test]
819    fn test_max_aggregate() {
820        let r = run_agg(&AggMaxFunc, &[int(3), int(7), int(1), int(5)]);
821        assert_eq!(r, int(7));
822    }
823
824    #[test]
825    fn test_max_aggregate_null_skipped() {
826        let r = run_agg(&AggMaxFunc, &[int(3), null(), int(7), null()]);
827        assert_eq!(r, int(7));
828    }
829
830    #[test]
831    fn test_max_aggregate_empty() {
832        let r = run_agg(&AggMaxFunc, &[]);
833        assert_eq!(r, SqliteValue::Null);
834    }
835
836    // ── min (aggregate) ───────────────────────────────────────────────
837
838    #[test]
839    fn test_min_aggregate() {
840        let r = run_agg(&AggMinFunc, &[int(3), int(7), int(1), int(5)]);
841        assert_eq!(r, int(1));
842    }
843
844    #[test]
845    fn test_min_aggregate_null_skipped() {
846        let r = run_agg(&AggMinFunc, &[int(3), null(), int(1), null()]);
847        assert_eq!(r, int(1));
848    }
849
850    #[test]
851    fn test_min_aggregate_empty() {
852        let r = run_agg(&AggMinFunc, &[]);
853        assert_eq!(r, SqliteValue::Null);
854    }
855
856    // ── sum ───────────────────────────────────────────────────────────
857
858    #[test]
859    fn test_sum_integers() {
860        let r = run_agg(&SumFunc, &[int(1), int(2), int(3)]);
861        assert_eq!(r, int(6));
862    }
863
864    #[test]
865    fn test_sum_reals() {
866        let r = run_agg(&SumFunc, &[float(1.5), float(2.5)]);
867        assert_float_eq(&r, 4.0);
868    }
869
870    #[test]
871    fn test_sum_empty_null() {
872        let r = run_agg(&SumFunc, &[]);
873        assert_eq!(r, SqliteValue::Null);
874    }
875
876    #[test]
877    fn test_sum_overflow_error() {
878        let mut state = SumFunc.initial_state();
879        SumFunc.step(&mut state, &[int(i64::MAX)]).unwrap();
880        SumFunc.step(&mut state, &[int(1)]).unwrap();
881        let err = SumFunc.finalize(state);
882        assert!(err.is_err(), "sum should raise overflow error");
883    }
884
885    #[test]
886    fn test_sum_null_skipped() {
887        let r = run_agg(&SumFunc, &[int(1), null(), int(3)]);
888        assert_eq!(r, int(4));
889    }
890
891    // ── total ─────────────────────────────────────────────────────────
892
893    #[test]
894    fn test_total_basic() {
895        let r = run_agg(&TotalFunc, &[int(1), int(2), int(3)]);
896        assert_float_eq(&r, 6.0);
897    }
898
899    #[test]
900    fn test_total_empty_zero() {
901        let r = run_agg(&TotalFunc, &[]);
902        assert_float_eq(&r, 0.0);
903    }
904
905    #[test]
906    fn test_total_no_overflow() {
907        // total uses f64 and never overflows.
908        let r = run_agg(&TotalFunc, &[int(i64::MAX), int(i64::MAX)]);
909        assert!(matches!(r, SqliteValue::Float(_)));
910    }
911
912    // ── median ────────────────────────────────────────────────────────
913
914    #[test]
915    fn test_median_basic() {
916        let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4), int(5)]);
917        assert_float_eq(&r, 3.0);
918    }
919
920    #[test]
921    fn test_median_even() {
922        let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4)]);
923        assert_float_eq(&r, 2.5);
924    }
925
926    #[test]
927    fn test_median_null_skipped() {
928        let r = run_agg(&MedianFunc, &[int(1), null(), int(3)]);
929        assert_float_eq(&r, 2.0);
930    }
931
932    #[test]
933    fn test_median_empty() {
934        let r = run_agg(&MedianFunc, &[]);
935        assert_eq!(r, SqliteValue::Null);
936    }
937
938    // ── percentile ────────────────────────────────────────────────────
939
940    #[test]
941    fn test_percentile_50() {
942        // percentile(col, 50) = median
943        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
944            (int(1), float(50.0)),
945            (int(2), float(50.0)),
946            (int(3), float(50.0)),
947            (int(4), float(50.0)),
948            (int(5), float(50.0)),
949        ];
950        let r = run_agg2(&PercentileFunc, &rows);
951        assert_float_eq(&r, 3.0);
952    }
953
954    #[test]
955    fn test_percentile_0() {
956        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
957            (int(10), float(0.0)),
958            (int(20), float(0.0)),
959            (int(30), float(0.0)),
960        ];
961        let r = run_agg2(&PercentileFunc, &rows);
962        assert_float_eq(&r, 10.0);
963    }
964
965    #[test]
966    fn test_percentile_100() {
967        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
968            (int(10), float(100.0)),
969            (int(20), float(100.0)),
970            (int(30), float(100.0)),
971        ];
972        let r = run_agg2(&PercentileFunc, &rows);
973        assert_float_eq(&r, 30.0);
974    }
975
976    // ── percentile_cont ───────────────────────────────────────────────
977
978    #[test]
979    fn test_percentile_cont_basic() {
980        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
981            (int(1), float(0.5)),
982            (int(2), float(0.5)),
983            (int(3), float(0.5)),
984            (int(4), float(0.5)),
985            (int(5), float(0.5)),
986        ];
987        let r = run_agg2(&PercentileContFunc, &rows);
988        assert_float_eq(&r, 3.0);
989    }
990
991    // ── percentile_disc ───────────────────────────────────────────────
992
993    #[test]
994    fn test_percentile_disc_basic() {
995        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
996            (int(1), float(0.5)),
997            (int(2), float(0.5)),
998            (int(3), float(0.5)),
999            (int(4), float(0.5)),
1000            (int(5), float(0.5)),
1001        ];
1002        let r = run_agg2(&PercentileDiscFunc, &rows);
1003        // Discrete: returns an actual input value.
1004        match r {
1005            SqliteValue::Float(v) => {
1006                // Should be one of the actual input values (3.0 for 0.5 in 5 items).
1007                assert!(
1008                    [1.0, 2.0, 3.0, 4.0, 5.0].contains(&v),
1009                    "expected actual value, got {v}"
1010                );
1011            }
1012            other => {
1013                assert!(
1014                    matches!(other, SqliteValue::Float(_)),
1015                    "expected Float, got {other:?}"
1016                );
1017            }
1018        }
1019    }
1020
1021    #[test]
1022    fn test_percentile_disc_no_interpolation() {
1023        // With 4 items at p=0.5, cont would interpolate, disc should not.
1024        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1025            (int(10), float(0.5)),
1026            (int(20), float(0.5)),
1027            (int(30), float(0.5)),
1028            (int(40), float(0.5)),
1029        ];
1030        let r = run_agg2(&PercentileDiscFunc, &rows);
1031        match r {
1032            SqliteValue::Float(v) => {
1033                // Must be one of {10, 20, 30, 40}, not 25.0.
1034                assert!(
1035                    [10.0, 20.0, 30.0, 40.0].contains(&v),
1036                    "disc must not interpolate: got {v}"
1037                );
1038            }
1039            other => {
1040                assert!(
1041                    matches!(other, SqliteValue::Float(_)),
1042                    "expected Float, got {other:?}"
1043                );
1044            }
1045        }
1046    }
1047
1048    // ── string_agg (alias) ────────────────────────────────────────────
1049
1050    #[test]
1051    fn test_string_agg_alias() {
1052        let mut reg = FunctionRegistry::new();
1053        register_aggregate_builtins(&mut reg);
1054        let sa = reg
1055            .find_aggregate("string_agg", 2)
1056            .expect("string_agg registered");
1057        let mut state = sa.initial_state();
1058        sa.step(&mut state, &[text("a"), text(",")]).unwrap();
1059        sa.step(&mut state, &[text("b"), text(",")]).unwrap();
1060        let r = sa.finalize(state).unwrap();
1061        assert_eq!(r, SqliteValue::Text("a,b".to_owned()));
1062    }
1063
1064    // ── registration ──────────────────────────────────────────────────
1065
1066    #[test]
1067    fn test_register_aggregate_builtins_all_present() {
1068        let mut reg = FunctionRegistry::new();
1069        register_aggregate_builtins(&mut reg);
1070
1071        let expected = [
1072            ("avg", 1),
1073            ("count", 0), // count(*)
1074            ("count", 1), // count(X)
1075            ("max", 1),
1076            ("min", 1),
1077            ("sum", 1),
1078            ("total", 1),
1079            ("median", 1),
1080            ("percentile", 2),
1081            ("percentile_cont", 2),
1082            ("percentile_disc", 2),
1083            ("string_agg", 2),
1084        ];
1085
1086        for (name, arity) in expected {
1087            assert!(
1088                reg.find_aggregate(name, arity).is_some(),
1089                "aggregate '{name}/{arity}' not registered"
1090            );
1091        }
1092
1093        // group_concat is variadic
1094        assert!(reg.find_aggregate("group_concat", 1).is_some());
1095        assert!(reg.find_aggregate("group_concat", 2).is_some());
1096    }
1097
1098    // ── E2E: full lifecycle through registry ──────────────────────────
1099
1100    #[test]
1101    fn test_e2e_registry_invoke_aggregates() {
1102        let mut reg = FunctionRegistry::new();
1103        register_aggregate_builtins(&mut reg);
1104
1105        // avg through registry
1106        let avg = reg.find_aggregate("avg", 1).unwrap();
1107        let mut state = avg.initial_state();
1108        avg.step(&mut state, &[int(10)]).unwrap();
1109        avg.step(&mut state, &[int(20)]).unwrap();
1110        avg.step(&mut state, &[int(30)]).unwrap();
1111        let r = avg.finalize(state).unwrap();
1112        assert_float_eq(&r, 20.0);
1113
1114        // sum through registry
1115        let sum = reg.find_aggregate("sum", 1).unwrap();
1116        let mut state = sum.initial_state();
1117        sum.step(&mut state, &[int(1)]).unwrap();
1118        sum.step(&mut state, &[int(2)]).unwrap();
1119        sum.step(&mut state, &[int(3)]).unwrap();
1120        let r = sum.finalize(state).unwrap();
1121        assert_eq!(r, int(6));
1122    }
1123}