1#![allow(
13 clippy::unnecessary_literal_bound,
14 clippy::too_many_lines,
15 clippy::cast_possible_truncation,
16 clippy::cast_possible_wrap,
17 clippy::cast_precision_loss,
18 clippy::match_same_arms,
19 clippy::items_after_statements,
20 clippy::float_cmp,
21 clippy::cast_sign_loss,
22 clippy::suboptimal_flops
23)]
24
25use fsqlite_error::{FrankenError, Result};
26use fsqlite_types::SqliteValue;
27
28use crate::{AggregateFunction, FunctionRegistry};
29
30#[inline]
36fn kahan_add(sum: &mut f64, compensation: &mut f64, value: f64) {
37 let s = *sum;
38 let t = s + value;
39 if s.abs() > value.abs() {
40 *compensation += (s - t) + value;
41 } else {
42 *compensation += (value - t) + s;
43 }
44 *sum = t;
45}
46
47pub struct AvgState {
52 sum: f64,
53 compensation: f64,
54 count: i64,
55}
56
57pub struct AvgFunc;
58
59impl AggregateFunction for AvgFunc {
60 type State = AvgState;
61
62 fn initial_state(&self) -> Self::State {
63 AvgState {
64 sum: 0.0,
65 compensation: 0.0,
66 count: 0,
67 }
68 }
69
70 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
71 if !args[0].is_null() {
72 kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
73 state.count += 1;
74 }
75 Ok(())
76 }
77
78 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
79 if state.count == 0 {
80 Ok(SqliteValue::Null)
81 } else {
82 Ok(SqliteValue::Float(
83 (state.sum + state.compensation) / state.count as f64,
84 ))
85 }
86 }
87
88 fn num_args(&self) -> i32 {
89 1
90 }
91
92 fn name(&self) -> &str {
93 "avg"
94 }
95}
96
97pub struct CountStarFunc;
103
104impl AggregateFunction for CountStarFunc {
105 type State = i64;
106
107 fn initial_state(&self) -> Self::State {
108 0
109 }
110
111 fn step(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
112 *state += 1;
113 Ok(())
114 }
115
116 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
117 Ok(SqliteValue::Integer(state))
118 }
119
120 fn num_args(&self) -> i32 {
121 0 }
123
124 fn name(&self) -> &str {
125 "count"
126 }
127}
128
129pub struct CountFunc;
131
132impl AggregateFunction for CountFunc {
133 type State = i64;
134
135 fn initial_state(&self) -> Self::State {
136 0
137 }
138
139 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
140 if !args[0].is_null() {
141 *state += 1;
142 }
143 Ok(())
144 }
145
146 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
147 Ok(SqliteValue::Integer(state))
148 }
149
150 fn num_args(&self) -> i32 {
151 1
152 }
153
154 fn name(&self) -> &str {
155 "count"
156 }
157}
158
159pub struct GroupConcatState {
164 result: String,
168 has_value: bool,
169}
170
171pub struct GroupConcatFunc;
172
173#[inline]
174fn push_group_concat_text(result: &mut String, value: &SqliteValue) {
175 if let Some(text) = value.as_text_str() {
176 result.push_str(text);
177 } else {
178 result.push_str(&value.to_text());
179 }
180}
181
182impl AggregateFunction for GroupConcatFunc {
183 type State = GroupConcatState;
184
185 fn initial_state(&self) -> Self::State {
186 GroupConcatState {
187 result: String::new(),
188 has_value: false,
189 }
190 }
191
192 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
193 if args[0].is_null() {
194 return Ok(());
195 }
196 if state.has_value {
197 match args.get(1) {
198 Some(separator) if !separator.is_null() => {
199 push_group_concat_text(&mut state.result, separator);
200 }
201 Some(_) => {}
202 None => state.result.push(','),
203 }
204 }
205 push_group_concat_text(&mut state.result, &args[0]);
206 state.has_value = true;
207 Ok(())
208 }
209
210 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
211 if state.has_value {
212 Ok(SqliteValue::Text(state.result.into()))
213 } else {
214 Ok(SqliteValue::Null)
215 }
216 }
217
218 fn num_args(&self) -> i32 {
219 -1 }
221
222 fn min_args(&self) -> i32 {
223 1
224 }
225
226 fn max_args(&self) -> Option<i32> {
227 Some(2)
228 }
229
230 fn name(&self) -> &str {
231 "group_concat"
232 }
233}
234
235pub struct AggMaxFunc;
240
241impl AggregateFunction for AggMaxFunc {
242 type State = Option<SqliteValue>;
243
244 fn initial_state(&self) -> Self::State {
245 None
246 }
247
248 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
249 if args[0].is_null() {
250 return Ok(());
251 }
252 let candidate = &args[0];
253 match state {
254 None => *state = Some(candidate.clone()),
255 Some(current) => {
256 if candidate > current {
257 *state = Some(candidate.clone());
258 }
259 }
260 }
261 Ok(())
262 }
263
264 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
265 Ok(state.unwrap_or(SqliteValue::Null))
266 }
267
268 fn num_args(&self) -> i32 {
269 1
270 }
271
272 fn name(&self) -> &str {
273 "max"
274 }
275}
276
277pub struct AggMinFunc;
282
283impl AggregateFunction for AggMinFunc {
284 type State = Option<SqliteValue>;
285
286 fn initial_state(&self) -> Self::State {
287 None
288 }
289
290 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
291 if args[0].is_null() {
292 return Ok(());
293 }
294 let candidate = &args[0];
295 match state {
296 None => *state = Some(candidate.clone()),
297 Some(current) => {
298 if candidate < current {
299 *state = Some(candidate.clone());
300 }
301 }
302 }
303 Ok(())
304 }
305
306 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
307 Ok(state.unwrap_or(SqliteValue::Null))
308 }
309
310 fn num_args(&self) -> i32 {
311 1
312 }
313
314 fn name(&self) -> &str {
315 "min"
316 }
317}
318
319pub struct SumState {
327 int_sum: i64,
328 float_sum: f64,
329 float_compensation: f64,
330 all_integer: bool,
331 has_values: bool,
332 overflowed: bool,
333}
334
335pub struct SumFunc;
336
337impl AggregateFunction for SumFunc {
338 type State = SumState;
339
340 fn initial_state(&self) -> Self::State {
341 SumState {
342 int_sum: 0,
343 float_sum: 0.0,
344 float_compensation: 0.0,
345 all_integer: true,
346 has_values: false,
347 overflowed: false,
348 }
349 }
350
351 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
352 let value = args[0].to_sum_numeric_value();
353 if value.is_null() {
354 return Ok(());
355 }
356 state.has_values = true;
357 match value {
358 SqliteValue::Integer(i) => {
359 if state.all_integer && !state.overflowed {
360 match state.int_sum.checked_add(i) {
361 Some(s) => state.int_sum = s,
362 None => state.overflowed = true,
363 }
364 }
365 kahan_add(
366 &mut state.float_sum,
367 &mut state.float_compensation,
368 i as f64,
369 );
370 }
371 SqliteValue::Float(f) => {
372 state.all_integer = false;
373 kahan_add(&mut state.float_sum, &mut state.float_compensation, f);
374 }
375 SqliteValue::Null | SqliteValue::Text(_) | SqliteValue::Blob(_) => {}
376 }
377 Ok(())
378 }
379
380 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
381 if !state.has_values {
382 return Ok(SqliteValue::Null);
383 }
384 if state.all_integer && state.overflowed {
385 return Err(FrankenError::IntegerOverflow);
386 }
387 if state.all_integer {
388 Ok(SqliteValue::Integer(state.int_sum))
389 } else {
390 Ok(SqliteValue::Float(
391 state.float_sum + state.float_compensation,
392 ))
393 }
394 }
395
396 fn num_args(&self) -> i32 {
397 1
398 }
399
400 fn name(&self) -> &str {
401 "sum"
402 }
403}
404
405pub struct TotalFunc;
410
411pub struct TotalState {
413 sum: f64,
414 compensation: f64,
415}
416
417impl AggregateFunction for TotalFunc {
418 type State = TotalState;
419
420 fn initial_state(&self) -> Self::State {
421 TotalState {
422 sum: 0.0,
423 compensation: 0.0,
424 }
425 }
426
427 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
428 if !args[0].is_null() {
429 kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
430 }
431 Ok(())
432 }
433
434 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
435 Ok(SqliteValue::Float(state.sum + state.compensation))
436 }
437
438 fn num_args(&self) -> i32 {
439 1
440 }
441
442 fn name(&self) -> &str {
443 "total"
444 }
445}
446
447pub struct MedianFunc;
452
453impl AggregateFunction for MedianFunc {
454 type State = Vec<f64>;
455
456 fn initial_state(&self) -> Self::State {
457 Vec::new()
458 }
459
460 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
461 if !args[0].is_null() {
462 state.push(args[0].to_float());
463 }
464 Ok(())
465 }
466
467 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
468 if state.is_empty() {
469 return Ok(SqliteValue::Null);
470 }
471 state.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
472 let result = percentile_cont_impl(&state, 0.5);
473 Ok(SqliteValue::Float(result))
474 }
475
476 fn num_args(&self) -> i32 {
477 1
478 }
479
480 fn name(&self) -> &str {
481 "median"
482 }
483}
484
485pub struct PercentileState {
490 values: Vec<f64>,
491 p: Option<f64>,
492}
493
494pub struct PercentileFunc;
495
496impl AggregateFunction for PercentileFunc {
497 type State = PercentileState;
498
499 fn initial_state(&self) -> Self::State {
500 PercentileState {
501 values: Vec::new(),
502 p: None,
503 }
504 }
505
506 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
507 if !args[0].is_null() {
508 state.values.push(args[0].to_float());
509 }
510 if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
512 state.p = Some(args[1].to_float());
513 }
514 Ok(())
515 }
516
517 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
518 if state.values.is_empty() {
519 return Ok(SqliteValue::Null);
520 }
521 let p = state.p.unwrap_or(50.0);
522 state
523 .values
524 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
525 let result = percentile_cont_impl(&state.values, p / 100.0);
527 Ok(SqliteValue::Float(result))
528 }
529
530 fn num_args(&self) -> i32 {
531 2
532 }
533
534 fn name(&self) -> &str {
535 "percentile"
536 }
537}
538
539pub struct PercentileContFunc;
544
545impl AggregateFunction for PercentileContFunc {
546 type State = PercentileState;
547
548 fn initial_state(&self) -> Self::State {
549 PercentileState {
550 values: Vec::new(),
551 p: None,
552 }
553 }
554
555 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
556 if !args[0].is_null() {
557 state.values.push(args[0].to_float());
558 }
559 if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
560 state.p = Some(args[1].to_float());
561 }
562 Ok(())
563 }
564
565 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
566 if state.values.is_empty() {
567 return Ok(SqliteValue::Null);
568 }
569 let p = state.p.unwrap_or(0.5);
570 state
571 .values
572 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
573 let result = percentile_cont_impl(&state.values, p);
574 Ok(SqliteValue::Float(result))
575 }
576
577 fn num_args(&self) -> i32 {
578 2
579 }
580
581 fn name(&self) -> &str {
582 "percentile_cont"
583 }
584}
585
586pub struct PercentileDiscFunc;
591
592impl AggregateFunction for PercentileDiscFunc {
593 type State = PercentileState;
594
595 fn initial_state(&self) -> Self::State {
596 PercentileState {
597 values: Vec::new(),
598 p: None,
599 }
600 }
601
602 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
603 if !args[0].is_null() {
604 state.values.push(args[0].to_float());
605 }
606 if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
607 state.p = Some(args[1].to_float());
608 }
609 Ok(())
610 }
611
612 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
613 if state.values.is_empty() {
614 return Ok(SqliteValue::Null);
615 }
616 let p = state.p.unwrap_or(0.5);
617 let p = if p.is_nan() { 0.5 } else { p.clamp(0.0, 1.0) };
618 state
619 .values
620 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
621 let n = state.values.len();
623 let idx = ((p * n as f64).ceil() as usize)
624 .saturating_sub(1)
625 .min(n - 1);
626 Ok(SqliteValue::Float(state.values[idx]))
627 }
628
629 fn num_args(&self) -> i32 {
630 2
631 }
632
633 fn name(&self) -> &str {
634 "percentile_disc"
635 }
636}
637
638fn percentile_cont_impl(sorted: &[f64], p: f64) -> f64 {
643 let n = sorted.len();
644 if n == 1 {
645 return sorted[0];
646 }
647 let p = if p.is_nan() { 0.5 } else { p.clamp(0.0, 1.0) };
648 let rank = p * (n - 1) as f64;
649 let lower = rank.floor() as usize;
650 let upper = rank.ceil() as usize;
651 if lower == upper {
652 sorted[lower]
653 } else {
654 let frac = rank - lower as f64;
655 sorted[lower] * (1.0 - frac) + sorted[upper] * frac
656 }
657}
658
659pub fn register_aggregate_builtins(registry: &mut FunctionRegistry) {
663 registry.register_aggregate(AvgFunc);
664 registry.register_aggregate(CountStarFunc);
665 registry.register_aggregate(CountFunc);
666 registry.register_aggregate(GroupConcatFunc);
667 registry.register_aggregate(AggMaxFunc);
668 registry.register_aggregate(AggMinFunc);
669 registry.register_aggregate(SumFunc);
670 registry.register_aggregate(TotalFunc);
671 registry.register_aggregate(MedianFunc);
672 registry.register_aggregate(PercentileFunc);
673 registry.register_aggregate(PercentileContFunc);
674 registry.register_aggregate(PercentileDiscFunc);
675
676 struct StringAggFunc;
678 impl AggregateFunction for StringAggFunc {
679 type State = GroupConcatState;
680
681 fn initial_state(&self) -> Self::State {
682 GroupConcatState {
683 result: String::new(),
684 has_value: false,
685 }
686 }
687
688 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
689 GroupConcatFunc.step(state, args)
690 }
691
692 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
693 GroupConcatFunc.finalize(state)
694 }
695
696 fn num_args(&self) -> i32 {
697 2 }
699
700 fn name(&self) -> &str {
701 "string_agg"
702 }
703 }
704 registry.register_aggregate(StringAggFunc);
705}
706
707#[cfg(test)]
710mod tests {
711 use super::*;
712
713 const EPS: f64 = 1e-12;
714
715 fn int(v: i64) -> SqliteValue {
716 SqliteValue::Integer(v)
717 }
718
719 fn float(v: f64) -> SqliteValue {
720 SqliteValue::Float(v)
721 }
722
723 fn null() -> SqliteValue {
724 SqliteValue::Null
725 }
726
727 fn text(s: &str) -> SqliteValue {
728 SqliteValue::Text(s.into())
729 }
730
731 fn assert_float_eq(result: &SqliteValue, expected: f64) {
732 match result {
733 SqliteValue::Float(v) => {
734 assert!((v - expected).abs() < EPS, "expected {expected}, got {v}");
735 }
736 other => {
737 assert!(
738 matches!(other, SqliteValue::Float(_)),
739 "expected Float({expected}), got {other:?}"
740 );
741 }
742 }
743 }
744
745 fn run_agg<F: AggregateFunction>(func: &F, rows: &[SqliteValue]) -> SqliteValue {
747 let mut state = func.initial_state();
748 for row in rows {
749 func.step(&mut state, std::slice::from_ref(row)).unwrap();
750 }
751 func.finalize(state).unwrap()
752 }
753
754 fn run_agg2<F: AggregateFunction>(
756 func: &F,
757 rows: &[(SqliteValue, SqliteValue)],
758 ) -> SqliteValue {
759 let mut state = func.initial_state();
760 for (a, b) in rows {
761 func.step(&mut state, &[a.clone(), b.clone()]).unwrap();
762 }
763 func.finalize(state).unwrap()
764 }
765
766 #[test]
769 fn test_avg_basic() {
770 let r = run_agg(&AvgFunc, &[int(1), int(2), int(3), int(4), int(5)]);
771 assert_float_eq(&r, 3.0);
772 }
773
774 #[test]
775 fn test_avg_with_nulls() {
776 let r = run_agg(&AvgFunc, &[int(1), null(), int(3)]);
777 assert_float_eq(&r, 2.0);
778 }
779
780 #[test]
781 fn test_avg_empty() {
782 let r = run_agg(&AvgFunc, &[]);
783 assert_eq!(r, SqliteValue::Null);
784 }
785
786 #[test]
787 fn test_avg_returns_real() {
788 let r = run_agg(&AvgFunc, &[int(2), int(4)]);
789 assert!(matches!(r, SqliteValue::Float(_)));
790 }
791
792 #[test]
795 fn test_count_star() {
796 let mut state = CountStarFunc.initial_state();
798 CountStarFunc.step(&mut state, &[]).unwrap(); CountStarFunc.step(&mut state, &[]).unwrap(); CountStarFunc.step(&mut state, &[]).unwrap(); let r = CountStarFunc.finalize(state).unwrap();
802 assert_eq!(r, int(3));
803 }
804
805 #[test]
806 fn test_count_column() {
807 let r = run_agg(&CountFunc, &[int(1), null(), int(3), null(), int(5)]);
808 assert_eq!(r, int(3));
809 }
810
811 #[test]
812 fn test_count_empty() {
813 let r = run_agg(&CountFunc, &[]);
814 assert_eq!(r, int(0));
815 }
816
817 #[test]
820 fn test_group_concat_basic() {
821 let r = run_agg(&GroupConcatFunc, &[text("a"), text("b"), text("c")]);
822 assert_eq!(r, SqliteValue::Text("a,b,c".into()));
823 }
824
825 #[test]
826 fn test_group_concat_custom_sep() {
827 let rows = vec![
828 (text("a"), text("; ")),
829 (text("b"), text("; ")),
830 (text("c"), text("; ")),
831 ];
832 let r = run_agg2(&GroupConcatFunc, &rows);
833 assert_eq!(r, SqliteValue::Text("a; b; c".into()));
834 }
835
836 #[test]
837 fn test_group_concat_null_skipped() {
838 let r = run_agg(&GroupConcatFunc, &[text("a"), null(), text("c")]);
839 assert_eq!(r, SqliteValue::Text("a,c".into()));
840 }
841
842 #[test]
843 fn test_group_concat_empty() {
844 let r = run_agg(&GroupConcatFunc, &[]);
845 assert_eq!(r, SqliteValue::Null);
846 }
847
848 #[test]
849 fn test_group_concat_varying_separator() {
850 let rows = vec![
854 (text("a"), text("-")),
855 (text("b"), text("+")),
856 (text("c"), text("*")),
857 ];
858 let r = run_agg2(&GroupConcatFunc, &rows);
859 assert_eq!(r, SqliteValue::Text("a+b*c".into()));
860 }
861
862 #[test]
863 fn test_group_concat_single_value() {
864 let r = run_agg(&GroupConcatFunc, &[text("only")]);
865 assert_eq!(r, SqliteValue::Text("only".into()));
866 }
867
868 #[test]
869 fn test_group_concat_integer_values_coerced_to_text() {
870 let r = run_agg(&GroupConcatFunc, &[int(1), int(2), int(3)]);
871 assert_eq!(r, SqliteValue::Text("1,2,3".into()));
872 }
873
874 #[test]
875 #[ignore = "perf-only benchmark"]
876 fn perf_group_concat_text_rows() {
877 use std::hint::black_box;
878 use std::time::Instant;
879
880 const ROWS: usize = 200_000;
881 const REPEATS: usize = 5;
882
883 let rows: Vec<SqliteValue> = (0..ROWS).map(|_| text("payload")).collect();
884 let mut best_ns = u128::MAX;
885 let mut result_len = 0usize;
886
887 for _ in 0..REPEATS {
888 let started = Instant::now();
889 let result = black_box(run_agg(&GroupConcatFunc, black_box(rows.as_slice())));
890 let elapsed_ns = started.elapsed().as_nanos();
891 if elapsed_ns < best_ns {
892 best_ns = elapsed_ns;
893 }
894 result_len = match result {
895 SqliteValue::Text(text) => text.len(),
896 SqliteValue::Null
897 | SqliteValue::Integer(_)
898 | SqliteValue::Float(_)
899 | SqliteValue::Blob(_) => 0,
900 };
901 }
902
903 println!(
904 "group_concat_text_rows rows={ROWS} repeats={REPEATS} best_ns={best_ns} result_len={result_len}"
905 );
906 }
907
908 #[test]
911 fn test_max_aggregate() {
912 let r = run_agg(&AggMaxFunc, &[int(3), int(7), int(1), int(5)]);
913 assert_eq!(r, int(7));
914 }
915
916 #[test]
917 fn test_max_aggregate_null_skipped() {
918 let r = run_agg(&AggMaxFunc, &[int(3), null(), int(7), null()]);
919 assert_eq!(r, int(7));
920 }
921
922 #[test]
923 fn test_max_aggregate_empty() {
924 let r = run_agg(&AggMaxFunc, &[]);
925 assert_eq!(r, SqliteValue::Null);
926 }
927
928 #[test]
931 fn test_min_aggregate() {
932 let r = run_agg(&AggMinFunc, &[int(3), int(7), int(1), int(5)]);
933 assert_eq!(r, int(1));
934 }
935
936 #[test]
937 fn test_min_aggregate_null_skipped() {
938 let r = run_agg(&AggMinFunc, &[int(3), null(), int(1), null()]);
939 assert_eq!(r, int(1));
940 }
941
942 #[test]
943 fn test_min_aggregate_empty() {
944 let r = run_agg(&AggMinFunc, &[]);
945 assert_eq!(r, SqliteValue::Null);
946 }
947
948 #[test]
951 fn test_sum_integers() {
952 let r = run_agg(&SumFunc, &[int(1), int(2), int(3)]);
953 assert_eq!(r, int(6));
954 }
955
956 #[test]
957 fn test_sum_reals() {
958 let r = run_agg(&SumFunc, &[float(1.5), float(2.5)]);
959 assert_float_eq(&r, 4.0);
960 }
961
962 #[test]
963 fn test_sum_empty_null() {
964 let r = run_agg(&SumFunc, &[]);
965 assert_eq!(r, SqliteValue::Null);
966 }
967
968 #[test]
969 fn test_sum_overflow_error() {
970 let mut state = SumFunc.initial_state();
971 SumFunc.step(&mut state, &[int(i64::MAX)]).unwrap();
972 SumFunc.step(&mut state, &[int(1)]).unwrap();
973 let err = SumFunc.finalize(state);
974 assert!(err.is_err(), "sum should raise overflow error");
975 }
976
977 #[test]
978 fn test_sum_later_real_value_clears_integer_overflow_error() {
979 let r = run_agg(&SumFunc, &[int(i64::MAX), int(1), float(0.5)]);
980 assert_float_eq(&r, 9_223_372_036_854_776_000.0);
981 }
982
983 #[test]
984 fn test_sum_integer_text_preserves_overflow_error() -> Result<()> {
985 let mut state = SumFunc.initial_state();
986 SumFunc.step(&mut state, &[text("9223372036854775807")])?;
987 SumFunc.step(&mut state, &[text("1")])?;
988 let err = SumFunc.finalize(state);
989 assert!(err.is_err(), "integer-text sum should raise overflow");
990 Ok(())
991 }
992
993 #[test]
994 fn test_sum_integer_text_later_real_clears_overflow_error() {
995 let r = run_agg(
996 &SumFunc,
997 &[text("9223372036854775807"), text("1"), text("0.5")],
998 );
999 assert_float_eq(&r, 9_223_372_036_854_776_000.0);
1000 }
1001
1002 #[test]
1003 fn test_sum_prefix_text_uses_real_accumulator() {
1004 let r = run_agg(&SumFunc, &[text("123abc"), int(1)]);
1005 assert_float_eq(&r, 124.0);
1006 }
1007
1008 #[test]
1009 fn test_sum_unicode_whitespace_text_uses_sqlite_ascii_space_rules() {
1010 let leading = run_agg(&SumFunc, &[text("\u{00a0}123"), int(1)]);
1011 assert_float_eq(&leading, 1.0);
1012
1013 let trailing = run_agg(&SumFunc, &[text("123\u{00a0}"), int(1)]);
1014 assert_float_eq(&trailing, 124.0);
1015 }
1016
1017 #[test]
1018 fn test_sum_null_skipped() {
1019 let r = run_agg(&SumFunc, &[int(1), null(), int(3)]);
1020 assert_eq!(r, int(4));
1021 }
1022
1023 #[test]
1026 fn test_total_basic() {
1027 let r = run_agg(&TotalFunc, &[int(1), int(2), int(3)]);
1028 assert_float_eq(&r, 6.0);
1029 }
1030
1031 #[test]
1032 fn test_total_empty_zero() {
1033 let r = run_agg(&TotalFunc, &[]);
1034 assert_float_eq(&r, 0.0);
1035 }
1036
1037 #[test]
1038 fn test_total_no_overflow() {
1039 let r = run_agg(&TotalFunc, &[int(i64::MAX), int(i64::MAX)]);
1041 assert!(matches!(r, SqliteValue::Float(_)));
1042 }
1043
1044 #[test]
1047 fn test_median_basic() {
1048 let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4), int(5)]);
1049 assert_float_eq(&r, 3.0);
1050 }
1051
1052 #[test]
1053 fn test_median_even() {
1054 let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4)]);
1055 assert_float_eq(&r, 2.5);
1056 }
1057
1058 #[test]
1059 fn test_median_null_skipped() {
1060 let r = run_agg(&MedianFunc, &[int(1), null(), int(3)]);
1061 assert_float_eq(&r, 2.0);
1062 }
1063
1064 #[test]
1065 fn test_median_empty() {
1066 let r = run_agg(&MedianFunc, &[]);
1067 assert_eq!(r, SqliteValue::Null);
1068 }
1069
1070 #[test]
1073 fn test_percentile_50() {
1074 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1076 (int(1), float(50.0)),
1077 (int(2), float(50.0)),
1078 (int(3), float(50.0)),
1079 (int(4), float(50.0)),
1080 (int(5), float(50.0)),
1081 ];
1082 let r = run_agg2(&PercentileFunc, &rows);
1083 assert_float_eq(&r, 3.0);
1084 }
1085
1086 #[test]
1087 fn test_percentile_0() {
1088 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1089 (int(10), float(0.0)),
1090 (int(20), float(0.0)),
1091 (int(30), float(0.0)),
1092 ];
1093 let r = run_agg2(&PercentileFunc, &rows);
1094 assert_float_eq(&r, 10.0);
1095 }
1096
1097 #[test]
1098 fn test_percentile_100() {
1099 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1100 (int(10), float(100.0)),
1101 (int(20), float(100.0)),
1102 (int(30), float(100.0)),
1103 ];
1104 let r = run_agg2(&PercentileFunc, &rows);
1105 assert_float_eq(&r, 30.0);
1106 }
1107
1108 #[test]
1111 fn test_percentile_cont_basic() {
1112 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1113 (int(1), float(0.5)),
1114 (int(2), float(0.5)),
1115 (int(3), float(0.5)),
1116 (int(4), float(0.5)),
1117 (int(5), float(0.5)),
1118 ];
1119 let r = run_agg2(&PercentileContFunc, &rows);
1120 assert_float_eq(&r, 3.0);
1121 }
1122
1123 #[test]
1126 fn test_percentile_disc_basic() {
1127 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1128 (int(1), float(0.5)),
1129 (int(2), float(0.5)),
1130 (int(3), float(0.5)),
1131 (int(4), float(0.5)),
1132 (int(5), float(0.5)),
1133 ];
1134 let r = run_agg2(&PercentileDiscFunc, &rows);
1135 match r {
1137 SqliteValue::Float(v) => {
1138 assert!(
1140 [1.0, 2.0, 3.0, 4.0, 5.0].contains(&v),
1141 "expected actual value, got {v}"
1142 );
1143 }
1144 other => {
1145 assert!(
1146 matches!(other, SqliteValue::Float(_)),
1147 "expected Float, got {other:?}"
1148 );
1149 }
1150 }
1151 }
1152
1153 #[test]
1154 fn test_percentile_disc_no_interpolation() {
1155 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1157 (int(10), float(0.5)),
1158 (int(20), float(0.5)),
1159 (int(30), float(0.5)),
1160 (int(40), float(0.5)),
1161 ];
1162 let r = run_agg2(&PercentileDiscFunc, &rows);
1163 match r {
1164 SqliteValue::Float(v) => {
1165 assert!(
1167 [10.0, 20.0, 30.0, 40.0].contains(&v),
1168 "disc must not interpolate: got {v}"
1169 );
1170 }
1171 other => {
1172 assert!(
1173 matches!(other, SqliteValue::Float(_)),
1174 "expected Float, got {other:?}"
1175 );
1176 }
1177 }
1178 }
1179
1180 #[test]
1183 fn test_string_agg_alias() {
1184 let mut reg = FunctionRegistry::new();
1185 register_aggregate_builtins(&mut reg);
1186 let sa = reg
1187 .find_aggregate("string_agg", 2)
1188 .expect("string_agg registered");
1189 let mut state = sa.initial_state();
1190 sa.step(&mut state, &[text("a"), text(",")]).unwrap();
1191 sa.step(&mut state, &[text("b"), text(",")]).unwrap();
1192 let r = sa.finalize(state).unwrap();
1193 assert_eq!(r, SqliteValue::Text("a,b".into()));
1194 }
1195
1196 #[test]
1199 fn test_register_aggregate_builtins_all_present() {
1200 let mut reg = FunctionRegistry::new();
1201 register_aggregate_builtins(&mut reg);
1202
1203 let expected = [
1204 ("avg", 1),
1205 ("count", 0), ("count", 1), ("max", 1),
1208 ("min", 1),
1209 ("sum", 1),
1210 ("total", 1),
1211 ("median", 1),
1212 ("percentile", 2),
1213 ("percentile_cont", 2),
1214 ("percentile_disc", 2),
1215 ("string_agg", 2),
1216 ];
1217
1218 for (name, arity) in expected {
1219 assert!(
1220 reg.find_aggregate(name, arity).is_some(),
1221 "aggregate '{name}/{arity}' not registered"
1222 );
1223 }
1224
1225 assert!(reg.find_aggregate("group_concat", 1).is_some());
1227 assert!(reg.find_aggregate("group_concat", 2).is_some());
1228
1229 let group_concat_zero = reg.find_aggregate("group_concat", 0).unwrap();
1230 let err = group_concat_zero
1231 .finalize(group_concat_zero.initial_state())
1232 .expect_err("group_concat() should reject zero arguments");
1233 assert!(
1234 matches!(&err, FrankenError::FunctionError(message)
1235 if message == "wrong number of arguments to function group_concat()"),
1236 "unexpected error: {err:?}"
1237 );
1238 }
1239
1240 #[test]
1243 fn test_e2e_registry_invoke_aggregates() {
1244 let mut reg = FunctionRegistry::new();
1245 register_aggregate_builtins(&mut reg);
1246
1247 let avg = reg.find_aggregate("avg", 1).unwrap();
1249 let mut state = avg.initial_state();
1250 avg.step(&mut state, &[int(10)]).unwrap();
1251 avg.step(&mut state, &[int(20)]).unwrap();
1252 avg.step(&mut state, &[int(30)]).unwrap();
1253 let r = avg.finalize(state).unwrap();
1254 assert_float_eq(&r, 20.0);
1255
1256 let sum = reg.find_aggregate("sum", 1).unwrap();
1258 let mut state = sum.initial_state();
1259 sum.step(&mut state, &[int(1)]).unwrap();
1260 sum.step(&mut state, &[int(2)]).unwrap();
1261 sum.step(&mut state, &[int(3)]).unwrap();
1262 let r = sum.finalize(state).unwrap();
1263 assert_eq!(r, int(6));
1264 }
1265}