1#![allow(
13 clippy::unnecessary_literal_bound,
14 clippy::too_many_lines,
15 clippy::cast_possible_truncation,
16 clippy::cast_possible_wrap,
17 clippy::cast_precision_loss,
18 clippy::match_same_arms,
19 clippy::items_after_statements,
20 clippy::float_cmp,
21 clippy::cast_sign_loss,
22 clippy::suboptimal_flops
23)]
24
25use fsqlite_error::{FrankenError, Result};
26use fsqlite_types::SqliteValue;
27
28use crate::{AggregateFunction, FunctionRegistry};
29
30#[inline]
36fn kahan_add(sum: &mut f64, compensation: &mut f64, value: f64) {
37 let y = value - *compensation;
38 let t = *sum + y;
39 *compensation = (t - *sum) - y;
40 *sum = t;
41}
42
43pub struct AvgState {
48 sum: f64,
49 compensation: f64,
50 count: i64,
51}
52
53pub struct AvgFunc;
54
55impl AggregateFunction for AvgFunc {
56 type State = AvgState;
57
58 fn initial_state(&self) -> Self::State {
59 AvgState {
60 sum: 0.0,
61 compensation: 0.0,
62 count: 0,
63 }
64 }
65
66 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
67 if !args[0].is_null() {
68 kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
69 state.count += 1;
70 }
71 Ok(())
72 }
73
74 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
75 if state.count == 0 {
76 Ok(SqliteValue::Null)
77 } else {
78 Ok(SqliteValue::Float(state.sum / state.count as f64))
79 }
80 }
81
82 fn num_args(&self) -> i32 {
83 1
84 }
85
86 fn name(&self) -> &str {
87 "avg"
88 }
89}
90
91pub struct CountStarFunc;
97
98impl AggregateFunction for CountStarFunc {
99 type State = i64;
100
101 fn initial_state(&self) -> Self::State {
102 0
103 }
104
105 fn step(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
106 *state += 1;
107 Ok(())
108 }
109
110 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
111 Ok(SqliteValue::Integer(state))
112 }
113
114 fn num_args(&self) -> i32 {
115 0 }
117
118 fn name(&self) -> &str {
119 "count"
120 }
121}
122
123pub struct CountFunc;
125
126impl AggregateFunction for CountFunc {
127 type State = i64;
128
129 fn initial_state(&self) -> Self::State {
130 0
131 }
132
133 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
134 if !args[0].is_null() {
135 *state += 1;
136 }
137 Ok(())
138 }
139
140 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
141 Ok(SqliteValue::Integer(state))
142 }
143
144 fn num_args(&self) -> i32 {
145 1
146 }
147
148 fn name(&self) -> &str {
149 "count"
150 }
151}
152
153pub struct GroupConcatState {
158 values: Vec<String>,
159 separator: String,
160}
161
162pub struct GroupConcatFunc;
163
164impl AggregateFunction for GroupConcatFunc {
165 type State = GroupConcatState;
166
167 fn initial_state(&self) -> Self::State {
168 GroupConcatState {
169 values: Vec::new(),
170 separator: ",".to_owned(),
171 }
172 }
173
174 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
175 if args[0].is_null() {
176 return Ok(());
177 }
178 if args.len() > 1 && !args[1].is_null() {
181 state.separator = args[1].to_text();
182 }
183 state.values.push(args[0].to_text());
184 Ok(())
185 }
186
187 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
188 if state.values.is_empty() {
189 Ok(SqliteValue::Null)
190 } else {
191 Ok(SqliteValue::Text(state.values.join(&state.separator)))
192 }
193 }
194
195 fn num_args(&self) -> i32 {
196 -1 }
198
199 fn name(&self) -> &str {
200 "group_concat"
201 }
202}
203
204pub struct AggMaxFunc;
209
210impl AggregateFunction for AggMaxFunc {
211 type State = Option<SqliteValue>;
212
213 fn initial_state(&self) -> Self::State {
214 None
215 }
216
217 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
218 if args[0].is_null() {
219 return Ok(());
220 }
221 let candidate = &args[0];
222 match state {
223 None => *state = Some(candidate.clone()),
224 Some(current) => {
225 if candidate > current {
226 *state = Some(candidate.clone());
227 }
228 }
229 }
230 Ok(())
231 }
232
233 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
234 Ok(state.unwrap_or(SqliteValue::Null))
235 }
236
237 fn num_args(&self) -> i32 {
238 1
239 }
240
241 fn name(&self) -> &str {
242 "max"
243 }
244}
245
246pub struct AggMinFunc;
251
252impl AggregateFunction for AggMinFunc {
253 type State = Option<SqliteValue>;
254
255 fn initial_state(&self) -> Self::State {
256 None
257 }
258
259 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
260 if args[0].is_null() {
261 return Ok(());
262 }
263 let candidate = &args[0];
264 match state {
265 None => *state = Some(candidate.clone()),
266 Some(current) => {
267 if candidate < current {
268 *state = Some(candidate.clone());
269 }
270 }
271 }
272 Ok(())
273 }
274
275 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
276 Ok(state.unwrap_or(SqliteValue::Null))
277 }
278
279 fn num_args(&self) -> i32 {
280 1
281 }
282
283 fn name(&self) -> &str {
284 "min"
285 }
286}
287
288pub struct SumState {
296 int_sum: i64,
297 float_sum: f64,
298 float_compensation: f64,
299 all_integer: bool,
300 has_values: bool,
301 overflowed: bool,
302}
303
304pub struct SumFunc;
305
306impl AggregateFunction for SumFunc {
307 type State = SumState;
308
309 fn initial_state(&self) -> Self::State {
310 SumState {
311 int_sum: 0,
312 float_sum: 0.0,
313 float_compensation: 0.0,
314 all_integer: true,
315 has_values: false,
316 overflowed: false,
317 }
318 }
319
320 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
321 if args[0].is_null() {
322 return Ok(());
323 }
324 state.has_values = true;
325 match &args[0] {
326 SqliteValue::Integer(i) => {
327 if state.all_integer && !state.overflowed {
328 match state.int_sum.checked_add(*i) {
329 Some(s) => state.int_sum = s,
330 None => state.overflowed = true,
331 }
332 }
333 kahan_add(
334 &mut state.float_sum,
335 &mut state.float_compensation,
336 *i as f64,
337 );
338 }
339 other => {
340 state.all_integer = false;
341 kahan_add(
342 &mut state.float_sum,
343 &mut state.float_compensation,
344 other.to_float(),
345 );
346 }
347 }
348 Ok(())
349 }
350
351 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
352 if !state.has_values {
353 return Ok(SqliteValue::Null);
354 }
355 if state.overflowed {
356 return Err(FrankenError::IntegerOverflow);
357 }
358 if state.all_integer {
359 Ok(SqliteValue::Integer(state.int_sum))
360 } else {
361 Ok(SqliteValue::Float(state.float_sum))
362 }
363 }
364
365 fn num_args(&self) -> i32 {
366 1
367 }
368
369 fn name(&self) -> &str {
370 "sum"
371 }
372}
373
374pub struct TotalFunc;
379
380pub struct TotalState {
382 sum: f64,
383 compensation: f64,
384}
385
386impl AggregateFunction for TotalFunc {
387 type State = TotalState;
388
389 fn initial_state(&self) -> Self::State {
390 TotalState {
391 sum: 0.0,
392 compensation: 0.0,
393 }
394 }
395
396 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
397 if !args[0].is_null() {
398 kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
399 }
400 Ok(())
401 }
402
403 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
404 Ok(SqliteValue::Float(state.sum))
405 }
406
407 fn num_args(&self) -> i32 {
408 1
409 }
410
411 fn name(&self) -> &str {
412 "total"
413 }
414}
415
416pub struct MedianFunc;
421
422impl AggregateFunction for MedianFunc {
423 type State = Vec<f64>;
424
425 fn initial_state(&self) -> Self::State {
426 Vec::new()
427 }
428
429 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
430 if !args[0].is_null() {
431 state.push(args[0].to_float());
432 }
433 Ok(())
434 }
435
436 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
437 if state.is_empty() {
438 return Ok(SqliteValue::Null);
439 }
440 state.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
441 let result = percentile_cont_impl(&state, 0.5);
442 Ok(SqliteValue::Float(result))
443 }
444
445 fn num_args(&self) -> i32 {
446 1
447 }
448
449 fn name(&self) -> &str {
450 "median"
451 }
452}
453
454pub struct PercentileState {
459 values: Vec<f64>,
460 p: Option<f64>,
461}
462
463pub struct PercentileFunc;
464
465impl AggregateFunction for PercentileFunc {
466 type State = PercentileState;
467
468 fn initial_state(&self) -> Self::State {
469 PercentileState {
470 values: Vec::new(),
471 p: None,
472 }
473 }
474
475 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
476 if !args[0].is_null() {
477 state.values.push(args[0].to_float());
478 }
479 if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
481 state.p = Some(args[1].to_float());
482 }
483 Ok(())
484 }
485
486 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
487 if state.values.is_empty() {
488 return Ok(SqliteValue::Null);
489 }
490 let p = state.p.unwrap_or(50.0);
491 state
492 .values
493 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
494 let result = percentile_cont_impl(&state.values, p / 100.0);
496 Ok(SqliteValue::Float(result))
497 }
498
499 fn num_args(&self) -> i32 {
500 2
501 }
502
503 fn name(&self) -> &str {
504 "percentile"
505 }
506}
507
508pub struct PercentileContFunc;
513
514impl AggregateFunction for PercentileContFunc {
515 type State = PercentileState;
516
517 fn initial_state(&self) -> Self::State {
518 PercentileState {
519 values: Vec::new(),
520 p: None,
521 }
522 }
523
524 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
525 if !args[0].is_null() {
526 state.values.push(args[0].to_float());
527 }
528 if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
529 state.p = Some(args[1].to_float());
530 }
531 Ok(())
532 }
533
534 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
535 if state.values.is_empty() {
536 return Ok(SqliteValue::Null);
537 }
538 let p = state.p.unwrap_or(0.5);
539 state
540 .values
541 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
542 let result = percentile_cont_impl(&state.values, p);
543 Ok(SqliteValue::Float(result))
544 }
545
546 fn num_args(&self) -> i32 {
547 2
548 }
549
550 fn name(&self) -> &str {
551 "percentile_cont"
552 }
553}
554
555pub struct PercentileDiscFunc;
560
561impl AggregateFunction for PercentileDiscFunc {
562 type State = PercentileState;
563
564 fn initial_state(&self) -> Self::State {
565 PercentileState {
566 values: Vec::new(),
567 p: None,
568 }
569 }
570
571 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
572 if !args[0].is_null() {
573 state.values.push(args[0].to_float());
574 }
575 if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
576 state.p = Some(args[1].to_float());
577 }
578 Ok(())
579 }
580
581 fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
582 if state.values.is_empty() {
583 return Ok(SqliteValue::Null);
584 }
585 let p = state.p.unwrap_or(0.5);
586 state
587 .values
588 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
589 let n = state.values.len();
591 let idx = ((p * n as f64).ceil() as usize)
592 .saturating_sub(1)
593 .min(n - 1);
594 Ok(SqliteValue::Float(state.values[idx]))
595 }
596
597 fn num_args(&self) -> i32 {
598 2
599 }
600
601 fn name(&self) -> &str {
602 "percentile_disc"
603 }
604}
605
606fn percentile_cont_impl(sorted: &[f64], p: f64) -> f64 {
611 let n = sorted.len();
612 if n == 1 {
613 return sorted[0];
614 }
615 let p = p.clamp(0.0, 1.0);
616 let rank = p * (n - 1) as f64;
617 let lower = rank.floor() as usize;
618 let upper = rank.ceil() as usize;
619 if lower == upper {
620 sorted[lower]
621 } else {
622 let frac = rank - lower as f64;
623 sorted[lower] * (1.0 - frac) + sorted[upper] * frac
624 }
625}
626
627pub fn register_aggregate_builtins(registry: &mut FunctionRegistry) {
631 registry.register_aggregate(AvgFunc);
632 registry.register_aggregate(CountStarFunc);
633 registry.register_aggregate(CountFunc);
634 registry.register_aggregate(GroupConcatFunc);
635 registry.register_aggregate(AggMaxFunc);
636 registry.register_aggregate(AggMinFunc);
637 registry.register_aggregate(SumFunc);
638 registry.register_aggregate(TotalFunc);
639 registry.register_aggregate(MedianFunc);
640 registry.register_aggregate(PercentileFunc);
641 registry.register_aggregate(PercentileContFunc);
642 registry.register_aggregate(PercentileDiscFunc);
643
644 struct StringAggFunc;
646 impl AggregateFunction for StringAggFunc {
647 type State = GroupConcatState;
648
649 fn initial_state(&self) -> Self::State {
650 GroupConcatState {
651 values: Vec::new(),
652 separator: ",".to_owned(),
653 }
654 }
655
656 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
657 GroupConcatFunc.step(state, args)
658 }
659
660 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
661 GroupConcatFunc.finalize(state)
662 }
663
664 fn num_args(&self) -> i32 {
665 2 }
667
668 fn name(&self) -> &str {
669 "string_agg"
670 }
671 }
672 registry.register_aggregate(StringAggFunc);
673}
674
675#[cfg(test)]
678mod tests {
679 use super::*;
680
681 const EPS: f64 = 1e-12;
682
683 fn int(v: i64) -> SqliteValue {
684 SqliteValue::Integer(v)
685 }
686
687 fn float(v: f64) -> SqliteValue {
688 SqliteValue::Float(v)
689 }
690
691 fn null() -> SqliteValue {
692 SqliteValue::Null
693 }
694
695 fn text(s: &str) -> SqliteValue {
696 SqliteValue::Text(s.to_owned())
697 }
698
699 fn assert_float_eq(result: &SqliteValue, expected: f64) {
700 match result {
701 SqliteValue::Float(v) => {
702 assert!((v - expected).abs() < EPS, "expected {expected}, got {v}");
703 }
704 other => {
705 assert!(
706 matches!(other, SqliteValue::Float(_)),
707 "expected Float({expected}), got {other:?}"
708 );
709 }
710 }
711 }
712
713 fn run_agg<F: AggregateFunction>(func: &F, rows: &[SqliteValue]) -> SqliteValue {
715 let mut state = func.initial_state();
716 for row in rows {
717 func.step(&mut state, std::slice::from_ref(row)).unwrap();
718 }
719 func.finalize(state).unwrap()
720 }
721
722 fn run_agg2<F: AggregateFunction>(
724 func: &F,
725 rows: &[(SqliteValue, SqliteValue)],
726 ) -> SqliteValue {
727 let mut state = func.initial_state();
728 for (a, b) in rows {
729 func.step(&mut state, &[a.clone(), b.clone()]).unwrap();
730 }
731 func.finalize(state).unwrap()
732 }
733
734 #[test]
737 fn test_avg_basic() {
738 let r = run_agg(&AvgFunc, &[int(1), int(2), int(3), int(4), int(5)]);
739 assert_float_eq(&r, 3.0);
740 }
741
742 #[test]
743 fn test_avg_with_nulls() {
744 let r = run_agg(&AvgFunc, &[int(1), null(), int(3)]);
745 assert_float_eq(&r, 2.0);
746 }
747
748 #[test]
749 fn test_avg_empty() {
750 let r = run_agg(&AvgFunc, &[]);
751 assert_eq!(r, SqliteValue::Null);
752 }
753
754 #[test]
755 fn test_avg_returns_real() {
756 let r = run_agg(&AvgFunc, &[int(2), int(4)]);
757 assert!(matches!(r, SqliteValue::Float(_)));
758 }
759
760 #[test]
763 fn test_count_star() {
764 let mut state = CountStarFunc.initial_state();
766 CountStarFunc.step(&mut state, &[]).unwrap(); CountStarFunc.step(&mut state, &[]).unwrap(); CountStarFunc.step(&mut state, &[]).unwrap(); let r = CountStarFunc.finalize(state).unwrap();
770 assert_eq!(r, int(3));
771 }
772
773 #[test]
774 fn test_count_column() {
775 let r = run_agg(&CountFunc, &[int(1), null(), int(3), null(), int(5)]);
776 assert_eq!(r, int(3));
777 }
778
779 #[test]
780 fn test_count_empty() {
781 let r = run_agg(&CountFunc, &[]);
782 assert_eq!(r, int(0));
783 }
784
785 #[test]
788 fn test_group_concat_basic() {
789 let r = run_agg(&GroupConcatFunc, &[text("a"), text("b"), text("c")]);
790 assert_eq!(r, SqliteValue::Text("a,b,c".to_owned()));
791 }
792
793 #[test]
794 fn test_group_concat_custom_sep() {
795 let rows = vec![
796 (text("a"), text("; ")),
797 (text("b"), text("; ")),
798 (text("c"), text("; ")),
799 ];
800 let r = run_agg2(&GroupConcatFunc, &rows);
801 assert_eq!(r, SqliteValue::Text("a; b; c".to_owned()));
802 }
803
804 #[test]
805 fn test_group_concat_null_skipped() {
806 let r = run_agg(&GroupConcatFunc, &[text("a"), null(), text("c")]);
807 assert_eq!(r, SqliteValue::Text("a,c".to_owned()));
808 }
809
810 #[test]
811 fn test_group_concat_empty() {
812 let r = run_agg(&GroupConcatFunc, &[]);
813 assert_eq!(r, SqliteValue::Null);
814 }
815
816 #[test]
819 fn test_max_aggregate() {
820 let r = run_agg(&AggMaxFunc, &[int(3), int(7), int(1), int(5)]);
821 assert_eq!(r, int(7));
822 }
823
824 #[test]
825 fn test_max_aggregate_null_skipped() {
826 let r = run_agg(&AggMaxFunc, &[int(3), null(), int(7), null()]);
827 assert_eq!(r, int(7));
828 }
829
830 #[test]
831 fn test_max_aggregate_empty() {
832 let r = run_agg(&AggMaxFunc, &[]);
833 assert_eq!(r, SqliteValue::Null);
834 }
835
836 #[test]
839 fn test_min_aggregate() {
840 let r = run_agg(&AggMinFunc, &[int(3), int(7), int(1), int(5)]);
841 assert_eq!(r, int(1));
842 }
843
844 #[test]
845 fn test_min_aggregate_null_skipped() {
846 let r = run_agg(&AggMinFunc, &[int(3), null(), int(1), null()]);
847 assert_eq!(r, int(1));
848 }
849
850 #[test]
851 fn test_min_aggregate_empty() {
852 let r = run_agg(&AggMinFunc, &[]);
853 assert_eq!(r, SqliteValue::Null);
854 }
855
856 #[test]
859 fn test_sum_integers() {
860 let r = run_agg(&SumFunc, &[int(1), int(2), int(3)]);
861 assert_eq!(r, int(6));
862 }
863
864 #[test]
865 fn test_sum_reals() {
866 let r = run_agg(&SumFunc, &[float(1.5), float(2.5)]);
867 assert_float_eq(&r, 4.0);
868 }
869
870 #[test]
871 fn test_sum_empty_null() {
872 let r = run_agg(&SumFunc, &[]);
873 assert_eq!(r, SqliteValue::Null);
874 }
875
876 #[test]
877 fn test_sum_overflow_error() {
878 let mut state = SumFunc.initial_state();
879 SumFunc.step(&mut state, &[int(i64::MAX)]).unwrap();
880 SumFunc.step(&mut state, &[int(1)]).unwrap();
881 let err = SumFunc.finalize(state);
882 assert!(err.is_err(), "sum should raise overflow error");
883 }
884
885 #[test]
886 fn test_sum_null_skipped() {
887 let r = run_agg(&SumFunc, &[int(1), null(), int(3)]);
888 assert_eq!(r, int(4));
889 }
890
891 #[test]
894 fn test_total_basic() {
895 let r = run_agg(&TotalFunc, &[int(1), int(2), int(3)]);
896 assert_float_eq(&r, 6.0);
897 }
898
899 #[test]
900 fn test_total_empty_zero() {
901 let r = run_agg(&TotalFunc, &[]);
902 assert_float_eq(&r, 0.0);
903 }
904
905 #[test]
906 fn test_total_no_overflow() {
907 let r = run_agg(&TotalFunc, &[int(i64::MAX), int(i64::MAX)]);
909 assert!(matches!(r, SqliteValue::Float(_)));
910 }
911
912 #[test]
915 fn test_median_basic() {
916 let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4), int(5)]);
917 assert_float_eq(&r, 3.0);
918 }
919
920 #[test]
921 fn test_median_even() {
922 let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4)]);
923 assert_float_eq(&r, 2.5);
924 }
925
926 #[test]
927 fn test_median_null_skipped() {
928 let r = run_agg(&MedianFunc, &[int(1), null(), int(3)]);
929 assert_float_eq(&r, 2.0);
930 }
931
932 #[test]
933 fn test_median_empty() {
934 let r = run_agg(&MedianFunc, &[]);
935 assert_eq!(r, SqliteValue::Null);
936 }
937
938 #[test]
941 fn test_percentile_50() {
942 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
944 (int(1), float(50.0)),
945 (int(2), float(50.0)),
946 (int(3), float(50.0)),
947 (int(4), float(50.0)),
948 (int(5), float(50.0)),
949 ];
950 let r = run_agg2(&PercentileFunc, &rows);
951 assert_float_eq(&r, 3.0);
952 }
953
954 #[test]
955 fn test_percentile_0() {
956 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
957 (int(10), float(0.0)),
958 (int(20), float(0.0)),
959 (int(30), float(0.0)),
960 ];
961 let r = run_agg2(&PercentileFunc, &rows);
962 assert_float_eq(&r, 10.0);
963 }
964
965 #[test]
966 fn test_percentile_100() {
967 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
968 (int(10), float(100.0)),
969 (int(20), float(100.0)),
970 (int(30), float(100.0)),
971 ];
972 let r = run_agg2(&PercentileFunc, &rows);
973 assert_float_eq(&r, 30.0);
974 }
975
976 #[test]
979 fn test_percentile_cont_basic() {
980 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
981 (int(1), float(0.5)),
982 (int(2), float(0.5)),
983 (int(3), float(0.5)),
984 (int(4), float(0.5)),
985 (int(5), float(0.5)),
986 ];
987 let r = run_agg2(&PercentileContFunc, &rows);
988 assert_float_eq(&r, 3.0);
989 }
990
991 #[test]
994 fn test_percentile_disc_basic() {
995 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
996 (int(1), float(0.5)),
997 (int(2), float(0.5)),
998 (int(3), float(0.5)),
999 (int(4), float(0.5)),
1000 (int(5), float(0.5)),
1001 ];
1002 let r = run_agg2(&PercentileDiscFunc, &rows);
1003 match r {
1005 SqliteValue::Float(v) => {
1006 assert!(
1008 [1.0, 2.0, 3.0, 4.0, 5.0].contains(&v),
1009 "expected actual value, got {v}"
1010 );
1011 }
1012 other => {
1013 assert!(
1014 matches!(other, SqliteValue::Float(_)),
1015 "expected Float, got {other:?}"
1016 );
1017 }
1018 }
1019 }
1020
1021 #[test]
1022 fn test_percentile_disc_no_interpolation() {
1023 let rows: Vec<(SqliteValue, SqliteValue)> = vec![
1025 (int(10), float(0.5)),
1026 (int(20), float(0.5)),
1027 (int(30), float(0.5)),
1028 (int(40), float(0.5)),
1029 ];
1030 let r = run_agg2(&PercentileDiscFunc, &rows);
1031 match r {
1032 SqliteValue::Float(v) => {
1033 assert!(
1035 [10.0, 20.0, 30.0, 40.0].contains(&v),
1036 "disc must not interpolate: got {v}"
1037 );
1038 }
1039 other => {
1040 assert!(
1041 matches!(other, SqliteValue::Float(_)),
1042 "expected Float, got {other:?}"
1043 );
1044 }
1045 }
1046 }
1047
1048 #[test]
1051 fn test_string_agg_alias() {
1052 let mut reg = FunctionRegistry::new();
1053 register_aggregate_builtins(&mut reg);
1054 let sa = reg
1055 .find_aggregate("string_agg", 2)
1056 .expect("string_agg registered");
1057 let mut state = sa.initial_state();
1058 sa.step(&mut state, &[text("a"), text(",")]).unwrap();
1059 sa.step(&mut state, &[text("b"), text(",")]).unwrap();
1060 let r = sa.finalize(state).unwrap();
1061 assert_eq!(r, SqliteValue::Text("a,b".to_owned()));
1062 }
1063
1064 #[test]
1067 fn test_register_aggregate_builtins_all_present() {
1068 let mut reg = FunctionRegistry::new();
1069 register_aggregate_builtins(&mut reg);
1070
1071 let expected = [
1072 ("avg", 1),
1073 ("count", 0), ("count", 1), ("max", 1),
1076 ("min", 1),
1077 ("sum", 1),
1078 ("total", 1),
1079 ("median", 1),
1080 ("percentile", 2),
1081 ("percentile_cont", 2),
1082 ("percentile_disc", 2),
1083 ("string_agg", 2),
1084 ];
1085
1086 for (name, arity) in expected {
1087 assert!(
1088 reg.find_aggregate(name, arity).is_some(),
1089 "aggregate '{name}/{arity}' not registered"
1090 );
1091 }
1092
1093 assert!(reg.find_aggregate("group_concat", 1).is_some());
1095 assert!(reg.find_aggregate("group_concat", 2).is_some());
1096 }
1097
1098 #[test]
1101 fn test_e2e_registry_invoke_aggregates() {
1102 let mut reg = FunctionRegistry::new();
1103 register_aggregate_builtins(&mut reg);
1104
1105 let avg = reg.find_aggregate("avg", 1).unwrap();
1107 let mut state = avg.initial_state();
1108 avg.step(&mut state, &[int(10)]).unwrap();
1109 avg.step(&mut state, &[int(20)]).unwrap();
1110 avg.step(&mut state, &[int(30)]).unwrap();
1111 let r = avg.finalize(state).unwrap();
1112 assert_float_eq(&r, 20.0);
1113
1114 let sum = reg.find_aggregate("sum", 1).unwrap();
1116 let mut state = sum.initial_state();
1117 sum.step(&mut state, &[int(1)]).unwrap();
1118 sum.step(&mut state, &[int(2)]).unwrap();
1119 sum.step(&mut state, &[int(3)]).unwrap();
1120 let r = sum.finalize(state).unwrap();
1121 assert_eq!(r, int(6));
1122 }
1123}