1#![allow(
22 clippy::unnecessary_literal_bound,
23 clippy::cast_possible_truncation,
24 clippy::cast_possible_wrap,
25 clippy::cast_precision_loss,
26 clippy::cast_sign_loss,
27 clippy::items_after_statements,
28 clippy::float_cmp,
29 clippy::match_same_arms,
30 clippy::similar_names
31)]
32
33use std::collections::VecDeque;
34
35use fsqlite_error::{FrankenError, Result};
36use fsqlite_types::SqliteValue;
37
38use crate::{FunctionRegistry, WindowFunction};
39
40pub struct RowNumberState {
45 counter: i64,
46}
47
48pub struct RowNumberFunc;
49
50impl WindowFunction for RowNumberFunc {
51 type State = RowNumberState;
52
53 fn initial_state(&self) -> Self::State {
54 RowNumberState { counter: 0 }
55 }
56
57 fn step(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
58 state.counter += 1;
59 Ok(())
60 }
61
62 fn inverse(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
63 state.counter -= 1;
64 Ok(())
65 }
66
67 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
68 Ok(SqliteValue::Integer(state.counter))
69 }
70
71 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
72 Ok(SqliteValue::Integer(state.counter))
73 }
74
75 fn num_args(&self) -> i32 {
76 0
77 }
78
79 fn name(&self) -> &str {
80 "row_number"
81 }
82}
83
84pub struct RankState {
89 row_number: i64,
90 rank: i64,
91 last_order_value: Option<SqliteValue>,
92}
93
94pub struct RankFunc;
95
96impl WindowFunction for RankFunc {
97 type State = RankState;
98
99 fn initial_state(&self) -> Self::State {
100 RankState {
101 row_number: 0,
102 rank: 0,
103 last_order_value: None,
104 }
105 }
106
107 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
108 state.row_number += 1;
109 let current = args.first().cloned().unwrap_or(SqliteValue::Null);
110 let is_new_peer = match &state.last_order_value {
111 None => true,
112 Some(last) => ¤t != last,
113 };
114 if is_new_peer {
115 state.rank = state.row_number;
116 state.last_order_value = Some(current);
117 }
118 Ok(())
119 }
120
121 fn inverse(&self, _state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
122 Ok(())
124 }
125
126 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
127 Ok(SqliteValue::Integer(state.rank))
128 }
129
130 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
131 Ok(SqliteValue::Integer(state.rank))
132 }
133
134 fn num_args(&self) -> i32 {
135 -1
136 }
137
138 fn name(&self) -> &str {
139 "rank"
140 }
141}
142
143pub struct DenseRankState {
148 dense_rank: i64,
149 last_order_value: Option<SqliteValue>,
150}
151
152pub struct DenseRankFunc;
153
154impl WindowFunction for DenseRankFunc {
155 type State = DenseRankState;
156
157 fn initial_state(&self) -> Self::State {
158 DenseRankState {
159 dense_rank: 0,
160 last_order_value: None,
161 }
162 }
163
164 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
165 let current = args.first().cloned().unwrap_or(SqliteValue::Null);
166 let is_new_peer = match &state.last_order_value {
167 None => true,
168 Some(last) => ¤t != last,
169 };
170 if is_new_peer {
171 state.dense_rank += 1;
172 state.last_order_value = Some(current);
173 }
174 Ok(())
175 }
176
177 fn inverse(&self, _state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
178 Ok(())
179 }
180
181 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
182 Ok(SqliteValue::Integer(state.dense_rank))
183 }
184
185 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
186 Ok(SqliteValue::Integer(state.dense_rank))
187 }
188
189 fn num_args(&self) -> i32 {
190 -1
191 }
192
193 fn name(&self) -> &str {
194 "dense_rank"
195 }
196}
197
198pub struct PercentRankState {
208 partition_size: i64,
209 ranks: Vec<i64>,
210 cursor: usize,
211 step_row_number: i64,
212 current_rank: i64,
213 last_order_value: Option<SqliteValue>,
214}
215
216pub struct PercentRankFunc;
217
218impl WindowFunction for PercentRankFunc {
219 type State = PercentRankState;
220
221 fn initial_state(&self) -> Self::State {
222 PercentRankState {
223 partition_size: 0,
224 ranks: Vec::new(),
225 cursor: 0,
226 step_row_number: 0,
227 current_rank: 0,
228 last_order_value: None,
229 }
230 }
231
232 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
233 state.step_row_number += 1;
234 state.partition_size += 1;
235 let current = args.first().cloned().unwrap_or(SqliteValue::Null);
236 let is_new_peer = match &state.last_order_value {
237 None => true,
238 Some(last) => ¤t != last,
239 };
240 if is_new_peer {
241 state.current_rank = state.step_row_number;
242 state.last_order_value = Some(current);
243 }
244 state.ranks.push(state.current_rank);
245 Ok(())
246 }
247
248 fn inverse(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
249 state.cursor += 1;
250 Ok(())
251 }
252
253 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
254 if state.partition_size <= 1 {
255 return Ok(SqliteValue::Float(0.0));
256 }
257 let rank = state.ranks.get(state.cursor).copied().unwrap_or(1);
258 let pr = (rank - 1) as f64 / (state.partition_size - 1) as f64;
259 Ok(SqliteValue::Float(pr))
260 }
261
262 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
263 self.value(&state)
264 }
265
266 fn num_args(&self) -> i32 {
267 -1
268 }
269
270 fn name(&self) -> &str {
271 "percent_rank"
272 }
273}
274
275pub struct CumeDistState {
284 partition_size: i64,
285 current_row: i64,
286}
287
288pub struct CumeDistFunc;
289
290impl WindowFunction for CumeDistFunc {
291 type State = CumeDistState;
292
293 fn initial_state(&self) -> Self::State {
294 CumeDistState {
295 partition_size: 0,
296 current_row: 0,
297 }
298 }
299
300 fn step(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
301 state.partition_size += 1;
302 Ok(())
303 }
304
305 fn inverse(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
306 state.current_row += 1;
307 Ok(())
308 }
309
310 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
311 if state.partition_size == 0 {
312 return Ok(SqliteValue::Float(0.0));
313 }
314 let cd = (state.current_row + 1) as f64 / state.partition_size as f64;
315 Ok(SqliteValue::Float(cd))
316 }
317
318 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
319 self.value(&state)
320 }
321
322 fn num_args(&self) -> i32 {
323 -1
324 }
325
326 fn name(&self) -> &str {
327 "cume_dist"
328 }
329}
330
331pub struct NtileState {
341 partition_size: i64,
342 n: i64,
343 current_row: i64,
344}
345
346pub struct NtileFunc;
347
348impl WindowFunction for NtileFunc {
349 type State = NtileState;
350
351 fn initial_state(&self) -> Self::State {
352 NtileState {
353 partition_size: 0,
354 n: 1,
355 current_row: 0,
356 }
357 }
358
359 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
360 state.partition_size += 1;
361 if state.partition_size == 1 {
362 let n = args.first().map_or(1, |v| v.to_integer().max(1));
363 state.n = n;
364 }
365 Ok(())
366 }
367
368 fn inverse(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
369 state.current_row += 1;
370 Ok(())
371 }
372
373 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
374 if state.n <= 0 || state.partition_size == 0 {
375 return Ok(SqliteValue::Integer(1));
376 }
377 let n = state.n;
378 let sz = state.partition_size;
379 let row = state.current_row + 1; let base = sz / n;
384 let extra = sz % n;
385 let large_rows = extra * (base + 1);
387
388 let bucket = if row <= large_rows {
389 (row - 1) / (base + 1) + 1
391 } else {
392 let adjusted = row - large_rows;
394 if base == 0 {
395 extra + adjusted
397 } else {
398 extra + (adjusted - 1) / base + 1
399 }
400 };
401 Ok(SqliteValue::Integer(bucket))
402 }
403
404 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
405 self.value(&state)
406 }
407
408 fn num_args(&self) -> i32 {
409 1
410 }
411
412 fn name(&self) -> &str {
413 "ntile"
414 }
415}
416
417pub struct LagState {
423 buffer: VecDeque<SqliteValue>,
424 offset: i64,
425 default_val: SqliteValue,
426 row_number: i64,
427}
428
429pub struct LagFunc;
430
431impl WindowFunction for LagFunc {
432 type State = LagState;
433
434 fn initial_state(&self) -> Self::State {
435 LagState {
436 buffer: VecDeque::new(),
437 offset: 1,
438 default_val: SqliteValue::Null,
439 row_number: 0,
440 }
441 }
442
443 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
444 let val = args.first().cloned().unwrap_or(SqliteValue::Null);
445 if state.row_number == 0 {
447 if let Some(off) = args.get(1) {
448 state.offset = off.to_integer().max(0);
449 }
450 if let Some(def) = args.get(2) {
451 state.default_val = def.clone();
452 }
453 }
454 state.row_number += 1;
455 state.buffer.push_back(val);
456 Ok(())
457 }
458
459 fn inverse(&self, _state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
460 Ok(())
461 }
462
463 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
464 let idx = state.row_number - state.offset;
465 if idx < 1 || idx > state.buffer.len() as i64 {
466 return Ok(state.default_val.clone());
467 }
468 Ok(state.buffer[(idx - 1) as usize].clone())
469 }
470
471 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
472 self.value(&state)
473 }
474
475 fn num_args(&self) -> i32 {
476 -1 }
478
479 fn name(&self) -> &str {
480 "lag"
481 }
482}
483
484pub struct LeadState {
490 buffer: Vec<SqliteValue>,
491 offset: i64,
492 default_val: SqliteValue,
493 current_row: i64,
494}
495
496pub struct LeadFunc;
497
498impl WindowFunction for LeadFunc {
499 type State = LeadState;
500
501 fn initial_state(&self) -> Self::State {
502 LeadState {
503 buffer: Vec::new(),
504 offset: 1,
505 default_val: SqliteValue::Null,
506 current_row: 0,
507 }
508 }
509
510 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
511 let val = args.first().cloned().unwrap_or(SqliteValue::Null);
512 if state.buffer.is_empty() {
513 if let Some(off) = args.get(1) {
514 state.offset = off.to_integer().max(0);
515 }
516 if let Some(def) = args.get(2) {
517 state.default_val = def.clone();
518 }
519 }
520 state.buffer.push(val);
521 Ok(())
522 }
523
524 fn inverse(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
525 state.current_row += 1;
526 Ok(())
527 }
528
529 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
530 let target = state.current_row + state.offset;
531 if target < 0 || target >= state.buffer.len() as i64 {
532 return Ok(state.default_val.clone());
533 }
534 Ok(state.buffer[target as usize].clone())
535 }
536
537 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
538 self.value(&state)
539 }
540
541 fn num_args(&self) -> i32 {
542 -1
543 }
544
545 fn name(&self) -> &str {
546 "lead"
547 }
548}
549
550pub struct FirstValueState {
555 first: Option<SqliteValue>,
556}
557
558pub struct FirstValueFunc;
559
560impl WindowFunction for FirstValueFunc {
561 type State = FirstValueState;
562
563 fn initial_state(&self) -> Self::State {
564 FirstValueState { first: None }
565 }
566
567 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
568 if state.first.is_none() {
569 state.first = Some(args.first().cloned().unwrap_or(SqliteValue::Null));
570 }
571 Ok(())
572 }
573
574 fn inverse(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
575 state.first = None;
580 Ok(())
581 }
582
583 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
584 Ok(state.first.clone().unwrap_or(SqliteValue::Null))
585 }
586
587 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
588 Ok(state.first.unwrap_or(SqliteValue::Null))
589 }
590
591 fn num_args(&self) -> i32 {
592 1
593 }
594
595 fn name(&self) -> &str {
596 "first_value"
597 }
598}
599
600pub struct LastValueState {
605 frame: VecDeque<SqliteValue>,
606}
607
608pub struct LastValueFunc;
609
610impl WindowFunction for LastValueFunc {
611 type State = LastValueState;
612
613 fn initial_state(&self) -> Self::State {
614 LastValueState {
615 frame: VecDeque::new(),
616 }
617 }
618
619 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
620 state
621 .frame
622 .push_back(args.first().cloned().unwrap_or(SqliteValue::Null));
623 Ok(())
624 }
625
626 fn inverse(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
627 state.frame.pop_front();
628 Ok(())
629 }
630
631 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
632 Ok(state.frame.back().cloned().unwrap_or(SqliteValue::Null))
633 }
634
635 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
636 Ok(state.frame.back().cloned().unwrap_or(SqliteValue::Null))
637 }
638
639 fn num_args(&self) -> i32 {
640 1
641 }
642
643 fn name(&self) -> &str {
644 "last_value"
645 }
646}
647
648pub struct NthValueState {
653 frame: VecDeque<SqliteValue>,
654 n: i64,
655}
656
657pub struct NthValueFunc;
658
659impl WindowFunction for NthValueFunc {
660 type State = NthValueState;
661
662 fn initial_state(&self) -> Self::State {
663 NthValueState {
664 frame: VecDeque::new(),
665 n: 1,
666 }
667 }
668
669 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
670 let val = args.first().cloned().unwrap_or(SqliteValue::Null);
671 if state.frame.is_empty() {
673 if let Some(n_arg) = args.get(1) {
674 state.n = n_arg.to_integer();
675 }
676 }
677 state.frame.push_back(val);
678 Ok(())
679 }
680
681 fn inverse(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
682 state.frame.pop_front();
683 Ok(())
684 }
685
686 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
687 if state.n <= 0 {
690 return Ok(SqliteValue::Null);
691 }
692 let idx = (state.n - 1) as usize;
693 Ok(state.frame.get(idx).cloned().unwrap_or(SqliteValue::Null))
694 }
695
696 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
697 self.value(&state)
698 }
699
700 fn num_args(&self) -> i32 {
701 2
702 }
703
704 fn name(&self) -> &str {
705 "nth_value"
706 }
707}
708
709pub struct WindowSumState {
715 sum: f64,
716 err: f64,
717 has_value: bool,
718 is_int: bool,
719 int_sum: i64,
720 overflowed: bool,
721}
722
723#[inline]
725fn kbn_step(sum: &mut f64, err: &mut f64, value: f64) {
726 let s = *sum;
727 let t = s + value;
728 if s.abs() > value.abs() {
729 *err += (s - t) + value;
730 } else {
731 *err += (value - t) + s;
732 }
733 *sum = t;
734}
735
736pub struct WindowSumFunc;
737
738impl WindowFunction for WindowSumFunc {
739 type State = WindowSumState;
740
741 fn initial_state(&self) -> Self::State {
742 WindowSumState {
743 sum: 0.0,
744 err: 0.0,
745 has_value: false,
746 is_int: true,
747 int_sum: 0,
748 overflowed: false,
749 }
750 }
751
752 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
753 if args.is_empty() || args[0].is_null() {
754 return Ok(());
755 }
756 state.has_value = true;
757 match &args[0] {
758 SqliteValue::Integer(i) => {
759 if state.is_int && !state.overflowed {
760 match state.int_sum.checked_add(*i) {
761 Some(s) => state.int_sum = s,
762 None => state.overflowed = true,
763 }
764 }
765 kbn_step(&mut state.sum, &mut state.err, *i as f64);
766 }
767 SqliteValue::Float(f) => {
768 state.is_int = false;
769 kbn_step(&mut state.sum, &mut state.err, *f);
770 }
771 other => {
772 let f = other.to_float();
773 if f != 0.0 || other.to_text() == "0" {
774 state.is_int = false;
775 kbn_step(&mut state.sum, &mut state.err, f);
776 }
777 }
778 }
779 Ok(())
780 }
781
782 fn inverse(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
783 if args.is_empty() || args[0].is_null() {
784 return Ok(());
785 }
786 match &args[0] {
787 SqliteValue::Integer(i) => {
788 if state.is_int && !state.overflowed {
789 match state.int_sum.checked_sub(*i) {
790 Some(s) => state.int_sum = s,
791 None => state.overflowed = true,
792 }
793 }
794 kbn_step(&mut state.sum, &mut state.err, -(*i as f64));
795 }
796 _ => {
797 kbn_step(&mut state.sum, &mut state.err, -args[0].to_float());
798 }
799 }
800 Ok(())
801 }
802
803 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
804 if !state.has_value {
805 return Ok(SqliteValue::Null);
806 }
807 if state.overflowed {
808 return Err(FrankenError::IntegerOverflow);
809 }
810 if state.is_int {
811 Ok(SqliteValue::Integer(state.int_sum))
812 } else {
813 Ok(SqliteValue::Float(state.sum + state.err))
814 }
815 }
816
817 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
818 self.value(&state)
819 }
820
821 fn num_args(&self) -> i32 {
822 1
823 }
824
825 fn name(&self) -> &str {
826 "SUM"
827 }
828}
829
830pub struct WindowTotalFunc;
831
832impl WindowFunction for WindowTotalFunc {
833 type State = f64;
834
835 fn initial_state(&self) -> Self::State {
836 0.0
837 }
838
839 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
840 if !args.is_empty() && !args[0].is_null() {
841 *state += args[0].to_float();
842 }
843 Ok(())
844 }
845
846 fn inverse(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
847 if !args.is_empty() && !args[0].is_null() {
848 *state -= args[0].to_float();
849 }
850 Ok(())
851 }
852
853 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
854 Ok(SqliteValue::Float(*state))
855 }
856
857 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
858 Ok(SqliteValue::Float(state))
859 }
860
861 fn num_args(&self) -> i32 {
862 1
863 }
864
865 fn name(&self) -> &str {
866 "TOTAL"
867 }
868}
869
870pub struct WindowCountState {
871 count: i64,
872}
873
874pub struct WindowCountFunc;
875
876impl WindowFunction for WindowCountFunc {
877 type State = WindowCountState;
878
879 fn initial_state(&self) -> Self::State {
880 WindowCountState { count: 0 }
881 }
882
883 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
884 if args.is_empty() || !args[0].is_null() {
886 state.count += 1;
887 }
888 Ok(())
889 }
890
891 fn inverse(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
892 if args.is_empty() || !args[0].is_null() {
893 state.count -= 1;
894 }
895 Ok(())
896 }
897
898 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
899 Ok(SqliteValue::Integer(state.count))
900 }
901
902 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
903 Ok(SqliteValue::Integer(state.count))
904 }
905
906 fn num_args(&self) -> i32 {
907 -1 }
909
910 fn name(&self) -> &str {
911 "COUNT"
912 }
913}
914
915pub struct WindowMinState {
916 min: Option<SqliteValue>,
917}
918
919pub struct WindowMinFunc;
920
921impl WindowFunction for WindowMinFunc {
922 type State = WindowMinState;
923
924 fn initial_state(&self) -> Self::State {
925 WindowMinState { min: None }
926 }
927
928 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
929 if args.is_empty() || args[0].is_null() {
930 return Ok(());
931 }
932 state.min = Some(match state.min.take() {
933 None => args[0].clone(),
934 Some(cur) => {
935 if cmp_values(&args[0], &cur) == std::cmp::Ordering::Less {
936 args[0].clone()
937 } else {
938 cur
939 }
940 }
941 });
942 Ok(())
943 }
944
945 fn inverse(&self, _state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
946 Ok(())
948 }
949
950 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
951 Ok(state.min.clone().unwrap_or(SqliteValue::Null))
952 }
953
954 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
955 Ok(state.min.unwrap_or(SqliteValue::Null))
956 }
957
958 fn num_args(&self) -> i32 {
959 1
960 }
961
962 fn name(&self) -> &str {
963 "MIN"
964 }
965}
966
967pub struct WindowMaxState {
968 max: Option<SqliteValue>,
969}
970
971pub struct WindowMaxFunc;
972
973impl WindowFunction for WindowMaxFunc {
974 type State = WindowMaxState;
975
976 fn initial_state(&self) -> Self::State {
977 WindowMaxState { max: None }
978 }
979
980 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
981 if args.is_empty() || args[0].is_null() {
982 return Ok(());
983 }
984 state.max = Some(match state.max.take() {
985 None => args[0].clone(),
986 Some(cur) => {
987 if cmp_values(&args[0], &cur) == std::cmp::Ordering::Greater {
988 args[0].clone()
989 } else {
990 cur
991 }
992 }
993 });
994 Ok(())
995 }
996
997 fn inverse(&self, _state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
998 Ok(())
999 }
1000
1001 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
1002 Ok(state.max.clone().unwrap_or(SqliteValue::Null))
1003 }
1004
1005 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
1006 Ok(state.max.unwrap_or(SqliteValue::Null))
1007 }
1008
1009 fn num_args(&self) -> i32 {
1010 1
1011 }
1012
1013 fn name(&self) -> &str {
1014 "MAX"
1015 }
1016}
1017
1018pub struct WindowAvgState {
1019 sum: f64,
1020 err: f64,
1021 count: i64,
1022}
1023
1024pub struct WindowAvgFunc;
1025
1026impl WindowFunction for WindowAvgFunc {
1027 type State = WindowAvgState;
1028
1029 fn initial_state(&self) -> Self::State {
1030 WindowAvgState {
1031 sum: 0.0,
1032 err: 0.0,
1033 count: 0,
1034 }
1035 }
1036
1037 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
1038 if args.is_empty() || args[0].is_null() {
1039 return Ok(());
1040 }
1041 kbn_step(&mut state.sum, &mut state.err, args[0].to_float());
1042 state.count += 1;
1043 Ok(())
1044 }
1045
1046 fn inverse(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
1047 if args.is_empty() || args[0].is_null() {
1048 return Ok(());
1049 }
1050 kbn_step(&mut state.sum, &mut state.err, -args[0].to_float());
1051 state.count -= 1;
1052 Ok(())
1053 }
1054
1055 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
1056 if state.count == 0 {
1057 Ok(SqliteValue::Null)
1058 } else {
1059 #[allow(clippy::cast_precision_loss)]
1060 Ok(SqliteValue::Float(
1061 (state.sum + state.err) / state.count as f64,
1062 ))
1063 }
1064 }
1065
1066 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
1067 self.value(&state)
1068 }
1069
1070 fn num_args(&self) -> i32 {
1071 1
1072 }
1073
1074 fn name(&self) -> &str {
1075 "AVG"
1076 }
1077}
1078
1079pub struct WindowGroupConcatState {
1080 result: String,
1081 has_value: bool,
1082}
1083
1084fn window_group_concat_step(state: &mut WindowGroupConcatState, args: &[SqliteValue]) {
1085 if args.is_empty() || args[0].is_null() {
1086 return;
1087 }
1088 let sep = if args.len() > 1 {
1089 if args[1].is_null() {
1090 String::new()
1091 } else {
1092 args[1].to_text()
1093 }
1094 } else {
1095 ",".to_owned()
1096 };
1097 if state.has_value {
1098 state.result.push_str(&sep);
1099 }
1100 state.result.push_str(&args[0].to_text());
1101 state.has_value = true;
1102}
1103
1104fn window_group_concat_value(state: &WindowGroupConcatState) -> SqliteValue {
1105 if state.has_value {
1106 SqliteValue::Text(state.result.clone().into())
1107 } else {
1108 SqliteValue::Null
1109 }
1110}
1111
1112pub struct WindowGroupConcatFunc;
1113
1114impl WindowFunction for WindowGroupConcatFunc {
1115 type State = WindowGroupConcatState;
1116
1117 fn initial_state(&self) -> Self::State {
1118 WindowGroupConcatState {
1119 result: String::new(),
1120 has_value: false,
1121 }
1122 }
1123
1124 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
1125 window_group_concat_step(state, args);
1126 Ok(())
1127 }
1128
1129 fn inverse(&self, _state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
1130 Ok(())
1133 }
1134
1135 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
1136 Ok(window_group_concat_value(state))
1137 }
1138
1139 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
1140 Ok(window_group_concat_value(&state))
1141 }
1142
1143 fn num_args(&self) -> i32 {
1144 -1
1145 }
1146
1147 fn name(&self) -> &str {
1148 "group_concat"
1149 }
1150}
1151
1152pub struct WindowStringAggFunc;
1153
1154impl WindowFunction for WindowStringAggFunc {
1155 type State = WindowGroupConcatState;
1156
1157 fn initial_state(&self) -> Self::State {
1158 WindowGroupConcatState {
1159 result: String::new(),
1160 has_value: false,
1161 }
1162 }
1163
1164 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
1165 window_group_concat_step(state, args);
1166 Ok(())
1167 }
1168
1169 fn inverse(&self, _state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
1170 Ok(())
1171 }
1172
1173 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
1174 Ok(window_group_concat_value(state))
1175 }
1176
1177 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
1178 Ok(window_group_concat_value(&state))
1179 }
1180
1181 fn num_args(&self) -> i32 {
1182 2
1183 }
1184
1185 fn name(&self) -> &str {
1186 "string_agg"
1187 }
1188}
1189
1190pub fn cmp_values(a: &SqliteValue, b: &SqliteValue) -> std::cmp::Ordering {
1192 match (a, b) {
1193 (SqliteValue::Null, SqliteValue::Null) => std::cmp::Ordering::Equal,
1194 (SqliteValue::Null, _) => std::cmp::Ordering::Less,
1195 (_, SqliteValue::Null) => std::cmp::Ordering::Greater,
1196 (SqliteValue::Integer(ai), SqliteValue::Integer(bi)) => ai.cmp(bi),
1197 (SqliteValue::Float(af), SqliteValue::Float(bf)) => af.total_cmp(bf),
1198 (SqliteValue::Integer(ai), SqliteValue::Float(bf)) => (*ai as f64).total_cmp(bf),
1199 (SqliteValue::Float(af), SqliteValue::Integer(bi)) => af.total_cmp(&(*bi as f64)),
1200 (SqliteValue::Text(at), SqliteValue::Text(bt)) => at.cmp(bt),
1201 (SqliteValue::Blob(ab), SqliteValue::Blob(bb)) => ab.cmp(bb),
1202 _ => a.to_text().cmp(&b.to_text()),
1203 }
1204}
1205
1206pub fn register_window_builtins(registry: &mut FunctionRegistry) {
1210 registry.register_window(RowNumberFunc);
1211 registry.register_window(RankFunc);
1212 registry.register_window(DenseRankFunc);
1213 registry.register_window(PercentRankFunc);
1214 registry.register_window(CumeDistFunc);
1215 registry.register_window(NtileFunc);
1216 registry.register_window(LagFunc);
1217 registry.register_window(LeadFunc);
1218 registry.register_window(FirstValueFunc);
1219 registry.register_window(LastValueFunc);
1220 registry.register_window(NthValueFunc);
1221
1222 registry.register_window(WindowSumFunc);
1224 registry.register_window(WindowTotalFunc);
1225 registry.register_window(WindowCountFunc);
1226 registry.register_window(WindowMinFunc);
1227 registry.register_window(WindowMaxFunc);
1228 registry.register_window(WindowAvgFunc);
1229 registry.register_window(WindowGroupConcatFunc);
1230 registry.register_window(WindowStringAggFunc);
1231}
1232
1233#[cfg(test)]
1236mod tests {
1237 use super::*;
1238
1239 fn int(v: i64) -> SqliteValue {
1240 SqliteValue::Integer(v)
1241 }
1242
1243 fn text(s: &str) -> SqliteValue {
1244 SqliteValue::Text(s.into())
1245 }
1246
1247 fn null() -> SqliteValue {
1248 SqliteValue::Null
1249 }
1250
1251 fn run_window_partition<F: WindowFunction>(
1255 func: &F,
1256 rows: &[Vec<SqliteValue>],
1257 ) -> Vec<SqliteValue> {
1258 let mut state = func.initial_state();
1259 let mut results = Vec::new();
1260 for row in rows {
1261 func.step(&mut state, row).unwrap();
1262 results.push(func.value(&state).unwrap());
1263 }
1264 results
1265 }
1266
1267 fn run_window_two_pass<F: WindowFunction>(
1272 func: &F,
1273 rows: &[Vec<SqliteValue>],
1274 ) -> Vec<SqliteValue> {
1275 let mut state = func.initial_state();
1276 for row in rows {
1278 func.step(&mut state, row).unwrap();
1279 }
1280 let mut results = Vec::new();
1282 for (i, _) in rows.iter().enumerate() {
1283 results.push(func.value(&state).unwrap());
1284 if i < rows.len() - 1 {
1285 func.inverse(&mut state, &[]).unwrap();
1286 }
1287 }
1288 results
1289 }
1290
1291 #[test]
1294 fn test_row_number_basic() {
1295 let results =
1296 run_window_partition(&RowNumberFunc, &[vec![], vec![], vec![], vec![], vec![]]);
1297 assert_eq!(results, vec![int(1), int(2), int(3), int(4), int(5)]);
1298 }
1299
1300 #[test]
1301 fn test_row_number_partition_reset() {
1302 let r1 = run_window_partition(&RowNumberFunc, &[vec![], vec![], vec![]]);
1304 assert_eq!(r1, vec![int(1), int(2), int(3)]);
1305
1306 let r2 = run_window_partition(&RowNumberFunc, &[vec![], vec![]]);
1308 assert_eq!(r2, vec![int(1), int(2)]);
1309 }
1310
1311 #[test]
1314 fn test_rank_with_ties() {
1315 let results = run_window_partition(
1317 &RankFunc,
1318 &[vec![int(1)], vec![int(2)], vec![int(2)], vec![int(3)]],
1319 );
1320 assert_eq!(results, vec![int(1), int(2), int(2), int(4)]);
1321 }
1322
1323 #[test]
1324 fn test_rank_no_ties() {
1325 let results =
1326 run_window_partition(&RankFunc, &[vec![int(10)], vec![int(20)], vec![int(30)]]);
1327 assert_eq!(results, vec![int(1), int(2), int(3)]);
1328 }
1329
1330 #[test]
1333 fn test_dense_rank_with_ties() {
1334 let results = run_window_partition(
1336 &DenseRankFunc,
1337 &[vec![int(1)], vec![int(2)], vec![int(2)], vec![int(3)]],
1338 );
1339 assert_eq!(results, vec![int(1), int(2), int(2), int(3)]);
1340 }
1341
1342 #[test]
1343 fn test_dense_rank_multiple_ties() {
1344 let results = run_window_partition(
1346 &DenseRankFunc,
1347 &[
1348 vec![int(1)],
1349 vec![int(1)],
1350 vec![int(2)],
1351 vec![int(2)],
1352 vec![int(3)],
1353 ],
1354 );
1355 assert_eq!(results, vec![int(1), int(1), int(2), int(2), int(3)]);
1356 }
1357
1358 #[test]
1361 fn test_percent_rank_single_row() {
1362 let results = run_window_two_pass(&PercentRankFunc, &[vec![int(1)]]);
1363 assert_eq!(results, vec![SqliteValue::Float(0.0)]);
1364 }
1365
1366 #[test]
1367 fn test_percent_rank_formula() {
1368 let results = run_window_two_pass(
1371 &PercentRankFunc,
1372 &[vec![int(1)], vec![int(2)], vec![int(2)], vec![int(3)]],
1373 );
1374 match &results[0] {
1379 SqliteValue::Float(v) => assert!((*v - 0.0).abs() < 1e-10),
1380 other => panic!("expected Float, got {other:?}"),
1381 }
1382 match &results[1] {
1383 SqliteValue::Float(v) => assert!((*v - 1.0 / 3.0).abs() < 1e-10),
1384 other => panic!("expected Float, got {other:?}"),
1385 }
1386 match &results[2] {
1387 SqliteValue::Float(v) => assert!((*v - 1.0 / 3.0).abs() < 1e-10),
1388 other => panic!("expected Float, got {other:?}"),
1389 }
1390 match &results[3] {
1391 SqliteValue::Float(v) => assert!((*v - 1.0).abs() < 1e-10),
1392 other => panic!("expected Float, got {other:?}"),
1393 }
1394 }
1395
1396 #[test]
1399 fn test_cume_dist_distinct() {
1400 let results = run_window_two_pass(
1402 &CumeDistFunc,
1403 &[vec![int(1)], vec![int(2)], vec![int(3)], vec![int(4)]],
1404 );
1405 for (i, expected) in [0.25, 0.5, 0.75, 1.0].iter().enumerate() {
1406 match &results[i] {
1407 SqliteValue::Float(v) => {
1408 assert!(
1409 (*v - expected).abs() < 1e-10,
1410 "row {i}: expected {expected}, got {v}"
1411 );
1412 }
1413 other => panic!("expected Float, got {other:?}"),
1414 }
1415 }
1416 }
1417
1418 #[test]
1421 fn test_ntile_even() {
1422 let rows: Vec<Vec<SqliteValue>> = (0..8).map(|_| vec![int(4)]).collect();
1424 let results = run_window_two_pass(&NtileFunc, &rows);
1425 assert_eq!(
1426 results,
1427 vec![
1428 int(1),
1429 int(1),
1430 int(2),
1431 int(2),
1432 int(3),
1433 int(3),
1434 int(4),
1435 int(4)
1436 ]
1437 );
1438 }
1439
1440 #[test]
1441 fn test_ntile_uneven() {
1442 let rows: Vec<Vec<SqliteValue>> = (0..10).map(|_| vec![int(3)]).collect();
1444 let results = run_window_two_pass(&NtileFunc, &rows);
1445 assert_eq!(
1446 results,
1447 vec![
1448 int(1),
1449 int(1),
1450 int(1),
1451 int(1),
1452 int(2),
1453 int(2),
1454 int(2),
1455 int(3),
1456 int(3),
1457 int(3)
1458 ]
1459 );
1460 }
1461
1462 #[test]
1463 fn test_ntile_more_buckets_than_rows() {
1464 let rows: Vec<Vec<SqliteValue>> = (0..3).map(|_| vec![int(10)]).collect();
1466 let results = run_window_two_pass(&NtileFunc, &rows);
1467 assert_eq!(results, vec![int(1), int(2), int(3)]);
1468 }
1469
1470 #[test]
1473 fn test_lag_default() {
1474 let results =
1476 run_window_partition(&LagFunc, &[vec![int(10)], vec![int(20)], vec![int(30)]]);
1477 assert_eq!(results, vec![null(), int(10), int(20)]);
1478 }
1479
1480 #[test]
1481 fn test_lag_offset_3() {
1482 let results = run_window_partition(
1484 &LagFunc,
1485 &[
1486 vec![int(10), int(3)],
1487 vec![int(20), int(3)],
1488 vec![int(30), int(3)],
1489 vec![int(40), int(3)],
1490 vec![int(50), int(3)],
1491 ],
1492 );
1493 assert_eq!(results, vec![null(), null(), null(), int(10), int(20)]);
1494 }
1495
1496 #[test]
1497 fn test_lag_default_value() {
1498 let results = run_window_partition(
1500 &LagFunc,
1501 &[
1502 vec![int(10), int(1), int(-1)],
1503 vec![int(20), int(1), int(-1)],
1504 ],
1505 );
1506 assert_eq!(results, vec![int(-1), int(10)]);
1507 }
1508
1509 #[test]
1512 fn test_lead_default() {
1513 let func = LeadFunc;
1517 let mut state = func.initial_state();
1518 let rows = [int(10), int(20), int(30)];
1519
1520 for row in &rows {
1522 func.step(&mut state, std::slice::from_ref(row)).unwrap();
1523 }
1524
1525 let mut results = Vec::new();
1527 for _ in &rows {
1528 results.push(func.value(&state).unwrap());
1529 func.inverse(&mut state, &[]).unwrap();
1530 }
1531 assert_eq!(results, vec![int(20), int(30), null()]);
1532 }
1533
1534 #[test]
1535 fn test_lead_offset_2() {
1536 let func = LeadFunc;
1537 let mut state = func.initial_state();
1538 let rows = [int(10), int(20), int(30), int(40), int(50)];
1539
1540 for row in &rows {
1541 func.step(&mut state, &[row.clone(), int(2)]).unwrap();
1542 }
1543
1544 let mut results = Vec::new();
1545 for _ in &rows {
1546 results.push(func.value(&state).unwrap());
1547 func.inverse(&mut state, &[]).unwrap();
1548 }
1549 assert_eq!(results, vec![int(30), int(40), int(50), null(), null()]);
1550 }
1551
1552 #[test]
1553 fn test_lead_default_value() {
1554 let func = LeadFunc;
1555 let mut state = func.initial_state();
1556 let rows = [int(10), int(20)];
1557
1558 for row in &rows {
1559 func.step(&mut state, &[row.clone(), int(1), text("N/A")])
1560 .unwrap();
1561 }
1562
1563 let mut results = Vec::new();
1564 for _ in &rows {
1565 results.push(func.value(&state).unwrap());
1566 func.inverse(&mut state, &[]).unwrap();
1567 }
1568 assert_eq!(results, vec![int(20), text("N/A")]);
1569 }
1570
1571 #[test]
1574 fn test_first_value_basic() {
1575 let results = run_window_partition(
1576 &FirstValueFunc,
1577 &[vec![int(10)], vec![int(20)], vec![int(30)]],
1578 );
1579 assert_eq!(results, vec![int(10), int(10), int(10)]);
1582 }
1583
1584 #[test]
1587 fn test_last_value_default_frame() {
1588 let results = run_window_partition(
1591 &LastValueFunc,
1592 &[vec![int(10)], vec![int(20)], vec![int(30)]],
1593 );
1594 assert_eq!(results, vec![int(10), int(20), int(30)]);
1595 }
1596
1597 #[test]
1598 fn test_last_value_unbounded_following() {
1599 let func = LastValueFunc;
1602 let mut state = func.initial_state();
1603 func.step(&mut state, &[int(10)]).unwrap();
1604 func.step(&mut state, &[int(20)]).unwrap();
1605 func.step(&mut state, &[int(30)]).unwrap();
1606 assert_eq!(func.value(&state).unwrap(), int(30));
1607 }
1608
1609 #[test]
1612 fn test_nth_value_basic() {
1613 let func = NthValueFunc;
1614 let mut state = func.initial_state();
1615 func.step(&mut state, &[int(10), int(3)]).unwrap();
1617 func.step(&mut state, &[int(20), int(3)]).unwrap();
1618 func.step(&mut state, &[int(30), int(3)]).unwrap();
1619 func.step(&mut state, &[int(40), int(3)]).unwrap();
1620 func.step(&mut state, &[int(50), int(3)]).unwrap();
1621 assert_eq!(func.value(&state).unwrap(), int(30));
1622 }
1623
1624 #[test]
1625 fn test_nth_value_out_of_range() {
1626 let func = NthValueFunc;
1627 let mut state = func.initial_state();
1628 func.step(&mut state, &[int(10), int(100)]).unwrap();
1629 func.step(&mut state, &[int(20), int(100)]).unwrap();
1630 assert_eq!(func.value(&state).unwrap(), null());
1632 }
1633
1634 #[test]
1635 fn test_nth_value_n_zero() {
1636 let func = NthValueFunc;
1638 let mut state = func.initial_state();
1639 func.step(&mut state, &[int(10), int(0)]).unwrap();
1640 assert_eq!(func.value(&state).unwrap(), null());
1641 }
1642
1643 #[test]
1646 fn test_window_group_concat_running_default_separator() {
1647 let results = run_window_partition(
1648 &WindowGroupConcatFunc,
1649 &[vec![text("a")], vec![text("b")], vec![text("c")]],
1650 );
1651 assert_eq!(results, vec![text("a"), text("a,b"), text("a,b,c")]);
1652 }
1653
1654 #[test]
1655 fn test_window_group_concat_running_custom_separator() {
1656 let results = run_window_partition(
1657 &WindowGroupConcatFunc,
1658 &[
1659 vec![text("a"), text(" | ")],
1660 vec![text("b"), text(" | ")],
1661 vec![text("c"), text(" | ")],
1662 ],
1663 );
1664 assert_eq!(results, vec![text("a"), text("a | b"), text("a | b | c")]);
1665 }
1666
1667 #[test]
1668 fn test_window_group_concat_skips_null_and_uses_current_row_separator() {
1669 let results = run_window_partition(
1670 &WindowGroupConcatFunc,
1671 &[
1672 vec![text("a"), text("-")],
1673 vec![null(), text("?")],
1674 vec![text("b"), text("+")],
1675 vec![text("c"), text("*")],
1676 ],
1677 );
1678 assert_eq!(
1679 results,
1680 vec![text("a"), text("a"), text("a+b"), text("a+b*c")]
1681 );
1682 }
1683
1684 #[test]
1685 fn test_window_string_agg_alias_through_registry() {
1686 let mut reg = FunctionRegistry::new();
1687 register_window_builtins(&mut reg);
1688
1689 let sa = reg.find_window("string_agg", 2).unwrap();
1690 let mut state = sa.initial_state();
1691 sa.step(&mut state, &[text("a"), text(";")]).unwrap();
1692 assert_eq!(sa.value(&state).unwrap(), text("a"));
1693 sa.step(&mut state, &[text("b"), text(";")]).unwrap();
1694 assert_eq!(sa.value(&state).unwrap(), text("a;b"));
1695 }
1696
1697 #[test]
1700 fn test_register_window_builtins_all_present() {
1701 let mut reg = FunctionRegistry::new();
1702 register_window_builtins(&mut reg);
1703
1704 let expected_variadic = [
1705 "row_number",
1706 "rank",
1707 "dense_rank",
1708 "percent_rank",
1709 "cume_dist",
1710 "lag",
1711 "lead",
1712 ];
1713 for name in expected_variadic {
1714 assert!(
1715 reg.find_window(name, 0).is_some()
1716 || reg.find_window(name, 1).is_some()
1717 || reg.find_window(name, -1).is_some(),
1718 "window function '{name}' not registered"
1719 );
1720 }
1721
1722 assert!(
1723 reg.find_window("ntile", 1).is_some(),
1724 "ntile(1) not registered"
1725 );
1726 assert!(
1727 reg.find_window("first_value", 1).is_some(),
1728 "first_value(1) not registered"
1729 );
1730 assert!(
1731 reg.find_window("last_value", 1).is_some(),
1732 "last_value(1) not registered"
1733 );
1734 assert!(
1735 reg.find_window("nth_value", 2).is_some(),
1736 "nth_value(2) not registered"
1737 );
1738 assert!(
1739 reg.find_window("group_concat", 1).is_some(),
1740 "group_concat(1) not registered"
1741 );
1742 assert!(
1743 reg.find_window("group_concat", 2).is_some(),
1744 "group_concat(2) not registered"
1745 );
1746 assert!(
1747 reg.find_window("string_agg", 2).is_some(),
1748 "string_agg(2) not registered"
1749 );
1750 }
1751
1752 #[test]
1755 fn test_e2e_window_row_number_through_registry() {
1756 let mut reg = FunctionRegistry::new();
1757 register_window_builtins(&mut reg);
1758
1759 let rn = reg.find_window("row_number", 0).unwrap();
1760 let mut state = rn.initial_state();
1761 rn.step(&mut state, &[]).unwrap();
1762 assert_eq!(rn.value(&state).unwrap(), int(1));
1763 rn.step(&mut state, &[]).unwrap();
1764 assert_eq!(rn.value(&state).unwrap(), int(2));
1765 rn.step(&mut state, &[]).unwrap();
1766 assert_eq!(rn.value(&state).unwrap(), int(3));
1767 }
1768
1769 #[test]
1770 fn test_e2e_window_rank_through_registry() {
1771 let mut reg = FunctionRegistry::new();
1772 register_window_builtins(&mut reg);
1773
1774 let rank = reg.find_window("rank", 1).unwrap();
1775 let mut state = rank.initial_state();
1776 rank.step(&mut state, &[int(1)]).unwrap();
1778 assert_eq!(rank.value(&state).unwrap(), int(1));
1779 rank.step(&mut state, &[int(2)]).unwrap();
1780 assert_eq!(rank.value(&state).unwrap(), int(2));
1781 rank.step(&mut state, &[int(2)]).unwrap();
1782 assert_eq!(rank.value(&state).unwrap(), int(2));
1783 rank.step(&mut state, &[int(3)]).unwrap();
1784 assert_eq!(rank.value(&state).unwrap(), int(4));
1785 }
1786}