1use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
23use arrow::buffer::NullBuffer;
24use arrow::datatypes::ArrowPrimitiveType;
25
26use datafusion_expr_common::groups_accumulator::EmitTo;
27
28#[derive(Debug)]
38pub enum SeenValues {
39 All {
41 num_values: usize,
42 },
43 Some {
45 values: BooleanBufferBuilder,
46 },
47}
48
49impl Default for SeenValues {
50 fn default() -> Self {
51 SeenValues::All { num_values: 0 }
52 }
53}
54
55impl SeenValues {
56 fn get_builder(&mut self, total_num_groups: usize) -> &mut BooleanBufferBuilder {
64 match self {
65 SeenValues::All { num_values } => {
66 let mut builder = BooleanBufferBuilder::new(total_num_groups);
67 builder.append_n(*num_values, true);
68 if total_num_groups > *num_values {
69 builder.append_n(total_num_groups - *num_values, false);
70 }
71 *self = SeenValues::Some { values: builder };
72 match self {
73 SeenValues::Some { values } => values,
74 _ => unreachable!(),
75 }
76 }
77 SeenValues::Some { values } => {
78 if values.len() < total_num_groups {
79 values.append_n(total_num_groups - values.len(), false);
80 }
81 values
82 }
83 }
84 }
85}
86
87#[derive(Debug)]
113pub struct NullState {
114 seen_values: SeenValues,
124}
125
126impl Default for NullState {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132impl NullState {
133 pub fn new() -> Self {
134 Self {
135 seen_values: SeenValues::All { num_values: 0 },
136 }
137 }
138
139 pub fn size(&self) -> usize {
141 match &self.seen_values {
142 SeenValues::All { .. } => 0,
143 SeenValues::Some { values } => values.capacity() / 8,
144 }
145 }
146
147 pub fn accumulate<T, F>(
164 &mut self,
165 group_indices: &[usize],
166 values: &PrimitiveArray<T>,
167 opt_filter: Option<&BooleanArray>,
168 total_num_groups: usize,
169 mut value_fn: F,
170 ) where
171 T: ArrowPrimitiveType + Send,
172 F: FnMut(usize, T::Native) + Send,
173 {
174 if let SeenValues::All { num_values } = &mut self.seen_values
176 && opt_filter.is_none()
177 && values.null_count() == 0
178 {
179 accumulate(group_indices, values, None, value_fn);
180 *num_values = total_num_groups;
181 return;
182 }
183
184 let seen_values = self.seen_values.get_builder(total_num_groups);
185 accumulate(group_indices, values, opt_filter, |group_index, value| {
186 seen_values.set_bit(group_index, true);
187 value_fn(group_index, value);
188 });
189 }
190
191 pub fn accumulate_boolean<F>(
202 &mut self,
203 group_indices: &[usize],
204 values: &BooleanArray,
205 opt_filter: Option<&BooleanArray>,
206 total_num_groups: usize,
207 mut value_fn: F,
208 ) where
209 F: FnMut(usize, bool) + Send,
210 {
211 let data = values.values();
212 assert_eq!(data.len(), group_indices.len());
213
214 if let SeenValues::All { num_values } = &mut self.seen_values
216 && opt_filter.is_none()
217 && values.null_count() == 0
218 {
219 group_indices
220 .iter()
221 .zip(data.iter())
222 .for_each(|(&group_index, new_value)| value_fn(group_index, new_value));
223 *num_values = total_num_groups;
224
225 return;
226 }
227
228 let seen_values = self.seen_values.get_builder(total_num_groups);
229
230 match (values.null_count() > 0, opt_filter) {
232 (false, None) => {
234 group_indices.iter().zip(data.iter()).for_each(
237 |(&group_index, new_value)| {
238 seen_values.set_bit(group_index, true);
239 value_fn(group_index, new_value)
240 },
241 )
242 }
243 (true, None) => {
245 let nulls = values.nulls().unwrap();
246 group_indices
247 .iter()
248 .zip(data.iter())
249 .zip(nulls.iter())
250 .for_each(|((&group_index, new_value), is_valid)| {
251 if is_valid {
252 seen_values.set_bit(group_index, true);
253 value_fn(group_index, new_value);
254 }
255 })
256 }
257 (false, Some(filter)) => {
259 assert_eq!(filter.len(), group_indices.len());
260
261 group_indices
262 .iter()
263 .zip(data.iter())
264 .zip(filter.iter())
265 .for_each(|((&group_index, new_value), filter_value)| {
266 if let Some(true) = filter_value {
267 seen_values.set_bit(group_index, true);
268 value_fn(group_index, new_value);
269 }
270 })
271 }
272 (true, Some(filter)) => {
274 assert_eq!(filter.len(), group_indices.len());
275 filter
276 .iter()
277 .zip(group_indices.iter())
278 .zip(values.iter())
279 .for_each(|((filter_value, &group_index), new_value)| {
280 if let Some(true) = filter_value
281 && let Some(new_value) = new_value
282 {
283 seen_values.set_bit(group_index, true);
284 value_fn(group_index, new_value)
285 }
286 })
287 }
288 }
289 }
290
291 pub fn build(&mut self, emit_to: EmitTo) -> Option<NullBuffer> {
297 match emit_to {
298 EmitTo::All => {
299 let old_seen = std::mem::take(&mut self.seen_values);
300 match old_seen {
301 SeenValues::All { .. } => None,
302 SeenValues::Some { mut values } => {
303 Some(NullBuffer::new(values.finish()))
304 }
305 }
306 }
307 EmitTo::First(n) => match &mut self.seen_values {
308 SeenValues::All { num_values } => {
309 *num_values = num_values.saturating_sub(n);
310 None
311 }
312 SeenValues::Some { .. } => {
313 let mut old_values = match std::mem::take(&mut self.seen_values) {
314 SeenValues::Some { values } => values,
315 _ => unreachable!(),
316 };
317 let nulls = old_values.finish();
318 let first_n_null = nulls.slice(0, n);
319 let remainder = nulls.slice(n, nulls.len() - n);
320 let mut new_builder = BooleanBufferBuilder::new(remainder.len());
321 new_builder.append_buffer(&remainder);
322 self.seen_values = SeenValues::Some {
323 values: new_builder,
324 };
325 Some(NullBuffer::new(first_n_null))
326 }
327 },
328 }
329 }
330}
331
332pub fn accumulate<T, F>(
371 group_indices: &[usize],
372 values: &PrimitiveArray<T>,
373 opt_filter: Option<&BooleanArray>,
374 mut value_fn: F,
375) where
376 T: ArrowPrimitiveType + Send,
377 F: FnMut(usize, T::Native) + Send,
378{
379 let data: &[T::Native] = values.values();
380 assert_eq!(data.len(), group_indices.len());
381
382 match (values.null_count() > 0, opt_filter) {
383 (false, None) => {
385 let iter = group_indices.iter().zip(data.iter());
386 for (&group_index, &new_value) in iter {
387 value_fn(group_index, new_value);
388 }
389 }
390 (true, None) => {
392 let nulls = values.nulls().unwrap();
393 let group_indices_chunks = group_indices.chunks_exact(64);
396 let data_chunks = data.chunks_exact(64);
397 let bit_chunks = nulls.inner().bit_chunks();
398
399 let group_indices_remainder = group_indices_chunks.remainder();
400 let data_remainder = data_chunks.remainder();
401
402 group_indices_chunks
403 .zip(data_chunks)
404 .zip(bit_chunks.iter())
405 .for_each(|((group_index_chunk, data_chunk), mask)| {
406 let mut index_mask = 1;
408 group_index_chunk.iter().zip(data_chunk.iter()).for_each(
409 |(&group_index, &new_value)| {
410 let is_valid = (mask & index_mask) != 0;
412 if is_valid {
413 value_fn(group_index, new_value);
414 }
415 index_mask <<= 1;
416 },
417 )
418 });
419
420 let remainder_bits = bit_chunks.remainder_bits();
422 group_indices_remainder
423 .iter()
424 .zip(data_remainder.iter())
425 .enumerate()
426 .for_each(|(i, (&group_index, &new_value))| {
427 let is_valid = remainder_bits & (1 << i) != 0;
428 if is_valid {
429 value_fn(group_index, new_value);
430 }
431 });
432 }
433 (false, Some(filter)) => {
435 assert_eq!(filter.len(), group_indices.len());
436 group_indices
440 .iter()
441 .zip(data.iter())
442 .zip(filter.iter())
443 .for_each(|((&group_index, &new_value), filter_value)| {
444 if let Some(true) = filter_value {
445 value_fn(group_index, new_value);
446 }
447 })
448 }
449 (true, Some(filter)) => {
451 assert_eq!(filter.len(), group_indices.len());
452 filter
456 .iter()
457 .zip(group_indices.iter())
458 .zip(values.iter())
459 .for_each(|((filter_value, &group_index), new_value)| {
460 if let Some(true) = filter_value
461 && let Some(new_value) = new_value
462 {
463 value_fn(group_index, new_value)
464 }
465 })
466 }
467 }
468}
469
470pub fn accumulate_multiple<T, F>(
486 group_indices: &[usize],
487 value_columns: &[&PrimitiveArray<T>],
488 opt_filter: Option<&BooleanArray>,
489 mut value_fn: F,
490) where
491 T: ArrowPrimitiveType + Send,
492 F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
493{
494 let combined_nulls = value_columns
502 .iter()
503 .map(|arr| arr.logical_nulls())
504 .fold(None, |acc, nulls| {
505 NullBuffer::union(acc.as_ref(), nulls.as_ref())
506 });
507
508 let valid_indices = match (combined_nulls, opt_filter) {
510 (None, None) => None,
511 (None, Some(filter)) => Some(filter.clone()),
512 (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)),
513 (Some(nulls), Some(filter)) => {
514 let combined = nulls.inner() & filter.values();
515 Some(BooleanArray::new(combined, None))
516 }
517 };
518
519 for col in value_columns.iter() {
520 debug_assert_eq!(col.len(), group_indices.len());
521 }
522
523 match valid_indices {
524 None => {
525 for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
526 value_fn(group_idx, batch_idx, value_columns);
527 }
528 }
529 Some(valid_indices) => {
530 for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
531 if valid_indices.value(batch_idx) {
532 value_fn(group_idx, batch_idx, value_columns);
533 }
534 }
535 }
536 }
537}
538
539pub fn accumulate_indices<F>(
549 group_indices: &[usize],
550 nulls: Option<&NullBuffer>,
551 opt_filter: Option<&BooleanArray>,
552 mut index_fn: F,
553) where
554 F: FnMut(usize) + Send,
555{
556 match (nulls, opt_filter) {
557 (None, None) => {
558 for &group_index in group_indices.iter() {
559 index_fn(group_index)
560 }
561 }
562 (None, Some(filter)) => {
563 debug_assert_eq!(filter.len(), group_indices.len());
564 let group_indices_chunks = group_indices.chunks_exact(64);
565 let bit_chunks = filter.values().bit_chunks();
566
567 let group_indices_remainder = group_indices_chunks.remainder();
568
569 group_indices_chunks.zip(bit_chunks.iter()).for_each(
570 |(group_index_chunk, mask)| {
571 let mut index_mask = 1;
573 group_index_chunk.iter().for_each(|&group_index| {
574 let is_valid = (mask & index_mask) != 0;
576 if is_valid {
577 index_fn(group_index);
578 }
579 index_mask <<= 1;
580 })
581 },
582 );
583
584 let remainder_bits = bit_chunks.remainder_bits();
586 group_indices_remainder
587 .iter()
588 .enumerate()
589 .for_each(|(i, &group_index)| {
590 let is_valid = remainder_bits & (1 << i) != 0;
591 if is_valid {
592 index_fn(group_index)
593 }
594 });
595 }
596 (Some(valids), None) => {
597 debug_assert_eq!(valids.len(), group_indices.len());
598 let group_indices_chunks = group_indices.chunks_exact(64);
601 let bit_chunks = valids.inner().bit_chunks();
602
603 let group_indices_remainder = group_indices_chunks.remainder();
604
605 group_indices_chunks.zip(bit_chunks.iter()).for_each(
606 |(group_index_chunk, mask)| {
607 let mut index_mask = 1;
609 group_index_chunk.iter().for_each(|&group_index| {
610 let is_valid = (mask & index_mask) != 0;
612 if is_valid {
613 index_fn(group_index);
614 }
615 index_mask <<= 1;
616 })
617 },
618 );
619
620 let remainder_bits = bit_chunks.remainder_bits();
622 group_indices_remainder
623 .iter()
624 .enumerate()
625 .for_each(|(i, &group_index)| {
626 let is_valid = remainder_bits & (1 << i) != 0;
627 if is_valid {
628 index_fn(group_index)
629 }
630 });
631 }
632
633 (Some(valids), Some(filter)) => {
634 debug_assert_eq!(filter.len(), group_indices.len());
635 debug_assert_eq!(valids.len(), group_indices.len());
636
637 let group_indices_chunks = group_indices.chunks_exact(64);
638 let valid_bit_chunks = valids.inner().bit_chunks();
639 let filter_bit_chunks = filter.values().bit_chunks();
640
641 let group_indices_remainder = group_indices_chunks.remainder();
642
643 group_indices_chunks
644 .zip(valid_bit_chunks.iter())
645 .zip(filter_bit_chunks.iter())
646 .for_each(|((group_index_chunk, valid_mask), filter_mask)| {
647 let mut index_mask = 1;
649 group_index_chunk.iter().for_each(|&group_index| {
650 let is_valid = (valid_mask & filter_mask & index_mask) != 0;
652 if is_valid {
653 index_fn(group_index);
654 }
655 index_mask <<= 1;
656 })
657 });
658
659 let remainder_valid_bits = valid_bit_chunks.remainder_bits();
661 let remainder_filter_bits = filter_bit_chunks.remainder_bits();
662 group_indices_remainder
663 .iter()
664 .enumerate()
665 .for_each(|(i, &group_index)| {
666 let is_valid =
667 remainder_valid_bits & remainder_filter_bits & (1 << i) != 0;
668 if is_valid {
669 index_fn(group_index)
670 }
671 });
672 }
673 }
674}
675
676#[cfg(test)]
677mod test {
678 use super::*;
679
680 use arrow::{
681 array::{Int32Array, UInt32Array},
682 buffer::BooleanBuffer,
683 };
684 use rand::{Rng, rngs::ThreadRng};
685 use std::collections::HashSet;
686
687 #[test]
688 fn accumulate() {
689 let group_indices = (0..100).collect();
690 let values = (0..100).map(|i| (i + 1) * 10).collect();
691 let values_with_nulls = (0..100)
692 .map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) })
693 .collect();
694
695 let filter: BooleanArray = (0..100)
698 .map(|i| {
699 let is_even = i % 2 == 0;
700 let is_fifth = i % 5 == 0;
701 if is_even {
702 None
703 } else if is_fifth {
704 Some(false)
705 } else {
706 Some(true)
707 }
708 })
709 .collect();
710
711 Fixture {
712 group_indices,
713 values,
714 values_with_nulls,
715 filter,
716 }
717 .run()
718 }
719
720 #[test]
721 fn accumulate_fuzz() {
722 let mut rng = rand::rng();
723 for _ in 0..100 {
724 Fixture::new_random(&mut rng).run();
725 }
726 }
727
728 struct Fixture {
730 group_indices: Vec<usize>,
732
733 values: Vec<u32>,
735
736 values_with_nulls: Vec<Option<u32>>,
739
740 filter: BooleanArray,
742 }
743
744 impl Fixture {
745 fn new_random(rng: &mut ThreadRng) -> Self {
746 let num_values: usize = rng.random_range(1..200);
748 let num_groups: usize = rng.random_range(2..1000);
750 let max_group = num_groups - 1;
751
752 let group_indices: Vec<usize> = (0..num_values)
753 .map(|_| rng.random_range(0..max_group))
754 .collect();
755
756 let values: Vec<u32> = (0..num_values).map(|_| rng.random()).collect();
757
758 let filter: BooleanArray = (0..num_values)
762 .map(|_| {
763 let filter_value = rng.random_range(0.0..1.0);
764 if filter_value < 0.1 {
765 Some(false)
766 } else if filter_value < 0.2 {
767 None
768 } else {
769 Some(true)
770 }
771 })
772 .collect();
773
774 let null_pct: f32 = rng.random_range(0.0..1.0);
777 let values_with_nulls: Vec<Option<u32>> = (0..num_values)
778 .map(|_| {
779 let is_null = null_pct < rng.random_range(0.0..1.0);
780 if is_null { None } else { Some(rng.random()) }
781 })
782 .collect();
783
784 Self {
785 group_indices,
786 values,
787 values_with_nulls,
788 filter,
789 }
790 }
791
792 fn values_array(&self) -> UInt32Array {
794 UInt32Array::from(self.values.clone())
795 }
796
797 fn values_with_nulls_array(&self) -> UInt32Array {
799 UInt32Array::from(self.values_with_nulls.clone())
800 }
801
802 fn run(&self) {
805 let total_num_groups = *self.group_indices.iter().max().unwrap() + 1;
806
807 let group_indices = &self.group_indices;
808 let values_array = self.values_array();
809 let values_with_nulls_array = self.values_with_nulls_array();
810 let filter = &self.filter;
811
812 Self::accumulate_test(group_indices, &values_array, None, total_num_groups);
814
815 Self::accumulate_test(
817 group_indices,
818 &values_with_nulls_array,
819 None,
820 total_num_groups,
821 );
822
823 Self::accumulate_test(
825 group_indices,
826 &values_array,
827 Some(filter),
828 total_num_groups,
829 );
830
831 Self::accumulate_test(
833 group_indices,
834 &values_with_nulls_array,
835 Some(filter),
836 total_num_groups,
837 );
838 }
839
840 fn accumulate_test(
843 group_indices: &[usize],
844 values: &UInt32Array,
845 opt_filter: Option<&BooleanArray>,
846 total_num_groups: usize,
847 ) {
848 Self::accumulate_values_test(
849 group_indices,
850 values,
851 opt_filter,
852 total_num_groups,
853 );
854 Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter);
855
856 let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum();
859 let boolean_values: BooleanArray =
860 values.iter().map(|v| v.map(|v| v as usize > avg)).collect();
861 Self::accumulate_boolean_test(
862 group_indices,
863 &boolean_values,
864 opt_filter,
865 total_num_groups,
866 );
867 }
868
869 fn accumulate_values_test(
872 group_indices: &[usize],
873 values: &UInt32Array,
874 opt_filter: Option<&BooleanArray>,
875 total_num_groups: usize,
876 ) {
877 let mut accumulated_values = vec![];
878 let mut null_state = NullState::new();
879
880 null_state.accumulate(
881 group_indices,
882 values,
883 opt_filter,
884 total_num_groups,
885 |group_index, value| {
886 accumulated_values.push((group_index, value));
887 },
888 );
889
890 let mut expected_values = vec![];
892 let mut mock = MockNullState::new();
893
894 match opt_filter {
895 None => group_indices.iter().zip(values.iter()).for_each(
896 |(&group_index, value)| {
897 if let Some(value) = value {
898 mock.saw_value(group_index);
899 expected_values.push((group_index, value));
900 }
901 },
902 ),
903 Some(filter) => {
904 group_indices
905 .iter()
906 .zip(values.iter())
907 .zip(filter.iter())
908 .for_each(|((&group_index, value), is_included)| {
909 if let Some(true) = is_included
911 && let Some(value) = value
912 {
913 mock.saw_value(group_index);
914 expected_values.push((group_index, value));
915 }
916 });
917 }
918 }
919
920 assert_eq!(
921 accumulated_values, expected_values,
922 "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
923 );
924
925 match &null_state.seen_values {
926 SeenValues::All { num_values } => {
927 assert_eq!(*num_values, total_num_groups);
928 }
929 SeenValues::Some { values } => {
930 let seen_values = values.finish_cloned();
931 mock.validate_seen_values(&seen_values);
932 }
933 }
934
935 let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
937
938 let null_buffer = null_state.build(EmitTo::All);
939 if let Some(nulls) = &null_buffer {
940 assert_eq!(*nulls, expected_null_buffer);
941 }
942 }
943
944 fn accumulate_indices_test(
947 group_indices: &[usize],
948 nulls: Option<&NullBuffer>,
949 opt_filter: Option<&BooleanArray>,
950 ) {
951 let mut accumulated_values = vec![];
952
953 accumulate_indices(group_indices, nulls, opt_filter, |group_index| {
954 accumulated_values.push(group_index);
955 });
956
957 let mut expected_values = vec![];
959
960 match (nulls, opt_filter) {
961 (None, None) => group_indices.iter().for_each(|&group_index| {
962 expected_values.push(group_index);
963 }),
964 (Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each(
965 |(&group_index, is_valid)| {
966 if is_valid {
967 expected_values.push(group_index);
968 }
969 },
970 ),
971 (None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each(
972 |(&group_index, is_included)| {
973 if let Some(true) = is_included {
974 expected_values.push(group_index);
975 }
976 },
977 ),
978 (Some(nulls), Some(filter)) => {
979 group_indices
980 .iter()
981 .zip(nulls.iter())
982 .zip(filter.iter())
983 .for_each(|((&group_index, is_valid), is_included)| {
984 if let (true, Some(true)) = (is_valid, is_included) {
986 expected_values.push(group_index);
987 }
988 });
989 }
990 }
991
992 assert_eq!(
993 accumulated_values, expected_values,
994 "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
995 );
996 }
997
998 fn accumulate_boolean_test(
1001 group_indices: &[usize],
1002 values: &BooleanArray,
1003 opt_filter: Option<&BooleanArray>,
1004 total_num_groups: usize,
1005 ) {
1006 let mut accumulated_values = vec![];
1007 let mut null_state = NullState::new();
1008
1009 null_state.accumulate_boolean(
1010 group_indices,
1011 values,
1012 opt_filter,
1013 total_num_groups,
1014 |group_index, value| {
1015 accumulated_values.push((group_index, value));
1016 },
1017 );
1018
1019 let mut expected_values = vec![];
1021 let mut mock = MockNullState::new();
1022
1023 match opt_filter {
1024 None => group_indices.iter().zip(values.iter()).for_each(
1025 |(&group_index, value)| {
1026 if let Some(value) = value {
1027 mock.saw_value(group_index);
1028 expected_values.push((group_index, value));
1029 }
1030 },
1031 ),
1032 Some(filter) => {
1033 group_indices
1034 .iter()
1035 .zip(values.iter())
1036 .zip(filter.iter())
1037 .for_each(|((&group_index, value), is_included)| {
1038 if let Some(true) = is_included
1040 && let Some(value) = value
1041 {
1042 mock.saw_value(group_index);
1043 expected_values.push((group_index, value));
1044 }
1045 });
1046 }
1047 }
1048
1049 assert_eq!(
1050 accumulated_values, expected_values,
1051 "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
1052 );
1053
1054 match &null_state.seen_values {
1055 SeenValues::All { num_values } => {
1056 assert_eq!(*num_values, total_num_groups);
1057 }
1058 SeenValues::Some { values } => {
1059 let seen_values = values.finish_cloned();
1060 mock.validate_seen_values(&seen_values);
1061 }
1062 }
1063
1064 let expected_null_buffer = Some(mock.expected_null_buffer(total_num_groups));
1066
1067 let is_all_seen = matches!(null_state.seen_values, SeenValues::All { .. });
1068 let null_buffer = null_state.build(EmitTo::All);
1069
1070 if !is_all_seen {
1071 assert_eq!(null_buffer, expected_null_buffer);
1072 }
1073 }
1074 }
1075
1076 #[derive(Debug, Default)]
1078 struct MockNullState {
1079 seen_values: HashSet<usize>,
1081 }
1082
1083 impl MockNullState {
1084 fn new() -> Self {
1085 Default::default()
1086 }
1087
1088 fn saw_value(&mut self, group_index: usize) {
1089 self.seen_values.insert(group_index);
1090 }
1091
1092 fn expected_seen(&self, group_index: usize) -> bool {
1094 self.seen_values.contains(&group_index)
1095 }
1096
1097 fn validate_seen_values(&self, seen_values: &BooleanBuffer) {
1099 for (group_index, is_seen) in seen_values.iter().enumerate() {
1100 let expected_seen = self.expected_seen(group_index);
1101 assert_eq!(
1102 expected_seen, is_seen,
1103 "mismatch at for group {group_index}"
1104 );
1105 }
1106 }
1107
1108 fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer {
1110 (0..total_num_groups)
1111 .map(|group_index| self.expected_seen(group_index))
1112 .collect()
1113 }
1114 }
1115
1116 #[test]
1117 fn test_accumulate_multiple_no_nulls_no_filter() {
1118 let group_indices = vec![0, 1, 0, 1];
1119 let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1120 let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1121 let value_columns = [values1, values2];
1122
1123 let mut accumulated = vec![];
1124 accumulate_multiple(
1125 &group_indices,
1126 &value_columns.iter().collect::<Vec<_>>(),
1127 None,
1128 |group_idx, batch_idx, columns| {
1129 let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1130 accumulated.push((group_idx, values));
1131 },
1132 );
1133
1134 let expected = vec![
1135 (0, vec![1, 10]),
1136 (1, vec![2, 20]),
1137 (0, vec![3, 30]),
1138 (1, vec![4, 40]),
1139 ];
1140 assert_eq!(accumulated, expected);
1141 }
1142
1143 #[test]
1144 fn test_accumulate_multiple_with_nulls() {
1145 let group_indices = vec![0, 1, 0, 1];
1146 let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1147 let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1148 let value_columns = [values1, values2];
1149
1150 let mut accumulated = vec![];
1151 accumulate_multiple(
1152 &group_indices,
1153 &value_columns.iter().collect::<Vec<_>>(),
1154 None,
1155 |group_idx, batch_idx, columns| {
1156 let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1157 accumulated.push((group_idx, values));
1158 },
1159 );
1160
1161 let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])];
1163 assert_eq!(accumulated, expected);
1164 }
1165
1166 #[test]
1167 fn test_accumulate_multiple_with_filter() {
1168 let group_indices = vec![0, 1, 0, 1];
1169 let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1170 let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1171 let value_columns = [values1, values2];
1172
1173 let filter = BooleanArray::from(vec![true, false, true, false]);
1174
1175 let mut accumulated = vec![];
1176 accumulate_multiple(
1177 &group_indices,
1178 &value_columns.iter().collect::<Vec<_>>(),
1179 Some(&filter),
1180 |group_idx, batch_idx, columns| {
1181 let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1182 accumulated.push((group_idx, values));
1183 },
1184 );
1185
1186 let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
1188 assert_eq!(accumulated, expected);
1189 }
1190
1191 #[test]
1192 fn test_accumulate_multiple_with_nulls_and_filter() {
1193 let group_indices = vec![0, 1, 0, 1];
1194 let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1195 let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1196 let value_columns = [values1, values2];
1197
1198 let filter = BooleanArray::from(vec![true, true, true, false]);
1199
1200 let mut accumulated = vec![];
1201 accumulate_multiple(
1202 &group_indices,
1203 &value_columns.iter().collect::<Vec<_>>(),
1204 Some(&filter),
1205 |group_idx, batch_idx, columns| {
1206 let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1207 accumulated.push((group_idx, values));
1208 },
1209 );
1210
1211 let expected = [(0, vec![1, 10])];
1216 assert_eq!(accumulated, expected);
1217 }
1218}