Skip to main content

fsqlite_func/
agg_builtins.rs

1//! Built-in aggregate functions (§13.4).
2//!
3//! Implements: avg, count, group_concat, string_agg, max, min, sum, total,
4//! median, percentile, percentile_cont, percentile_disc.
5//!
6//! # NULL handling
7//! All aggregate functions skip NULL values (except `count(*)` which counts
8//! all rows). Empty-set behavior:
9//! - avg / sum / max / min / median → NULL
10//! - total → 0.0
11//! - count → 0
12#![allow(
13    clippy::unnecessary_literal_bound,
14    clippy::too_many_lines,
15    clippy::cast_possible_truncation,
16    clippy::cast_possible_wrap,
17    clippy::cast_precision_loss,
18    clippy::match_same_arms,
19    clippy::items_after_statements,
20    clippy::float_cmp,
21    clippy::cast_sign_loss,
22    clippy::suboptimal_flops
23)]
24
25use fsqlite_error::{FrankenError, Result};
26use fsqlite_types::SqliteValue;
27
28use crate::{AggregateFunction, FunctionRegistry};
29
30// ─── Kahan compensated summation ──────────────────────────────────────────
31
32/// Kahan-Babuska-Neumaier compensated summation step.  Uses magnitude-aware
33/// error term selection to match the precision behavior of C SQLite's
34/// `kahanBabuskaNeumaierStep` (func.c:1871-1883).
35#[inline]
36fn kahan_add(sum: &mut f64, compensation: &mut f64, value: f64) {
37    let s = *sum;
38    let t = s + value;
39    if s.abs() > value.abs() {
40        *compensation += (s - t) + value;
41    } else {
42        *compensation += (value - t) + s;
43    }
44    *sum = t;
45}
46
47// ═══════════════════════════════════════════════════════════════════════════
48// avg(X)
49// ═══════════════════════════════════════════════════════════════════════════
50
51pub struct AvgState {
52    sum: f64,
53    compensation: f64,
54    count: i64,
55}
56
57pub struct AvgFunc;
58
59impl AggregateFunction for AvgFunc {
60    type State = AvgState;
61
62    fn initial_state(&self) -> Self::State {
63        AvgState {
64            sum: 0.0,
65            compensation: 0.0,
66            count: 0,
67        }
68    }
69
70    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
71        if !args[0].is_null() {
72            kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
73            state.count += 1;
74        }
75        Ok(())
76    }
77
78    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
79        if state.count == 0 {
80            Ok(SqliteValue::Null)
81        } else {
82            Ok(SqliteValue::Float(
83                (state.sum + state.compensation) / state.count as f64,
84            ))
85        }
86    }
87
88    fn num_args(&self) -> i32 {
89        1
90    }
91
92    fn name(&self) -> &str {
93        "avg"
94    }
95}
96
97// ═══════════════════════════════════════════════════════════════════════════
98// count(*) and count(X)
99// ═══════════════════════════════════════════════════════════════════════════
100
101/// `count(*)` — counts all rows including those with NULL values.
102pub struct CountStarFunc;
103
104impl AggregateFunction for CountStarFunc {
105    type State = i64;
106
107    fn initial_state(&self) -> Self::State {
108        0
109    }
110
111    fn step(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
112        *state += 1;
113        Ok(())
114    }
115
116    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
117        Ok(SqliteValue::Integer(state))
118    }
119
120    fn num_args(&self) -> i32 {
121        0 // count(*) takes no column argument
122    }
123
124    fn name(&self) -> &str {
125        "count"
126    }
127}
128
129/// `count(X)` — counts non-NULL values of X.
130pub struct CountFunc;
131
132impl AggregateFunction for CountFunc {
133    type State = i64;
134
135    fn initial_state(&self) -> Self::State {
136        0
137    }
138
139    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
140        if !args[0].is_null() {
141            *state += 1;
142        }
143        Ok(())
144    }
145
146    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
147        Ok(SqliteValue::Integer(state))
148    }
149
150    fn num_args(&self) -> i32 {
151        1
152    }
153
154    fn name(&self) -> &str {
155        "count"
156    }
157}
158
159// ═══════════════════════════════════════════════════════════════════════════
160// group_concat(X [, SEP])
161// ═══════════════════════════════════════════════════════════════════════════
162
163pub struct GroupConcatState {
164    /// Incrementally built result string.  C SQLite appends
165    /// `separator + value` at each step (separator only before 2nd+ value),
166    /// using the separator from *that row's* argument, not a single global one.
167    result: String,
168    has_value: bool,
169}
170
171pub struct GroupConcatFunc;
172
173#[inline]
174fn push_group_concat_text(result: &mut String, value: &SqliteValue) {
175    if let Some(text) = value.as_text_str() {
176        result.push_str(text);
177    } else {
178        result.push_str(&value.to_text());
179    }
180}
181
182impl AggregateFunction for GroupConcatFunc {
183    type State = GroupConcatState;
184
185    fn initial_state(&self) -> Self::State {
186        GroupConcatState {
187            result: String::new(),
188            has_value: false,
189        }
190    }
191
192    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
193        if args[0].is_null() {
194            return Ok(());
195        }
196        if state.has_value {
197            match args.get(1) {
198                Some(separator) if !separator.is_null() => {
199                    push_group_concat_text(&mut state.result, separator);
200                }
201                Some(_) => {}
202                None => state.result.push(','),
203            }
204        }
205        push_group_concat_text(&mut state.result, &args[0]);
206        state.has_value = true;
207        Ok(())
208    }
209
210    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
211        if state.has_value {
212            Ok(SqliteValue::Text(state.result.into()))
213        } else {
214            Ok(SqliteValue::Null)
215        }
216    }
217
218    fn num_args(&self) -> i32 {
219        -1 // 1 or 2 args
220    }
221
222    fn min_args(&self) -> i32 {
223        1
224    }
225
226    fn max_args(&self) -> Option<i32> {
227        Some(2)
228    }
229
230    fn name(&self) -> &str {
231        "group_concat"
232    }
233}
234
235// ═══════════════════════════════════════════════════════════════════════════
236// max(X) — aggregate, single arg
237// ═══════════════════════════════════════════════════════════════════════════
238
239pub struct AggMaxFunc;
240
241impl AggregateFunction for AggMaxFunc {
242    type State = Option<SqliteValue>;
243
244    fn initial_state(&self) -> Self::State {
245        None
246    }
247
248    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
249        if args[0].is_null() {
250            return Ok(());
251        }
252        let candidate = &args[0];
253        match state {
254            None => *state = Some(candidate.clone()),
255            Some(current) => {
256                if candidate > current {
257                    *state = Some(candidate.clone());
258                }
259            }
260        }
261        Ok(())
262    }
263
264    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
265        Ok(state.unwrap_or(SqliteValue::Null))
266    }
267
268    fn num_args(&self) -> i32 {
269        1
270    }
271
272    fn name(&self) -> &str {
273        "max"
274    }
275}
276
277// ═══════════════════════════════════════════════════════════════════════════
278// min(X) — aggregate, single arg
279// ═══════════════════════════════════════════════════════════════════════════
280
281pub struct AggMinFunc;
282
283impl AggregateFunction for AggMinFunc {
284    type State = Option<SqliteValue>;
285
286    fn initial_state(&self) -> Self::State {
287        None
288    }
289
290    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
291        if args[0].is_null() {
292            return Ok(());
293        }
294        let candidate = &args[0];
295        match state {
296            None => *state = Some(candidate.clone()),
297            Some(current) => {
298                if candidate < current {
299                    *state = Some(candidate.clone());
300                }
301            }
302        }
303        Ok(())
304    }
305
306    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
307        Ok(state.unwrap_or(SqliteValue::Null))
308    }
309
310    fn num_args(&self) -> i32 {
311        1
312    }
313
314    fn name(&self) -> &str {
315        "min"
316    }
317}
318
319// ═══════════════════════════════════════════════════════════════════════════
320// sum(X)
321// ═══════════════════════════════════════════════════════════════════════════
322
323/// State for `sum()`: tracks whether all values are integers, the running
324/// integer sum, and the float sum as fallback.  Uses Kahan compensated
325/// summation for the float path to match C SQLite's precision.
326pub struct SumState {
327    int_sum: i64,
328    float_sum: f64,
329    float_compensation: f64,
330    all_integer: bool,
331    has_values: bool,
332    overflowed: bool,
333}
334
335pub struct SumFunc;
336
337impl AggregateFunction for SumFunc {
338    type State = SumState;
339
340    fn initial_state(&self) -> Self::State {
341        SumState {
342            int_sum: 0,
343            float_sum: 0.0,
344            float_compensation: 0.0,
345            all_integer: true,
346            has_values: false,
347            overflowed: false,
348        }
349    }
350
351    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
352        let value = args[0].to_sum_numeric_value();
353        if value.is_null() {
354            return Ok(());
355        }
356        state.has_values = true;
357        match value {
358            SqliteValue::Integer(i) => {
359                if state.all_integer && !state.overflowed {
360                    match state.int_sum.checked_add(i) {
361                        Some(s) => state.int_sum = s,
362                        None => state.overflowed = true,
363                    }
364                }
365                kahan_add(
366                    &mut state.float_sum,
367                    &mut state.float_compensation,
368                    i as f64,
369                );
370            }
371            SqliteValue::Float(f) => {
372                state.all_integer = false;
373                kahan_add(&mut state.float_sum, &mut state.float_compensation, f);
374            }
375            SqliteValue::Null | SqliteValue::Text(_) | SqliteValue::Blob(_) => {}
376        }
377        Ok(())
378    }
379
380    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
381        if !state.has_values {
382            return Ok(SqliteValue::Null);
383        }
384        if state.all_integer && state.overflowed {
385            return Err(FrankenError::IntegerOverflow);
386        }
387        if state.all_integer {
388            Ok(SqliteValue::Integer(state.int_sum))
389        } else {
390            Ok(SqliteValue::Float(
391                state.float_sum + state.float_compensation,
392            ))
393        }
394    }
395
396    fn num_args(&self) -> i32 {
397        1
398    }
399
400    fn name(&self) -> &str {
401        "sum"
402    }
403}
404
405// ═══════════════════════════════════════════════════════════════════════════
406// total(X) — always returns float, 0.0 for empty set, never overflows.
407// ═══════════════════════════════════════════════════════════════════════════
408
409pub struct TotalFunc;
410
411/// State for `total()`: Kahan compensated accumulator.
412pub struct TotalState {
413    sum: f64,
414    compensation: f64,
415}
416
417impl AggregateFunction for TotalFunc {
418    type State = TotalState;
419
420    fn initial_state(&self) -> Self::State {
421        TotalState {
422            sum: 0.0,
423            compensation: 0.0,
424        }
425    }
426
427    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
428        if !args[0].is_null() {
429            kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
430        }
431        Ok(())
432    }
433
434    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
435        Ok(SqliteValue::Float(state.sum + state.compensation))
436    }
437
438    fn num_args(&self) -> i32 {
439        1
440    }
441
442    fn name(&self) -> &str {
443        "total"
444    }
445}
446
447// ═══════════════════════════════════════════════════════════════════════════
448// median(X) — equivalent to percentile_cont(X, 0.5)
449// ═══════════════════════════════════════════════════════════════════════════
450
451pub struct MedianFunc;
452
453impl AggregateFunction for MedianFunc {
454    type State = Vec<f64>;
455
456    fn initial_state(&self) -> Self::State {
457        Vec::new()
458    }
459
460    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
461        if !args[0].is_null() {
462            state.push(args[0].to_float());
463        }
464        Ok(())
465    }
466
467    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
468        if state.is_empty() {
469            return Ok(SqliteValue::Null);
470        }
471        state.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
472        let result = percentile_cont_impl(&state, 0.5);
473        Ok(SqliteValue::Float(result))
474    }
475
476    fn num_args(&self) -> i32 {
477        1
478    }
479
480    fn name(&self) -> &str {
481        "median"
482    }
483}
484
485// ═══════════════════════════════════════════════════════════════════════════
486// percentile(Y, P) — P in 0..100
487// ═══════════════════════════════════════════════════════════════════════════
488
489pub struct PercentileState {
490    values: Vec<f64>,
491    p: Option<f64>,
492}
493
494pub struct PercentileFunc;
495
496impl AggregateFunction for PercentileFunc {
497    type State = PercentileState;
498
499    fn initial_state(&self) -> Self::State {
500        PercentileState {
501            values: Vec::new(),
502            p: None,
503        }
504    }
505
506    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
507        if !args[0].is_null() {
508            state.values.push(args[0].to_float());
509        }
510        // Capture P from the second argument (constant expression).
511        if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
512            state.p = Some(args[1].to_float());
513        }
514        Ok(())
515    }
516
517    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
518        if state.values.is_empty() {
519            return Ok(SqliteValue::Null);
520        }
521        let p = state.p.unwrap_or(50.0);
522        state
523            .values
524            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
525        // Convert P from 0-100 to 0-1 for the shared implementation.
526        let result = percentile_cont_impl(&state.values, p / 100.0);
527        Ok(SqliteValue::Float(result))
528    }
529
530    fn num_args(&self) -> i32 {
531        2
532    }
533
534    fn name(&self) -> &str {
535        "percentile"
536    }
537}
538
539// ═══════════════════════════════════════════════════════════════════════════
540// percentile_cont(Y, P) — P in 0..1, continuous interpolation
541// ═══════════════════════════════════════════════════════════════════════════
542
543pub struct PercentileContFunc;
544
545impl AggregateFunction for PercentileContFunc {
546    type State = PercentileState;
547
548    fn initial_state(&self) -> Self::State {
549        PercentileState {
550            values: Vec::new(),
551            p: None,
552        }
553    }
554
555    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
556        if !args[0].is_null() {
557            state.values.push(args[0].to_float());
558        }
559        if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
560            state.p = Some(args[1].to_float());
561        }
562        Ok(())
563    }
564
565    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
566        if state.values.is_empty() {
567            return Ok(SqliteValue::Null);
568        }
569        let p = state.p.unwrap_or(0.5);
570        state
571            .values
572            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
573        let result = percentile_cont_impl(&state.values, p);
574        Ok(SqliteValue::Float(result))
575    }
576
577    fn num_args(&self) -> i32 {
578        2
579    }
580
581    fn name(&self) -> &str {
582        "percentile_cont"
583    }
584}
585
586// ═══════════════════════════════════════════════════════════════════════════
587// percentile_disc(Y, P) — P in 0..1, discrete (returns actual value)
588// ═══════════════════════════════════════════════════════════════════════════
589
590pub struct PercentileDiscFunc;
591
592impl AggregateFunction for PercentileDiscFunc {
593    type State = PercentileState;
594
595    fn initial_state(&self) -> Self::State {
596        PercentileState {
597            values: Vec::new(),
598            p: None,
599        }
600    }
601
602    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
603        if !args[0].is_null() {
604            state.values.push(args[0].to_float());
605        }
606        if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
607            state.p = Some(args[1].to_float());
608        }
609        Ok(())
610    }
611
612    fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
613        if state.values.is_empty() {
614            return Ok(SqliteValue::Null);
615        }
616        let p = state.p.unwrap_or(0.5);
617        let p = if p.is_nan() { 0.5 } else { p.clamp(0.0, 1.0) };
618        state
619            .values
620            .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
621        // Discrete: pick the value at the ceiling index.
622        let n = state.values.len();
623        let idx = ((p * n as f64).ceil() as usize)
624            .saturating_sub(1)
625            .min(n - 1);
626        Ok(SqliteValue::Float(state.values[idx]))
627    }
628
629    fn num_args(&self) -> i32 {
630        2
631    }
632
633    fn name(&self) -> &str {
634        "percentile_disc"
635    }
636}
637
638// ── Shared percentile helper ──────────────────────────────────────────────
639
640/// Continuous percentile with linear interpolation.
641/// `sorted` must be sorted ascending. `p` is in [0, 1].
642fn percentile_cont_impl(sorted: &[f64], p: f64) -> f64 {
643    let n = sorted.len();
644    if n == 1 {
645        return sorted[0];
646    }
647    let p = if p.is_nan() { 0.5 } else { p.clamp(0.0, 1.0) };
648    let rank = p * (n - 1) as f64;
649    let lower = rank.floor() as usize;
650    let upper = rank.ceil() as usize;
651    if lower == upper {
652        sorted[lower]
653    } else {
654        let frac = rank - lower as f64;
655        sorted[lower] * (1.0 - frac) + sorted[upper] * frac
656    }
657}
658
659// ── Registration ──────────────────────────────────────────────────────────
660
661/// Register all §13.4 aggregate functions into the given registry.
662pub fn register_aggregate_builtins(registry: &mut FunctionRegistry) {
663    registry.register_aggregate(AvgFunc);
664    registry.register_aggregate(CountStarFunc);
665    registry.register_aggregate(CountFunc);
666    registry.register_aggregate(GroupConcatFunc);
667    registry.register_aggregate(AggMaxFunc);
668    registry.register_aggregate(AggMinFunc);
669    registry.register_aggregate(SumFunc);
670    registry.register_aggregate(TotalFunc);
671    registry.register_aggregate(MedianFunc);
672    registry.register_aggregate(PercentileFunc);
673    registry.register_aggregate(PercentileContFunc);
674    registry.register_aggregate(PercentileDiscFunc);
675
676    // string_agg is an alias for group_concat with mandatory separator.
677    struct StringAggFunc;
678    impl AggregateFunction for StringAggFunc {
679        type State = GroupConcatState;
680
681        fn initial_state(&self) -> Self::State {
682            GroupConcatState {
683                result: String::new(),
684                has_value: false,
685            }
686        }
687
688        fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
689            GroupConcatFunc.step(state, args)
690        }
691
692        fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
693            GroupConcatFunc.finalize(state)
694        }
695
696        fn num_args(&self) -> i32 {
697            2 // string_agg requires separator
698        }
699
700        fn name(&self) -> &str {
701            "string_agg"
702        }
703    }
704    registry.register_aggregate(StringAggFunc);
705}
706
707// ── Tests ─────────────────────────────────────────────────────────────────
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712
713    const EPS: f64 = 1e-12;
714
715    fn int(v: i64) -> SqliteValue {
716        SqliteValue::Integer(v)
717    }
718
719    fn float(v: f64) -> SqliteValue {
720        SqliteValue::Float(v)
721    }
722
723    fn null() -> SqliteValue {
724        SqliteValue::Null
725    }
726
727    fn text(s: &str) -> SqliteValue {
728        SqliteValue::Text(s.into())
729    }
730
731    fn assert_float_eq(result: &SqliteValue, expected: f64) {
732        match result {
733            SqliteValue::Float(v) => {
734                assert!((v - expected).abs() < EPS, "expected {expected}, got {v}");
735            }
736            other => {
737                assert!(
738                    matches!(other, SqliteValue::Float(_)),
739                    "expected Float({expected}), got {other:?}"
740                );
741            }
742        }
743    }
744
745    /// Helper: run an aggregate over a list of single-arg row values.
746    fn run_agg<F: AggregateFunction>(func: &F, rows: &[SqliteValue]) -> SqliteValue {
747        let mut state = func.initial_state();
748        for row in rows {
749            func.step(&mut state, std::slice::from_ref(row)).unwrap();
750        }
751        func.finalize(state).unwrap()
752    }
753
754    /// Helper: run an aggregate over a list of two-arg row values.
755    fn run_agg2<F: AggregateFunction>(
756        func: &F,
757        rows: &[(SqliteValue, SqliteValue)],
758    ) -> SqliteValue {
759        let mut state = func.initial_state();
760        for (a, b) in rows {
761            func.step(&mut state, &[a.clone(), b.clone()]).unwrap();
762        }
763        func.finalize(state).unwrap()
764    }
765
766    // ── avg ───────────────────────────────────────────────────────────
767
768    #[test]
769    fn test_avg_basic() {
770        let r = run_agg(&AvgFunc, &[int(1), int(2), int(3), int(4), int(5)]);
771        assert_float_eq(&r, 3.0);
772    }
773
774    #[test]
775    fn test_avg_with_nulls() {
776        let r = run_agg(&AvgFunc, &[int(1), null(), int(3)]);
777        assert_float_eq(&r, 2.0);
778    }
779
780    #[test]
781    fn test_avg_empty() {
782        let r = run_agg(&AvgFunc, &[]);
783        assert_eq!(r, SqliteValue::Null);
784    }
785
786    #[test]
787    fn test_avg_returns_real() {
788        let r = run_agg(&AvgFunc, &[int(2), int(4)]);
789        assert!(matches!(r, SqliteValue::Float(_)));
790    }
791
792    // ── count ─────────────────────────────────────────────────────────
793
794    #[test]
795    fn test_count_star() {
796        // count(*) counts all rows including NULLs.
797        let mut state = CountStarFunc.initial_state();
798        CountStarFunc.step(&mut state, &[]).unwrap(); // row 1
799        CountStarFunc.step(&mut state, &[]).unwrap(); // row 2
800        CountStarFunc.step(&mut state, &[]).unwrap(); // row 3
801        let r = CountStarFunc.finalize(state).unwrap();
802        assert_eq!(r, int(3));
803    }
804
805    #[test]
806    fn test_count_column() {
807        let r = run_agg(&CountFunc, &[int(1), null(), int(3), null(), int(5)]);
808        assert_eq!(r, int(3));
809    }
810
811    #[test]
812    fn test_count_empty() {
813        let r = run_agg(&CountFunc, &[]);
814        assert_eq!(r, int(0));
815    }
816
817    // ── group_concat ──────────────────────────────────────────────────
818
819    #[test]
820    fn test_group_concat_basic() {
821        let r = run_agg(&GroupConcatFunc, &[text("a"), text("b"), text("c")]);
822        assert_eq!(r, SqliteValue::Text("a,b,c".into()));
823    }
824
825    #[test]
826    fn test_group_concat_custom_sep() {
827        let rows = vec![
828            (text("a"), text("; ")),
829            (text("b"), text("; ")),
830            (text("c"), text("; ")),
831        ];
832        let r = run_agg2(&GroupConcatFunc, &rows);
833        assert_eq!(r, SqliteValue::Text("a; b; c".into()));
834    }
835
836    #[test]
837    fn test_group_concat_null_skipped() {
838        let r = run_agg(&GroupConcatFunc, &[text("a"), null(), text("c")]);
839        assert_eq!(r, SqliteValue::Text("a,c".into()));
840    }
841
842    #[test]
843    fn test_group_concat_empty() {
844        let r = run_agg(&GroupConcatFunc, &[]);
845        assert_eq!(r, SqliteValue::Null);
846    }
847
848    #[test]
849    fn test_group_concat_varying_separator() {
850        // C SQLite uses the separator from each row's argument, not a single
851        // global separator. SELECT group_concat(val, sep) with varying sep
852        // produces a+b*c, not a*b*c (the old bug used the last-seen sep).
853        let rows = vec![
854            (text("a"), text("-")),
855            (text("b"), text("+")),
856            (text("c"), text("*")),
857        ];
858        let r = run_agg2(&GroupConcatFunc, &rows);
859        assert_eq!(r, SqliteValue::Text("a+b*c".into()));
860    }
861
862    #[test]
863    fn test_group_concat_single_value() {
864        let r = run_agg(&GroupConcatFunc, &[text("only")]);
865        assert_eq!(r, SqliteValue::Text("only".into()));
866    }
867
868    #[test]
869    fn test_group_concat_integer_values_coerced_to_text() {
870        let r = run_agg(&GroupConcatFunc, &[int(1), int(2), int(3)]);
871        assert_eq!(r, SqliteValue::Text("1,2,3".into()));
872    }
873
874    #[test]
875    #[ignore = "perf-only benchmark"]
876    fn perf_group_concat_text_rows() {
877        use std::hint::black_box;
878        use std::time::Instant;
879
880        const ROWS: usize = 200_000;
881        const REPEATS: usize = 5;
882
883        let rows: Vec<SqliteValue> = (0..ROWS).map(|_| text("payload")).collect();
884        let mut best_ns = u128::MAX;
885        let mut result_len = 0usize;
886
887        for _ in 0..REPEATS {
888            let started = Instant::now();
889            let result = black_box(run_agg(&GroupConcatFunc, black_box(rows.as_slice())));
890            let elapsed_ns = started.elapsed().as_nanos();
891            if elapsed_ns < best_ns {
892                best_ns = elapsed_ns;
893            }
894            result_len = match result {
895                SqliteValue::Text(text) => text.len(),
896                SqliteValue::Null
897                | SqliteValue::Integer(_)
898                | SqliteValue::Float(_)
899                | SqliteValue::Blob(_) => 0,
900            };
901        }
902
903        println!(
904            "group_concat_text_rows rows={ROWS} repeats={REPEATS} best_ns={best_ns} result_len={result_len}"
905        );
906    }
907
908    // ── max (aggregate) ───────────────────────────────────────────────
909
910    #[test]
911    fn test_max_aggregate() {
912        let r = run_agg(&AggMaxFunc, &[int(3), int(7), int(1), int(5)]);
913        assert_eq!(r, int(7));
914    }
915
916    #[test]
917    fn test_max_aggregate_null_skipped() {
918        let r = run_agg(&AggMaxFunc, &[int(3), null(), int(7), null()]);
919        assert_eq!(r, int(7));
920    }
921
922    #[test]
923    fn test_max_aggregate_empty() {
924        let r = run_agg(&AggMaxFunc, &[]);
925        assert_eq!(r, SqliteValue::Null);
926    }
927
928    // ── min (aggregate) ───────────────────────────────────────────────
929
930    #[test]
931    fn test_min_aggregate() {
932        let r = run_agg(&AggMinFunc, &[int(3), int(7), int(1), int(5)]);
933        assert_eq!(r, int(1));
934    }
935
936    #[test]
937    fn test_min_aggregate_null_skipped() {
938        let r = run_agg(&AggMinFunc, &[int(3), null(), int(1), null()]);
939        assert_eq!(r, int(1));
940    }
941
942    #[test]
943    fn test_min_aggregate_empty() {
944        let r = run_agg(&AggMinFunc, &[]);
945        assert_eq!(r, SqliteValue::Null);
946    }
947
948    // ── sum ───────────────────────────────────────────────────────────
949
950    #[test]
951    fn test_sum_integers() {
952        let r = run_agg(&SumFunc, &[int(1), int(2), int(3)]);
953        assert_eq!(r, int(6));
954    }
955
956    #[test]
957    fn test_sum_reals() {
958        let r = run_agg(&SumFunc, &[float(1.5), float(2.5)]);
959        assert_float_eq(&r, 4.0);
960    }
961
962    #[test]
963    fn test_sum_empty_null() {
964        let r = run_agg(&SumFunc, &[]);
965        assert_eq!(r, SqliteValue::Null);
966    }
967
968    #[test]
969    fn test_sum_overflow_error() {
970        let mut state = SumFunc.initial_state();
971        SumFunc.step(&mut state, &[int(i64::MAX)]).unwrap();
972        SumFunc.step(&mut state, &[int(1)]).unwrap();
973        let err = SumFunc.finalize(state);
974        assert!(err.is_err(), "sum should raise overflow error");
975    }
976
977    #[test]
978    fn test_sum_later_real_value_clears_integer_overflow_error() {
979        let r = run_agg(&SumFunc, &[int(i64::MAX), int(1), float(0.5)]);
980        assert_float_eq(&r, 9_223_372_036_854_776_000.0);
981    }
982
983    #[test]
984    fn test_sum_integer_text_preserves_overflow_error() -> Result<()> {
985        let mut state = SumFunc.initial_state();
986        SumFunc.step(&mut state, &[text("9223372036854775807")])?;
987        SumFunc.step(&mut state, &[text("1")])?;
988        let err = SumFunc.finalize(state);
989        assert!(err.is_err(), "integer-text sum should raise overflow");
990        Ok(())
991    }
992
993    #[test]
994    fn test_sum_integer_text_later_real_clears_overflow_error() {
995        let r = run_agg(
996            &SumFunc,
997            &[text("9223372036854775807"), text("1"), text("0.5")],
998        );
999        assert_float_eq(&r, 9_223_372_036_854_776_000.0);
1000    }
1001
1002    #[test]
1003    fn test_sum_prefix_text_uses_real_accumulator() {
1004        let r = run_agg(&SumFunc, &[text("123abc"), int(1)]);
1005        assert_float_eq(&r, 124.0);
1006    }
1007
1008    #[test]
1009    fn test_sum_unicode_whitespace_text_uses_sqlite_ascii_space_rules() {
1010        let leading = run_agg(&SumFunc, &[text("\u{00a0}123"), int(1)]);
1011        assert_float_eq(&leading, 1.0);
1012
1013        let trailing = run_agg(&SumFunc, &[text("123\u{00a0}"), int(1)]);
1014        assert_float_eq(&trailing, 124.0);
1015    }
1016
1017    #[test]
1018    fn test_sum_null_skipped() {
1019        let r = run_agg(&SumFunc, &[int(1), null(), int(3)]);
1020        assert_eq!(r, int(4));
1021    }
1022
1023    // ── total ─────────────────────────────────────────────────────────
1024
1025    #[test]
1026    fn test_total_basic() {
1027        let r = run_agg(&TotalFunc, &[int(1), int(2), int(3)]);
1028        assert_float_eq(&r, 6.0);
1029    }
1030
1031    #[test]
1032    fn test_total_empty_zero() {
1033        let r = run_agg(&TotalFunc, &[]);
1034        assert_float_eq(&r, 0.0);
1035    }
1036
1037    #[test]
1038    fn test_total_no_overflow() {
1039        // total uses f64 and never overflows.
1040        let r = run_agg(&TotalFunc, &[int(i64::MAX), int(i64::MAX)]);
1041        assert!(matches!(r, SqliteValue::Float(_)));
1042    }
1043
1044    // ── median ────────────────────────────────────────────────────────
1045
1046    #[test]
1047    fn test_median_basic() {
1048        let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4), int(5)]);
1049        assert_float_eq(&r, 3.0);
1050    }
1051
1052    #[test]
1053    fn test_median_even() {
1054        let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4)]);
1055        assert_float_eq(&r, 2.5);
1056    }
1057
1058    #[test]
1059    fn test_median_null_skipped() {
1060        let r = run_agg(&MedianFunc, &[int(1), null(), int(3)]);
1061        assert_float_eq(&r, 2.0);
1062    }
1063
1064    #[test]
1065    fn test_median_empty() {
1066        let r = run_agg(&MedianFunc, &[]);
1067        assert_eq!(r, SqliteValue::Null);
1068    }
1069
1070    // ── percentile ────────────────────────────────────────────────────
1071
1072    #[test]
1073    fn test_percentile_50() {
1074        // percentile(col, 50) = median
1075        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1076            (int(1), float(50.0)),
1077            (int(2), float(50.0)),
1078            (int(3), float(50.0)),
1079            (int(4), float(50.0)),
1080            (int(5), float(50.0)),
1081        ];
1082        let r = run_agg2(&PercentileFunc, &rows);
1083        assert_float_eq(&r, 3.0);
1084    }
1085
1086    #[test]
1087    fn test_percentile_0() {
1088        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1089            (int(10), float(0.0)),
1090            (int(20), float(0.0)),
1091            (int(30), float(0.0)),
1092        ];
1093        let r = run_agg2(&PercentileFunc, &rows);
1094        assert_float_eq(&r, 10.0);
1095    }
1096
1097    #[test]
1098    fn test_percentile_100() {
1099        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1100            (int(10), float(100.0)),
1101            (int(20), float(100.0)),
1102            (int(30), float(100.0)),
1103        ];
1104        let r = run_agg2(&PercentileFunc, &rows);
1105        assert_float_eq(&r, 30.0);
1106    }
1107
1108    // ── percentile_cont ───────────────────────────────────────────────
1109
1110    #[test]
1111    fn test_percentile_cont_basic() {
1112        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1113            (int(1), float(0.5)),
1114            (int(2), float(0.5)),
1115            (int(3), float(0.5)),
1116            (int(4), float(0.5)),
1117            (int(5), float(0.5)),
1118        ];
1119        let r = run_agg2(&PercentileContFunc, &rows);
1120        assert_float_eq(&r, 3.0);
1121    }
1122
1123    // ── percentile_disc ───────────────────────────────────────────────
1124
1125    #[test]
1126    fn test_percentile_disc_basic() {
1127        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1128            (int(1), float(0.5)),
1129            (int(2), float(0.5)),
1130            (int(3), float(0.5)),
1131            (int(4), float(0.5)),
1132            (int(5), float(0.5)),
1133        ];
1134        let r = run_agg2(&PercentileDiscFunc, &rows);
1135        // Discrete: returns an actual input value.
1136        match r {
1137            SqliteValue::Float(v) => {
1138                // Should be one of the actual input values (3.0 for 0.5 in 5 items).
1139                assert!(
1140                    [1.0, 2.0, 3.0, 4.0, 5.0].contains(&v),
1141                    "expected actual value, got {v}"
1142                );
1143            }
1144            other => {
1145                assert!(
1146                    matches!(other, SqliteValue::Float(_)),
1147                    "expected Float, got {other:?}"
1148                );
1149            }
1150        }
1151    }
1152
1153    #[test]
1154    fn test_percentile_disc_no_interpolation() {
1155        // With 4 items at p=0.5, cont would interpolate, disc should not.
1156        let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1157            (int(10), float(0.5)),
1158            (int(20), float(0.5)),
1159            (int(30), float(0.5)),
1160            (int(40), float(0.5)),
1161        ];
1162        let r = run_agg2(&PercentileDiscFunc, &rows);
1163        match r {
1164            SqliteValue::Float(v) => {
1165                // Must be one of {10, 20, 30, 40}, not 25.0.
1166                assert!(
1167                    [10.0, 20.0, 30.0, 40.0].contains(&v),
1168                    "disc must not interpolate: got {v}"
1169                );
1170            }
1171            other => {
1172                assert!(
1173                    matches!(other, SqliteValue::Float(_)),
1174                    "expected Float, got {other:?}"
1175                );
1176            }
1177        }
1178    }
1179
1180    // ── string_agg (alias) ────────────────────────────────────────────
1181
1182    #[test]
1183    fn test_string_agg_alias() {
1184        let mut reg = FunctionRegistry::new();
1185        register_aggregate_builtins(&mut reg);
1186        let sa = reg
1187            .find_aggregate("string_agg", 2)
1188            .expect("string_agg registered");
1189        let mut state = sa.initial_state();
1190        sa.step(&mut state, &[text("a"), text(",")]).unwrap();
1191        sa.step(&mut state, &[text("b"), text(",")]).unwrap();
1192        let r = sa.finalize(state).unwrap();
1193        assert_eq!(r, SqliteValue::Text("a,b".into()));
1194    }
1195
1196    // ── registration ──────────────────────────────────────────────────
1197
1198    #[test]
1199    fn test_register_aggregate_builtins_all_present() {
1200        let mut reg = FunctionRegistry::new();
1201        register_aggregate_builtins(&mut reg);
1202
1203        let expected = [
1204            ("avg", 1),
1205            ("count", 0), // count(*)
1206            ("count", 1), // count(X)
1207            ("max", 1),
1208            ("min", 1),
1209            ("sum", 1),
1210            ("total", 1),
1211            ("median", 1),
1212            ("percentile", 2),
1213            ("percentile_cont", 2),
1214            ("percentile_disc", 2),
1215            ("string_agg", 2),
1216        ];
1217
1218        for (name, arity) in expected {
1219            assert!(
1220                reg.find_aggregate(name, arity).is_some(),
1221                "aggregate '{name}/{arity}' not registered"
1222            );
1223        }
1224
1225        // group_concat is variadic
1226        assert!(reg.find_aggregate("group_concat", 1).is_some());
1227        assert!(reg.find_aggregate("group_concat", 2).is_some());
1228
1229        let group_concat_zero = reg.find_aggregate("group_concat", 0).unwrap();
1230        let err = group_concat_zero
1231            .finalize(group_concat_zero.initial_state())
1232            .expect_err("group_concat() should reject zero arguments");
1233        assert!(
1234            matches!(&err, FrankenError::FunctionError(message)
1235                if message == "wrong number of arguments to function group_concat()"),
1236            "unexpected error: {err:?}"
1237        );
1238    }
1239
1240    // ── E2E: full lifecycle through registry ──────────────────────────
1241
1242    #[test]
1243    fn test_e2e_registry_invoke_aggregates() {
1244        let mut reg = FunctionRegistry::new();
1245        register_aggregate_builtins(&mut reg);
1246
1247        // avg through registry
1248        let avg = reg.find_aggregate("avg", 1).unwrap();
1249        let mut state = avg.initial_state();
1250        avg.step(&mut state, &[int(10)]).unwrap();
1251        avg.step(&mut state, &[int(20)]).unwrap();
1252        avg.step(&mut state, &[int(30)]).unwrap();
1253        let r = avg.finalize(state).unwrap();
1254        assert_float_eq(&r, 20.0);
1255
1256        // sum through registry
1257        let sum = reg.find_aggregate("sum", 1).unwrap();
1258        let mut state = sum.initial_state();
1259        sum.step(&mut state, &[int(1)]).unwrap();
1260        sum.step(&mut state, &[int(2)]).unwrap();
1261        sum.step(&mut state, &[int(3)]).unwrap();
1262        let r = sum.finalize(state).unwrap();
1263        assert_eq!(r, int(6));
1264    }
1265}