1#![allow(
13 clippy::unnecessary_literal_bound,
14 clippy::too_many_lines,
15 clippy::cast_possible_truncation,
16 clippy::cast_possible_wrap,
17 clippy::cast_precision_loss,
18 clippy::match_same_arms,
19 clippy::items_after_statements,
20 clippy::float_cmp,
21 clippy::cast_sign_loss,
22 clippy::suboptimal_flops
23)]
24
25use fsqlite_error::{FrankenError, Result};
26use fsqlite_types::SqliteValue;
27
28use crate::{AggregateFunction, FunctionRegistry};
29
30#[inline]
36fn kahan_add(sum: &mut f64, compensation: &mut f64, value: f64) {
37 let s = *sum;
38 let t = s + value;
39 if s.abs() > value.abs() {
40 *compensation += (s - t) + value;
41 } else {
42 *compensation += (value - t) + s;
43 }
44 *sum = t;
45}
46
47pub struct AvgState {
52 sum: f64,
53 compensation: f64,
54 count: i64,
55}
56
57pub struct AvgFunc;
58
59impl AggregateFunction for AvgFunc {
60 type State = AvgState;
61
62 fn initial_state(&self) -> Self::State {
63 AvgState {
64 sum: 0.0,
65 compensation: 0.0,
66 count: 0,
67 }
68 }
69
70 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
71 if !args[0].is_null() {
72 kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
73 state.count += 1;
74 }
75 Ok(())
76 }
77
78 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
79 if state.count == 0 {
80 Ok(SqliteValue::Null)
81 } else {
82 Ok(SqliteValue::Float(
83 (state.sum + state.compensation) / state.count as f64,
84 ))
85 }
86 }
87
88 fn num_args(&self) -> i32 {
89 1
90 }
91
92 fn name(&self) -> &str {
93 "avg"
94 }
95}
96
97pub struct CountStarFunc;
103
104impl AggregateFunction for CountStarFunc {
105 type State = i64;
106
107 fn initial_state(&self) -> Self::State {
108 0
109 }
110
111 fn step(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
112 *state += 1;
113 Ok(())
114 }
115
116 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
117 Ok(SqliteValue::Integer(state))
118 }
119
120 fn num_args(&self) -> i32 {
121 0 }
123
124 fn name(&self) -> &str {
125 "count"
126 }
127}
128
129pub struct CountFunc;
131
132impl AggregateFunction for CountFunc {
133 type State = i64;
134
135 fn initial_state(&self) -> Self::State {
136 0
137 }
138
139 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
140 if !args[0].is_null() {
141 *state += 1;
142 }
143 Ok(())
144 }
145
146 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
147 Ok(SqliteValue::Integer(state))
148 }
149
150 fn num_args(&self) -> i32 {
151 1
152 }
153
154 fn name(&self) -> &str {
155 "count"
156 }
157}
158
159pub struct GroupConcatState {
164 result: String,
168 has_value: bool,
169}
170
171pub struct GroupConcatFunc;
172
173impl AggregateFunction for GroupConcatFunc {
174 type State = GroupConcatState;
175
176 fn initial_state(&self) -> Self::State {
177 GroupConcatState {
178 result: String::new(),
179 has_value: false,
180 }
181 }
182
183 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
184 if args[0].is_null() {
185 return Ok(());
186 }
187 let sep = if args.len() > 1 {
188 if args[1].is_null() {
189 String::new()
190 } else {
191 args[1].to_text()
192 }
193 } else {
194 ",".to_owned()
195 };
196 if state.has_value {
197 state.result.push_str(&sep);
198 }
199 state.result.push_str(&args[0].to_text());
200 state.has_value = true;
201 Ok(())
202 }
203
204 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
205 if state.has_value {
206 Ok(SqliteValue::Text(state.result.into()))
207 } else {
208 Ok(SqliteValue::Null)
209 }
210 }
211
212 fn num_args(&self) -> i32 {
213 -1 }
215
216 fn name(&self) -> &str {
217 "group_concat"
218 }
219}
220
221pub struct AggMaxFunc;
226
227impl AggregateFunction for AggMaxFunc {
228 type State = Option<SqliteValue>;
229
230 fn initial_state(&self) -> Self::State {
231 None
232 }
233
234 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
235 if args[0].is_null() {
236 return Ok(());
237 }
238 let candidate = &args[0];
239 match state {
240 None => *state = Some(candidate.clone()),
241 Some(current) => {
242 if candidate > current {
243 *state = Some(candidate.clone());
244 }
245 }
246 }
247 Ok(())
248 }
249
250 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
251 Ok(state.unwrap_or(SqliteValue::Null))
252 }
253
254 fn num_args(&self) -> i32 {
255 1
256 }
257
258 fn name(&self) -> &str {
259 "max"
260 }
261}
262
263pub struct AggMinFunc;
268
269impl AggregateFunction for AggMinFunc {
270 type State = Option<SqliteValue>;
271
272 fn initial_state(&self) -> Self::State {
273 None
274 }
275
276 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
277 if args[0].is_null() {
278 return Ok(());
279 }
280 let candidate = &args[0];
281 match state {
282 None => *state = Some(candidate.clone()),
283 Some(current) => {
284 if candidate < current {
285 *state = Some(candidate.clone());
286 }
287 }
288 }
289 Ok(())
290 }
291
292 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
293 Ok(state.unwrap_or(SqliteValue::Null))
294 }
295
296 fn num_args(&self) -> i32 {
297 1
298 }
299
300 fn name(&self) -> &str {
301 "min"
302 }
303}
304
305pub struct SumState {
313 int_sum: i64,
314 float_sum: f64,
315 float_compensation: f64,
316 all_integer: bool,
317 has_values: bool,
318 overflowed: bool,
319}
320
321pub struct SumFunc;
322
323impl AggregateFunction for SumFunc {
324 type State = SumState;
325
326 fn initial_state(&self) -> Self::State {
327 SumState {
328 int_sum: 0,
329 float_sum: 0.0,
330 float_compensation: 0.0,
331 all_integer: true,
332 has_values: false,
333 overflowed: false,
334 }
335 }
336
337 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
338 if args[0].is_null() {
339 return Ok(());
340 }
341 state.has_values = true;
342 match &args[0] {
343 SqliteValue::Integer(i) => {
344 if state.all_integer && !state.overflowed {
345 match state.int_sum.checked_add(*i) {
346 Some(s) => state.int_sum = s,
347 None => state.overflowed = true,
348 }
349 }
350 kahan_add(
351 &mut state.float_sum,
352 &mut state.float_compensation,
353 *i as f64,
354 );
355 }
356 other => {
357 state.all_integer = false;
358 kahan_add(
359 &mut state.float_sum,
360 &mut state.float_compensation,
361 other.to_float(),
362 );
363 }
364 }
365 Ok(())
366 }
367
368 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
369 if !state.has_values {
370 return Ok(SqliteValue::Null);
371 }
372 if state.overflowed {
373 return Err(FrankenError::IntegerOverflow);
374 }
375 if state.all_integer {
376 Ok(SqliteValue::Integer(state.int_sum))
377 } else {
378 Ok(SqliteValue::Float(
379 state.float_sum + state.float_compensation,
380 ))
381 }
382 }
383
384 fn num_args(&self) -> i32 {
385 1
386 }
387
388 fn name(&self) -> &str {
389 "sum"
390 }
391}
392
393pub struct TotalFunc;
398
399pub struct TotalState {
401 sum: f64,
402 compensation: f64,
403}
404
405impl AggregateFunction for TotalFunc {
406 type State = TotalState;
407
408 fn initial_state(&self) -> Self::State {
409 TotalState {
410 sum: 0.0,
411 compensation: 0.0,
412 }
413 }
414
415 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
416 if !args[0].is_null() {
417 kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
418 }
419 Ok(())
420 }
421
422 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
423 Ok(SqliteValue::Float(state.sum + state.compensation))
424 }
425
426 fn num_args(&self) -> i32 {
427 1
428 }
429
430 fn name(&self) -> &str {
431 "total"
432 }
433}
434
435pub struct MedianFunc;
440
441impl AggregateFunction for MedianFunc {
442 type State = Vec<f64>;
443
444 fn initial_state(&self) -> Self::State {
445 Vec::new()
446 }
447
448 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
449 if !args[0].is_null() {
450 state.push(args[0].to_float());
451 }
452 Ok(())
453 }
454
455 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
456 if state.is_empty() {
457 return Ok(SqliteValue::Null);
458 }
459 state.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
460 let result = percentile_cont_impl(&state, 0.5);
461 Ok(SqliteValue::Float(result))
462 }
463
464 fn num_args(&self) -> i32 {
465 1
466 }
467
468 fn name(&self) -> &str {
469 "median"
470 }
471}
472
473pub struct PercentileState {
478 values: Vec<f64>,
479 p: Option<f64>,
480}
481
482pub struct PercentileFunc;
483
484impl AggregateFunction for PercentileFunc {
485 type State = PercentileState;
486
487 fn initial_state(&self) -> Self::State {
488 PercentileState {
489 values: Vec::new(),
490 p: None,
491 }
492 }
493
494 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
495 if !args[0].is_null() {
496 state.values.push(args[0].to_float());
497 }
498 if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
500 state.p = Some(args[1].to_float());
501 }
502 Ok(())
503 }
504
505 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
506 if state.values.is_empty() {
507 return Ok(SqliteValue::Null);
508 }
509 let p = state.p.unwrap_or(50.0);
510 state
511 .values
512 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
513 let result = percentile_cont_impl(&state.values, p / 100.0);
515 Ok(SqliteValue::Float(result))
516 }
517
518 fn num_args(&self) -> i32 {
519 2
520 }
521
522 fn name(&self) -> &str {
523 "percentile"
524 }
525}
526
527pub struct PercentileContFunc;
532
533impl AggregateFunction for PercentileContFunc {
534 type State = PercentileState;
535
536 fn initial_state(&self) -> Self::State {
537 PercentileState {
538 values: Vec::new(),
539 p: None,
540 }
541 }
542
543 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
544 if !args[0].is_null() {
545 state.values.push(args[0].to_float());
546 }
547 if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
548 state.p = Some(args[1].to_float());
549 }
550 Ok(())
551 }
552
553 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
554 if state.values.is_empty() {
555 return Ok(SqliteValue::Null);
556 }
557 let p = state.p.unwrap_or(0.5);
558 state
559 .values
560 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
561 let result = percentile_cont_impl(&state.values, p);
562 Ok(SqliteValue::Float(result))
563 }
564
565 fn num_args(&self) -> i32 {
566 2
567 }
568
569 fn name(&self) -> &str {
570 "percentile_cont"
571 }
572}
573
574pub struct PercentileDiscFunc;
579
580impl AggregateFunction for PercentileDiscFunc {
581 type State = PercentileState;
582
583 fn initial_state(&self) -> Self::State {
584 PercentileState {
585 values: Vec::new(),
586 p: None,
587 }
588 }
589
590 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
591 if !args[0].is_null() {
592 state.values.push(args[0].to_float());
593 }
594 if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
595 state.p = Some(args[1].to_float());
596 }
597 Ok(())
598 }
599
600 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
601 if state.values.is_empty() {
602 return Ok(SqliteValue::Null);
603 }
604 let p = state.p.unwrap_or(0.5);
605 let p = if p.is_nan() { 0.5 } else { p.clamp(0.0, 1.0) };
606 state
607 .values
608 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
609 let n = state.values.len();
611 let idx = ((p * n as f64).ceil() as usize)
612 .saturating_sub(1)
613 .min(n - 1);
614 Ok(SqliteValue::Float(state.values[idx]))
615 }
616
617 fn num_args(&self) -> i32 {
618 2
619 }
620
621 fn name(&self) -> &str {
622 "percentile_disc"
623 }
624}
625
626fn percentile_cont_impl(sorted: &[f64], p: f64) -> f64 {
631 let n = sorted.len();
632 if n == 1 {
633 return sorted[0];
634 }
635 let p = if p.is_nan() { 0.5 } else { p.clamp(0.0, 1.0) };
636 let rank = p * (n - 1) as f64;
637 let lower = rank.floor() as usize;
638 let upper = rank.ceil() as usize;
639 if lower == upper {
640 sorted[lower]
641 } else {
642 let frac = rank - lower as f64;
643 sorted[lower] * (1.0 - frac) + sorted[upper] * frac
644 }
645}
646
647pub fn register_aggregate_builtins(registry: &mut FunctionRegistry) {
651 registry.register_aggregate(AvgFunc);
652 registry.register_aggregate(CountStarFunc);
653 registry.register_aggregate(CountFunc);
654 registry.register_aggregate(GroupConcatFunc);
655 registry.register_aggregate(AggMaxFunc);
656 registry.register_aggregate(AggMinFunc);
657 registry.register_aggregate(SumFunc);
658 registry.register_aggregate(TotalFunc);
659 registry.register_aggregate(MedianFunc);
660 registry.register_aggregate(PercentileFunc);
661 registry.register_aggregate(PercentileContFunc);
662 registry.register_aggregate(PercentileDiscFunc);
663
664 struct StringAggFunc;
666 impl AggregateFunction for StringAggFunc {
667 type State = GroupConcatState;
668
669 fn initial_state(&self) -> Self::State {
670 GroupConcatState {
671 result: String::new(),
672 has_value: false,
673 }
674 }
675
676 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
677 GroupConcatFunc.step(state, args)
678 }
679
680 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
681 GroupConcatFunc.finalize(state)
682 }
683
684 fn num_args(&self) -> i32 {
685 2 }
687
688 fn name(&self) -> &str {
689 "string_agg"
690 }
691 }
692 registry.register_aggregate(StringAggFunc);
693}
694
695#[cfg(test)]
698mod tests {
699 use super::*;
700
701 const EPS: f64 = 1e-12;
702
703 fn int(v: i64) -> SqliteValue {
704 SqliteValue::Integer(v)
705 }
706
707 fn float(v: f64) -> SqliteValue {
708 SqliteValue::Float(v)
709 }
710
711 fn null() -> SqliteValue {
712 SqliteValue::Null
713 }
714
715 fn text(s: &str) -> SqliteValue {
716 SqliteValue::Text(s.into())
717 }
718
719 fn assert_float_eq(result: &SqliteValue, expected: f64) {
720 match result {
721 SqliteValue::Float(v) => {
722 assert!((v - expected).abs() < EPS, "expected {expected}, got {v}");
723 }
724 other => {
725 assert!(
726 matches!(other, SqliteValue::Float(_)),
727 "expected Float({expected}), got {other:?}"
728 );
729 }
730 }
731 }
732
733 fn run_agg<F: AggregateFunction>(func: &F, rows: &[SqliteValue]) -> SqliteValue {
735 let mut state = func.initial_state();
736 for row in rows {
737 func.step(&mut state, std::slice::from_ref(row)).unwrap();
738 }
739 func.finalize(state).unwrap()
740 }
741
742 fn run_agg2<F: AggregateFunction>(
744 func: &F,
745 rows: &[(SqliteValue, SqliteValue)],
746 ) -> SqliteValue {
747 let mut state = func.initial_state();
748 for (a, b) in rows {
749 func.step(&mut state, &[a.clone(), b.clone()]).unwrap();
750 }
751 func.finalize(state).unwrap()
752 }
753
754 #[test]
757 fn test_avg_basic() {
758 let r = run_agg(&AvgFunc, &[int(1), int(2), int(3), int(4), int(5)]);
759 assert_float_eq(&r, 3.0);
760 }
761
762 #[test]
763 fn test_avg_with_nulls() {
764 let r = run_agg(&AvgFunc, &[int(1), null(), int(3)]);
765 assert_float_eq(&r, 2.0);
766 }
767
768 #[test]
769 fn test_avg_empty() {
770 let r = run_agg(&AvgFunc, &[]);
771 assert_eq!(r, SqliteValue::Null);
772 }
773
774 #[test]
775 fn test_avg_returns_real() {
776 let r = run_agg(&AvgFunc, &[int(2), int(4)]);
777 assert!(matches!(r, SqliteValue::Float(_)));
778 }
779
780 #[test]
783 fn test_count_star() {
784 let mut state = CountStarFunc.initial_state();
786 CountStarFunc.step(&mut state, &[]).unwrap(); CountStarFunc.step(&mut state, &[]).unwrap(); CountStarFunc.step(&mut state, &[]).unwrap(); let r = CountStarFunc.finalize(state).unwrap();
790 assert_eq!(r, int(3));
791 }
792
793 #[test]
794 fn test_count_column() {
795 let r = run_agg(&CountFunc, &[int(1), null(), int(3), null(), int(5)]);
796 assert_eq!(r, int(3));
797 }
798
799 #[test]
800 fn test_count_empty() {
801 let r = run_agg(&CountFunc, &[]);
802 assert_eq!(r, int(0));
803 }
804
805 #[test]
808 fn test_group_concat_basic() {
809 let r = run_agg(&GroupConcatFunc, &[text("a"), text("b"), text("c")]);
810 assert_eq!(r, SqliteValue::Text("a,b,c".into()));
811 }
812
813 #[test]
814 fn test_group_concat_custom_sep() {
815 let rows = vec![
816 (text("a"), text("; ")),
817 (text("b"), text("; ")),
818 (text("c"), text("; ")),
819 ];
820 let r = run_agg2(&GroupConcatFunc, &rows);
821 assert_eq!(r, SqliteValue::Text("a; b; c".into()));
822 }
823
824 #[test]
825 fn test_group_concat_null_skipped() {
826 let r = run_agg(&GroupConcatFunc, &[text("a"), null(), text("c")]);
827 assert_eq!(r, SqliteValue::Text("a,c".into()));
828 }
829
830 #[test]
831 fn test_group_concat_empty() {
832 let r = run_agg(&GroupConcatFunc, &[]);
833 assert_eq!(r, SqliteValue::Null);
834 }
835
836 #[test]
837 fn test_group_concat_varying_separator() {
838 let rows = vec![
842 (text("a"), text("-")),
843 (text("b"), text("+")),
844 (text("c"), text("*")),
845 ];
846 let r = run_agg2(&GroupConcatFunc, &rows);
847 assert_eq!(r, SqliteValue::Text("a+b*c".into()));
848 }
849
850 #[test]
851 fn test_group_concat_single_value() {
852 let r = run_agg(&GroupConcatFunc, &[text("only")]);
853 assert_eq!(r, SqliteValue::Text("only".into()));
854 }
855
856 #[test]
859 fn test_max_aggregate() {
860 let r = run_agg(&AggMaxFunc, &[int(3), int(7), int(1), int(5)]);
861 assert_eq!(r, int(7));
862 }
863
864 #[test]
865 fn test_max_aggregate_null_skipped() {
866 let r = run_agg(&AggMaxFunc, &[int(3), null(), int(7), null()]);
867 assert_eq!(r, int(7));
868 }
869
870 #[test]
871 fn test_max_aggregate_empty() {
872 let r = run_agg(&AggMaxFunc, &[]);
873 assert_eq!(r, SqliteValue::Null);
874 }
875
876 #[test]
879 fn test_min_aggregate() {
880 let r = run_agg(&AggMinFunc, &[int(3), int(7), int(1), int(5)]);
881 assert_eq!(r, int(1));
882 }
883
884 #[test]
885 fn test_min_aggregate_null_skipped() {
886 let r = run_agg(&AggMinFunc, &[int(3), null(), int(1), null()]);
887 assert_eq!(r, int(1));
888 }
889
890 #[test]
891 fn test_min_aggregate_empty() {
892 let r = run_agg(&AggMinFunc, &[]);
893 assert_eq!(r, SqliteValue::Null);
894 }
895
896 #[test]
899 fn test_sum_integers() {
900 let r = run_agg(&SumFunc, &[int(1), int(2), int(3)]);
901 assert_eq!(r, int(6));
902 }
903
904 #[test]
905 fn test_sum_reals() {
906 let r = run_agg(&SumFunc, &[float(1.5), float(2.5)]);
907 assert_float_eq(&r, 4.0);
908 }
909
910 #[test]
911 fn test_sum_empty_null() {
912 let r = run_agg(&SumFunc, &[]);
913 assert_eq!(r, SqliteValue::Null);
914 }
915
916 #[test]
917 fn test_sum_overflow_error() {
918 let mut state = SumFunc.initial_state();
919 SumFunc.step(&mut state, &[int(i64::MAX)]).unwrap();
920 SumFunc.step(&mut state, &[int(1)]).unwrap();
921 let err = SumFunc.finalize(state);
922 assert!(err.is_err(), "sum should raise overflow error");
923 }
924
925 #[test]
926 fn test_sum_null_skipped() {
927 let r = run_agg(&SumFunc, &[int(1), null(), int(3)]);
928 assert_eq!(r, int(4));
929 }
930
931 #[test]
934 fn test_total_basic() {
935 let r = run_agg(&TotalFunc, &[int(1), int(2), int(3)]);
936 assert_float_eq(&r, 6.0);
937 }
938
939 #[test]
940 fn test_total_empty_zero() {
941 let r = run_agg(&TotalFunc, &[]);
942 assert_float_eq(&r, 0.0);
943 }
944
945 #[test]
946 fn test_total_no_overflow() {
947 let r = run_agg(&TotalFunc, &[int(i64::MAX), int(i64::MAX)]);
949 assert!(matches!(r, SqliteValue::Float(_)));
950 }
951
952 #[test]
955 fn test_median_basic() {
956 let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4), int(5)]);
957 assert_float_eq(&r, 3.0);
958 }
959
960 #[test]
961 fn test_median_even() {
962 let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4)]);
963 assert_float_eq(&r, 2.5);
964 }
965
966 #[test]
967 fn test_median_null_skipped() {
968 let r = run_agg(&MedianFunc, &[int(1), null(), int(3)]);
969 assert_float_eq(&r, 2.0);
970 }
971
972 #[test]
973 fn test_median_empty() {
974 let r = run_agg(&MedianFunc, &[]);
975 assert_eq!(r, SqliteValue::Null);
976 }
977
978 #[test]
981 fn test_percentile_50() {
982 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
984 (int(1), float(50.0)),
985 (int(2), float(50.0)),
986 (int(3), float(50.0)),
987 (int(4), float(50.0)),
988 (int(5), float(50.0)),
989 ];
990 let r = run_agg2(&PercentileFunc, &rows);
991 assert_float_eq(&r, 3.0);
992 }
993
994 #[test]
995 fn test_percentile_0() {
996 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
997 (int(10), float(0.0)),
998 (int(20), float(0.0)),
999 (int(30), float(0.0)),
1000 ];
1001 let r = run_agg2(&PercentileFunc, &rows);
1002 assert_float_eq(&r, 10.0);
1003 }
1004
1005 #[test]
1006 fn test_percentile_100() {
1007 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1008 (int(10), float(100.0)),
1009 (int(20), float(100.0)),
1010 (int(30), float(100.0)),
1011 ];
1012 let r = run_agg2(&PercentileFunc, &rows);
1013 assert_float_eq(&r, 30.0);
1014 }
1015
1016 #[test]
1019 fn test_percentile_cont_basic() {
1020 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1021 (int(1), float(0.5)),
1022 (int(2), float(0.5)),
1023 (int(3), float(0.5)),
1024 (int(4), float(0.5)),
1025 (int(5), float(0.5)),
1026 ];
1027 let r = run_agg2(&PercentileContFunc, &rows);
1028 assert_float_eq(&r, 3.0);
1029 }
1030
1031 #[test]
1034 fn test_percentile_disc_basic() {
1035 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1036 (int(1), float(0.5)),
1037 (int(2), float(0.5)),
1038 (int(3), float(0.5)),
1039 (int(4), float(0.5)),
1040 (int(5), float(0.5)),
1041 ];
1042 let r = run_agg2(&PercentileDiscFunc, &rows);
1043 match r {
1045 SqliteValue::Float(v) => {
1046 assert!(
1048 [1.0, 2.0, 3.0, 4.0, 5.0].contains(&v),
1049 "expected actual value, got {v}"
1050 );
1051 }
1052 other => {
1053 assert!(
1054 matches!(other, SqliteValue::Float(_)),
1055 "expected Float, got {other:?}"
1056 );
1057 }
1058 }
1059 }
1060
1061 #[test]
1062 fn test_percentile_disc_no_interpolation() {
1063 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1065 (int(10), float(0.5)),
1066 (int(20), float(0.5)),
1067 (int(30), float(0.5)),
1068 (int(40), float(0.5)),
1069 ];
1070 let r = run_agg2(&PercentileDiscFunc, &rows);
1071 match r {
1072 SqliteValue::Float(v) => {
1073 assert!(
1075 [10.0, 20.0, 30.0, 40.0].contains(&v),
1076 "disc must not interpolate: got {v}"
1077 );
1078 }
1079 other => {
1080 assert!(
1081 matches!(other, SqliteValue::Float(_)),
1082 "expected Float, got {other:?}"
1083 );
1084 }
1085 }
1086 }
1087
1088 #[test]
1091 fn test_string_agg_alias() {
1092 let mut reg = FunctionRegistry::new();
1093 register_aggregate_builtins(&mut reg);
1094 let sa = reg
1095 .find_aggregate("string_agg", 2)
1096 .expect("string_agg registered");
1097 let mut state = sa.initial_state();
1098 sa.step(&mut state, &[text("a"), text(",")]).unwrap();
1099 sa.step(&mut state, &[text("b"), text(",")]).unwrap();
1100 let r = sa.finalize(state).unwrap();
1101 assert_eq!(r, SqliteValue::Text("a,b".into()));
1102 }
1103
1104 #[test]
1107 fn test_register_aggregate_builtins_all_present() {
1108 let mut reg = FunctionRegistry::new();
1109 register_aggregate_builtins(&mut reg);
1110
1111 let expected = [
1112 ("avg", 1),
1113 ("count", 0), ("count", 1), ("max", 1),
1116 ("min", 1),
1117 ("sum", 1),
1118 ("total", 1),
1119 ("median", 1),
1120 ("percentile", 2),
1121 ("percentile_cont", 2),
1122 ("percentile_disc", 2),
1123 ("string_agg", 2),
1124 ];
1125
1126 for (name, arity) in expected {
1127 assert!(
1128 reg.find_aggregate(name, arity).is_some(),
1129 "aggregate '{name}/{arity}' not registered"
1130 );
1131 }
1132
1133 assert!(reg.find_aggregate("group_concat", 1).is_some());
1135 assert!(reg.find_aggregate("group_concat", 2).is_some());
1136 }
1137
1138 #[test]
1141 fn test_e2e_registry_invoke_aggregates() {
1142 let mut reg = FunctionRegistry::new();
1143 register_aggregate_builtins(&mut reg);
1144
1145 let avg = reg.find_aggregate("avg", 1).unwrap();
1147 let mut state = avg.initial_state();
1148 avg.step(&mut state, &[int(10)]).unwrap();
1149 avg.step(&mut state, &[int(20)]).unwrap();
1150 avg.step(&mut state, &[int(30)]).unwrap();
1151 let r = avg.finalize(state).unwrap();
1152 assert_float_eq(&r, 20.0);
1153
1154 let sum = reg.find_aggregate("sum", 1).unwrap();
1156 let mut state = sum.initial_state();
1157 sum.step(&mut state, &[int(1)]).unwrap();
1158 sum.step(&mut state, &[int(2)]).unwrap();
1159 sum.step(&mut state, &[int(3)]).unwrap();
1160 let r = sum.finalize(state).unwrap();
1161 assert_eq!(r, int(6));
1162 }
1163}