1use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
23use arrow::buffer::NullBuffer;
24use arrow::datatypes::ArrowPrimitiveType;
25
26use crate::aggregate::groups_accumulator::nulls::filter_to_validity;
27use datafusion_expr_common::groups_accumulator::EmitTo;
28
29#[derive(Debug)]
39pub enum SeenValues {
40 All {
42 num_values: usize,
43 },
44 Some {
46 values: BooleanBufferBuilder,
47 },
48}
49
50impl Default for SeenValues {
51 fn default() -> Self {
52 SeenValues::All { num_values: 0 }
53 }
54}
55
56impl SeenValues {
57 fn get_builder(&mut self, total_num_groups: usize) -> &mut BooleanBufferBuilder {
65 match self {
66 SeenValues::All { num_values } => {
67 let mut builder = BooleanBufferBuilder::new(total_num_groups);
68 builder.append_n(*num_values, true);
69 if total_num_groups > *num_values {
70 builder.append_n(total_num_groups - *num_values, false);
71 }
72 *self = SeenValues::Some { values: builder };
73 match self {
74 SeenValues::Some { values } => values,
75 _ => unreachable!(),
76 }
77 }
78 SeenValues::Some { values } => {
79 if values.len() < total_num_groups {
80 values.append_n(total_num_groups - values.len(), false);
81 }
82 values
83 }
84 }
85 }
86}
87
88#[derive(Debug)]
114pub struct NullState {
115 seen_values: SeenValues,
125}
126
127impl Default for NullState {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133impl NullState {
134 pub fn new() -> Self {
135 Self {
136 seen_values: SeenValues::All { num_values: 0 },
137 }
138 }
139
140 pub fn size(&self) -> usize {
142 match &self.seen_values {
143 SeenValues::All { .. } => 0,
144 SeenValues::Some { values } => values.capacity() / 8,
145 }
146 }
147
148 pub fn accumulate<T, F>(
165 &mut self,
166 group_indices: &[usize],
167 values: &PrimitiveArray<T>,
168 opt_filter: Option<&BooleanArray>,
169 total_num_groups: usize,
170 mut value_fn: F,
171 ) where
172 T: ArrowPrimitiveType + Send,
173 F: FnMut(usize, T::Native) + Send,
174 {
175 if let SeenValues::All { num_values } = &mut self.seen_values
177 && opt_filter.is_none()
178 && values.null_count() == 0
179 {
180 accumulate(group_indices, values, None, value_fn);
181 *num_values = total_num_groups;
182 return;
183 }
184
185 let seen_values = self.seen_values.get_builder(total_num_groups);
186 accumulate(group_indices, values, opt_filter, |group_index, value| {
187 seen_values.set_bit(group_index, true);
188 value_fn(group_index, value);
189 });
190 }
191
192 pub fn accumulate_boolean<F>(
203 &mut self,
204 group_indices: &[usize],
205 values: &BooleanArray,
206 opt_filter: Option<&BooleanArray>,
207 total_num_groups: usize,
208 mut value_fn: F,
209 ) where
210 F: FnMut(usize, bool) + Send,
211 {
212 let data = values.values();
213 assert_eq!(data.len(), group_indices.len());
214
215 if let SeenValues::All { num_values } = &mut self.seen_values
217 && opt_filter.is_none()
218 && values.null_count() == 0
219 {
220 group_indices
221 .iter()
222 .zip(data.iter())
223 .for_each(|(&group_index, new_value)| value_fn(group_index, new_value));
224 *num_values = total_num_groups;
225
226 return;
227 }
228
229 let seen_values = self.seen_values.get_builder(total_num_groups);
230
231 match (values.null_count() > 0, opt_filter) {
233 (false, None) => {
235 group_indices.iter().zip(data.iter()).for_each(
238 |(&group_index, new_value)| {
239 seen_values.set_bit(group_index, true);
240 value_fn(group_index, new_value)
241 },
242 )
243 }
244 (true, None) => {
246 let nulls = values.nulls().unwrap();
247 group_indices
248 .iter()
249 .zip(data.iter())
250 .zip(nulls.iter())
251 .for_each(|((&group_index, new_value), is_valid)| {
252 if is_valid {
253 seen_values.set_bit(group_index, true);
254 value_fn(group_index, new_value);
255 }
256 })
257 }
258 (false, Some(filter)) => {
260 assert_eq!(filter.len(), group_indices.len());
261
262 group_indices
263 .iter()
264 .zip(data.iter())
265 .zip(filter.iter())
266 .for_each(|((&group_index, new_value), filter_value)| {
267 if let Some(true) = filter_value {
268 seen_values.set_bit(group_index, true);
269 value_fn(group_index, new_value);
270 }
271 })
272 }
273 (true, Some(filter)) => {
275 assert_eq!(filter.len(), group_indices.len());
276 filter
277 .iter()
278 .zip(group_indices.iter())
279 .zip(values.iter())
280 .for_each(|((filter_value, &group_index), new_value)| {
281 if let Some(true) = filter_value
282 && let Some(new_value) = new_value
283 {
284 seen_values.set_bit(group_index, true);
285 value_fn(group_index, new_value)
286 }
287 })
288 }
289 }
290 }
291
292 pub fn build(&mut self, emit_to: EmitTo) -> Option<NullBuffer> {
298 match emit_to {
299 EmitTo::All => {
300 let old_seen = std::mem::take(&mut self.seen_values);
301 match old_seen {
302 SeenValues::All { .. } => None,
303 SeenValues::Some { mut values } => {
304 Some(NullBuffer::new(values.finish()))
305 }
306 }
307 }
308 EmitTo::First(n) => match &mut self.seen_values {
309 SeenValues::All { num_values } => {
310 *num_values = num_values.saturating_sub(n);
311 None
312 }
313 SeenValues::Some { .. } => {
314 let mut old_values = match std::mem::take(&mut self.seen_values) {
315 SeenValues::Some { values } => values,
316 _ => unreachable!(),
317 };
318 let nulls = old_values.finish();
319 let first_n_null = nulls.slice(0, n);
320 let remainder = nulls.slice(n, nulls.len() - n);
321 let mut new_builder = BooleanBufferBuilder::new(remainder.len());
322 new_builder.append_buffer(&remainder);
323 self.seen_values = SeenValues::Some {
324 values: new_builder,
325 };
326 Some(NullBuffer::new(first_n_null))
327 }
328 },
329 }
330 }
331}
332
333pub fn accumulate<T, F>(
372 group_indices: &[usize],
373 values: &PrimitiveArray<T>,
374 opt_filter: Option<&BooleanArray>,
375 mut value_fn: F,
376) where
377 T: ArrowPrimitiveType + Send,
378 F: FnMut(usize, T::Native) + Send,
379{
380 let data: &[T::Native] = values.values();
381 assert_eq!(data.len(), group_indices.len());
382
383 match (values.null_count() > 0, opt_filter) {
384 (false, None) => {
386 let iter = group_indices.iter().zip(data.iter());
387 for (&group_index, &new_value) in iter {
388 value_fn(group_index, new_value);
389 }
390 }
391 (true, None) => {
393 let nulls = values.nulls().unwrap();
394 let group_indices_chunks = group_indices.chunks_exact(64);
397 let data_chunks = data.chunks_exact(64);
398 let bit_chunks = nulls.inner().bit_chunks();
399
400 let group_indices_remainder = group_indices_chunks.remainder();
401 let data_remainder = data_chunks.remainder();
402
403 group_indices_chunks
404 .zip(data_chunks)
405 .zip(bit_chunks.iter())
406 .for_each(|((group_index_chunk, data_chunk), mask)| {
407 let mut index_mask = 1;
409 group_index_chunk.iter().zip(data_chunk.iter()).for_each(
410 |(&group_index, &new_value)| {
411 let is_valid = (mask & index_mask) != 0;
413 if is_valid {
414 value_fn(group_index, new_value);
415 }
416 index_mask <<= 1;
417 },
418 )
419 });
420
421 let remainder_bits = bit_chunks.remainder_bits();
423 group_indices_remainder
424 .iter()
425 .zip(data_remainder.iter())
426 .enumerate()
427 .for_each(|(i, (&group_index, &new_value))| {
428 let is_valid = remainder_bits & (1 << i) != 0;
429 if is_valid {
430 value_fn(group_index, new_value);
431 }
432 });
433 }
434 (false, Some(filter)) => {
436 assert_eq!(filter.len(), group_indices.len());
437 group_indices
441 .iter()
442 .zip(data.iter())
443 .zip(filter.iter())
444 .for_each(|((&group_index, &new_value), filter_value)| {
445 if let Some(true) = filter_value {
446 value_fn(group_index, new_value);
447 }
448 })
449 }
450 (true, Some(filter)) => {
452 assert_eq!(filter.len(), group_indices.len());
453 filter
457 .iter()
458 .zip(group_indices.iter())
459 .zip(values.iter())
460 .for_each(|((filter_value, &group_index), new_value)| {
461 if let Some(true) = filter_value
462 && let Some(new_value) = new_value
463 {
464 value_fn(group_index, new_value)
465 }
466 })
467 }
468 }
469}
470
471pub fn accumulate_multiple<T, F>(
487 group_indices: &[usize],
488 value_columns: &[&PrimitiveArray<T>],
489 opt_filter: Option<&BooleanArray>,
490 mut value_fn: F,
491) where
492 T: ArrowPrimitiveType + Send,
493 F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
494{
495 for col in value_columns.iter() {
496 debug_assert_eq!(col.len(), group_indices.len());
497 }
498
499 let mut valid_indices =
501 NullBuffer::union_many(value_columns.iter().map(|arr| arr.nulls()))
502 .map(NullBuffer::into_inner);
503
504 if let Some(filter) = opt_filter {
508 debug_assert_eq!(filter.len(), group_indices.len());
509 let filter_validity = filter_to_validity(filter);
510 if let Some(valid_indices) = valid_indices.as_mut() {
511 *valid_indices &= &filter_validity;
512 } else {
513 valid_indices = Some(filter_validity);
514 }
515 }
516
517 match valid_indices {
518 None => {
519 for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
520 value_fn(group_idx, batch_idx, value_columns);
521 }
522 }
523 Some(valid_indices) => {
524 for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
525 if valid_indices.value(batch_idx) {
526 value_fn(group_idx, batch_idx, value_columns);
527 }
528 }
529 }
530 }
531}
532
533pub fn accumulate_indices<F>(
543 group_indices: &[usize],
544 nulls: Option<&NullBuffer>,
545 opt_filter: Option<&BooleanArray>,
546 mut index_fn: F,
547) where
548 F: FnMut(usize) + Send,
549{
550 match (nulls, opt_filter) {
551 (None, None) => {
552 for &group_index in group_indices.iter() {
553 index_fn(group_index)
554 }
555 }
556 (None, Some(filter)) => {
557 debug_assert_eq!(filter.len(), group_indices.len());
558 let group_indices_chunks = group_indices.chunks_exact(64);
559 let filter_validity = filter_to_validity(filter);
560 let bit_chunks = filter_validity.bit_chunks();
561
562 let group_indices_remainder = group_indices_chunks.remainder();
563
564 group_indices_chunks.zip(bit_chunks.iter()).for_each(
565 |(group_index_chunk, mask)| {
566 let mut index_mask = 1;
568 group_index_chunk.iter().for_each(|&group_index| {
569 let is_valid = (mask & index_mask) != 0;
571 if is_valid {
572 index_fn(group_index);
573 }
574 index_mask <<= 1;
575 })
576 },
577 );
578
579 let remainder_bits = bit_chunks.remainder_bits();
581 group_indices_remainder
582 .iter()
583 .enumerate()
584 .for_each(|(i, &group_index)| {
585 let is_valid = remainder_bits & (1 << i) != 0;
586 if is_valid {
587 index_fn(group_index)
588 }
589 });
590 }
591 (Some(valids), None) => {
592 debug_assert_eq!(valids.len(), group_indices.len());
593 let group_indices_chunks = group_indices.chunks_exact(64);
596 let bit_chunks = valids.inner().bit_chunks();
597
598 let group_indices_remainder = group_indices_chunks.remainder();
599
600 group_indices_chunks.zip(bit_chunks.iter()).for_each(
601 |(group_index_chunk, mask)| {
602 let mut index_mask = 1;
604 group_index_chunk.iter().for_each(|&group_index| {
605 let is_valid = (mask & index_mask) != 0;
607 if is_valid {
608 index_fn(group_index);
609 }
610 index_mask <<= 1;
611 })
612 },
613 );
614
615 let remainder_bits = bit_chunks.remainder_bits();
617 group_indices_remainder
618 .iter()
619 .enumerate()
620 .for_each(|(i, &group_index)| {
621 let is_valid = remainder_bits & (1 << i) != 0;
622 if is_valid {
623 index_fn(group_index)
624 }
625 });
626 }
627
628 (Some(valids), Some(filter)) => {
629 debug_assert_eq!(filter.len(), group_indices.len());
630 debug_assert_eq!(valids.len(), group_indices.len());
631
632 let group_indices_chunks = group_indices.chunks_exact(64);
633 let valid_bit_chunks = valids.inner().bit_chunks();
634 let filter_validity = filter_to_validity(filter);
635 let filter_bit_chunks = filter_validity.bit_chunks();
636
637 let group_indices_remainder = group_indices_chunks.remainder();
638
639 group_indices_chunks
640 .zip(valid_bit_chunks.iter())
641 .zip(filter_bit_chunks.iter())
642 .for_each(|((group_index_chunk, valid_mask), filter_mask)| {
643 let mut index_mask = 1;
645 group_index_chunk.iter().for_each(|&group_index| {
646 let is_valid = (valid_mask & filter_mask & index_mask) != 0;
648 if is_valid {
649 index_fn(group_index);
650 }
651 index_mask <<= 1;
652 })
653 });
654
655 let remainder_valid_bits = valid_bit_chunks.remainder_bits();
657 let remainder_filter_bits = filter_bit_chunks.remainder_bits();
658 group_indices_remainder
659 .iter()
660 .enumerate()
661 .for_each(|(i, &group_index)| {
662 let is_valid =
663 remainder_valid_bits & remainder_filter_bits & (1 << i) != 0;
664 if is_valid {
665 index_fn(group_index)
666 }
667 });
668 }
669 }
670}
671
672#[cfg(test)]
673mod test {
674 use super::*;
675
676 use arrow::{
677 array::{Int32Array, UInt32Array},
678 buffer::BooleanBuffer,
679 };
680 use rand::{Rng, rngs::ThreadRng};
681 use std::collections::HashSet;
682
683 #[test]
684 fn accumulate() {
685 let group_indices = (0..100).collect();
686 let values = (0..100).map(|i| (i + 1) * 10).collect();
687 let values_with_nulls = (0..100)
688 .map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) })
689 .collect();
690
691 let filter: BooleanArray = (0..100)
694 .map(|i| {
695 let is_even = i % 2 == 0;
696 let is_fifth = i % 5 == 0;
697 if is_even {
698 None
699 } else if is_fifth {
700 Some(false)
701 } else {
702 Some(true)
703 }
704 })
705 .collect();
706
707 Fixture {
708 group_indices,
709 values,
710 values_with_nulls,
711 filter,
712 }
713 .run()
714 }
715
716 #[test]
717 fn accumulate_fuzz() {
718 let mut rng = rand::rng();
719 for _ in 0..100 {
720 Fixture::new_random(&mut rng).run();
721 }
722 }
723
724 struct Fixture {
726 group_indices: Vec<usize>,
728
729 values: Vec<u32>,
731
732 values_with_nulls: Vec<Option<u32>>,
735
736 filter: BooleanArray,
738 }
739
740 impl Fixture {
741 fn new_random(rng: &mut ThreadRng) -> Self {
742 let num_values: usize = rng.random_range(1..200);
744 let num_groups: usize = rng.random_range(2..1000);
746 let max_group = num_groups - 1;
747
748 let group_indices: Vec<usize> = (0..num_values)
749 .map(|_| rng.random_range(0..max_group))
750 .collect();
751
752 let values: Vec<u32> = (0..num_values).map(|_| rng.random()).collect();
753
754 let filter: BooleanArray = (0..num_values)
758 .map(|_| {
759 let filter_value = rng.random_range(0.0..1.0);
760 if filter_value < 0.1 {
761 Some(false)
762 } else if filter_value < 0.2 {
763 None
764 } else {
765 Some(true)
766 }
767 })
768 .collect();
769
770 let null_pct: f32 = rng.random_range(0.0..1.0);
773 let values_with_nulls: Vec<Option<u32>> = (0..num_values)
774 .map(|_| {
775 let is_null = null_pct < rng.random_range(0.0..1.0);
776 if is_null { None } else { Some(rng.random()) }
777 })
778 .collect();
779
780 Self {
781 group_indices,
782 values,
783 values_with_nulls,
784 filter,
785 }
786 }
787
788 fn values_array(&self) -> UInt32Array {
790 UInt32Array::from(self.values.clone())
791 }
792
793 fn values_with_nulls_array(&self) -> UInt32Array {
795 UInt32Array::from(self.values_with_nulls.clone())
796 }
797
798 fn run(&self) {
801 let total_num_groups = *self.group_indices.iter().max().unwrap() + 1;
802
803 let group_indices = &self.group_indices;
804 let values_array = self.values_array();
805 let values_with_nulls_array = self.values_with_nulls_array();
806 let filter = &self.filter;
807
808 Self::accumulate_test(group_indices, &values_array, None, total_num_groups);
810
811 Self::accumulate_test(
813 group_indices,
814 &values_with_nulls_array,
815 None,
816 total_num_groups,
817 );
818
819 Self::accumulate_test(
821 group_indices,
822 &values_array,
823 Some(filter),
824 total_num_groups,
825 );
826
827 Self::accumulate_test(
829 group_indices,
830 &values_with_nulls_array,
831 Some(filter),
832 total_num_groups,
833 );
834 }
835
836 fn accumulate_test(
839 group_indices: &[usize],
840 values: &UInt32Array,
841 opt_filter: Option<&BooleanArray>,
842 total_num_groups: usize,
843 ) {
844 Self::accumulate_values_test(
845 group_indices,
846 values,
847 opt_filter,
848 total_num_groups,
849 );
850 Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter);
851
852 let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum();
855 let boolean_values: BooleanArray =
856 values.iter().map(|v| v.map(|v| v as usize > avg)).collect();
857 Self::accumulate_boolean_test(
858 group_indices,
859 &boolean_values,
860 opt_filter,
861 total_num_groups,
862 );
863 }
864
865 fn accumulate_values_test(
868 group_indices: &[usize],
869 values: &UInt32Array,
870 opt_filter: Option<&BooleanArray>,
871 total_num_groups: usize,
872 ) {
873 let mut accumulated_values = vec![];
874 let mut null_state = NullState::new();
875
876 null_state.accumulate(
877 group_indices,
878 values,
879 opt_filter,
880 total_num_groups,
881 |group_index, value| {
882 accumulated_values.push((group_index, value));
883 },
884 );
885
886 let mut expected_values = vec![];
888 let mut mock = MockNullState::new();
889
890 match opt_filter {
891 None => group_indices.iter().zip(values.iter()).for_each(
892 |(&group_index, value)| {
893 if let Some(value) = value {
894 mock.saw_value(group_index);
895 expected_values.push((group_index, value));
896 }
897 },
898 ),
899 Some(filter) => {
900 group_indices
901 .iter()
902 .zip(values.iter())
903 .zip(filter.iter())
904 .for_each(|((&group_index, value), is_included)| {
905 if let Some(true) = is_included
907 && let Some(value) = value
908 {
909 mock.saw_value(group_index);
910 expected_values.push((group_index, value));
911 }
912 });
913 }
914 }
915
916 assert_eq!(
917 accumulated_values, expected_values,
918 "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
919 );
920
921 match &null_state.seen_values {
922 SeenValues::All { num_values } => {
923 assert_eq!(*num_values, total_num_groups);
924 }
925 SeenValues::Some { values } => {
926 let seen_values = values.finish_cloned();
927 mock.validate_seen_values(&seen_values);
928 }
929 }
930
931 let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
933
934 let null_buffer = null_state.build(EmitTo::All);
935 if let Some(nulls) = &null_buffer {
936 assert_eq!(*nulls, expected_null_buffer);
937 }
938 }
939
940 fn accumulate_indices_test(
943 group_indices: &[usize],
944 nulls: Option<&NullBuffer>,
945 opt_filter: Option<&BooleanArray>,
946 ) {
947 let mut accumulated_values = vec![];
948
949 accumulate_indices(group_indices, nulls, opt_filter, |group_index| {
950 accumulated_values.push(group_index);
951 });
952
953 let mut expected_values = vec![];
955
956 match (nulls, opt_filter) {
957 (None, None) => group_indices.iter().for_each(|&group_index| {
958 expected_values.push(group_index);
959 }),
960 (Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each(
961 |(&group_index, is_valid)| {
962 if is_valid {
963 expected_values.push(group_index);
964 }
965 },
966 ),
967 (None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each(
968 |(&group_index, is_included)| {
969 if let Some(true) = is_included {
970 expected_values.push(group_index);
971 }
972 },
973 ),
974 (Some(nulls), Some(filter)) => {
975 group_indices
976 .iter()
977 .zip(nulls.iter())
978 .zip(filter.iter())
979 .for_each(|((&group_index, is_valid), is_included)| {
980 if let (true, Some(true)) = (is_valid, is_included) {
982 expected_values.push(group_index);
983 }
984 });
985 }
986 }
987
988 assert_eq!(
989 accumulated_values, expected_values,
990 "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
991 );
992 }
993
994 fn accumulate_boolean_test(
997 group_indices: &[usize],
998 values: &BooleanArray,
999 opt_filter: Option<&BooleanArray>,
1000 total_num_groups: usize,
1001 ) {
1002 let mut accumulated_values = vec![];
1003 let mut null_state = NullState::new();
1004
1005 null_state.accumulate_boolean(
1006 group_indices,
1007 values,
1008 opt_filter,
1009 total_num_groups,
1010 |group_index, value| {
1011 accumulated_values.push((group_index, value));
1012 },
1013 );
1014
1015 let mut expected_values = vec![];
1017 let mut mock = MockNullState::new();
1018
1019 match opt_filter {
1020 None => group_indices.iter().zip(values.iter()).for_each(
1021 |(&group_index, value)| {
1022 if let Some(value) = value {
1023 mock.saw_value(group_index);
1024 expected_values.push((group_index, value));
1025 }
1026 },
1027 ),
1028 Some(filter) => {
1029 group_indices
1030 .iter()
1031 .zip(values.iter())
1032 .zip(filter.iter())
1033 .for_each(|((&group_index, value), is_included)| {
1034 if let Some(true) = is_included
1036 && let Some(value) = value
1037 {
1038 mock.saw_value(group_index);
1039 expected_values.push((group_index, value));
1040 }
1041 });
1042 }
1043 }
1044
1045 assert_eq!(
1046 accumulated_values, expected_values,
1047 "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
1048 );
1049
1050 match &null_state.seen_values {
1051 SeenValues::All { num_values } => {
1052 assert_eq!(*num_values, total_num_groups);
1053 }
1054 SeenValues::Some { values } => {
1055 let seen_values = values.finish_cloned();
1056 mock.validate_seen_values(&seen_values);
1057 }
1058 }
1059
1060 let expected_null_buffer = Some(mock.expected_null_buffer(total_num_groups));
1062
1063 let is_all_seen = matches!(null_state.seen_values, SeenValues::All { .. });
1064 let null_buffer = null_state.build(EmitTo::All);
1065
1066 if !is_all_seen {
1067 assert_eq!(null_buffer, expected_null_buffer);
1068 }
1069 }
1070 }
1071
1072 #[derive(Debug, Default)]
1074 struct MockNullState {
1075 seen_values: HashSet<usize>,
1077 }
1078
1079 impl MockNullState {
1080 fn new() -> Self {
1081 Default::default()
1082 }
1083
1084 fn saw_value(&mut self, group_index: usize) {
1085 self.seen_values.insert(group_index);
1086 }
1087
1088 fn expected_seen(&self, group_index: usize) -> bool {
1090 self.seen_values.contains(&group_index)
1091 }
1092
1093 fn validate_seen_values(&self, seen_values: &BooleanBuffer) {
1095 for (group_index, is_seen) in seen_values.iter().enumerate() {
1096 let expected_seen = self.expected_seen(group_index);
1097 assert_eq!(
1098 expected_seen, is_seen,
1099 "mismatch at for group {group_index}"
1100 );
1101 }
1102 }
1103
1104 fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer {
1106 (0..total_num_groups)
1107 .map(|group_index| self.expected_seen(group_index))
1108 .collect()
1109 }
1110 }
1111
1112 #[test]
1113 fn test_accumulate_multiple_no_nulls_no_filter() {
1114 let group_indices = vec![0, 1, 0, 1];
1115 let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1116 let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1117 let value_columns = [values1, values2];
1118
1119 let mut accumulated = vec![];
1120 accumulate_multiple(
1121 &group_indices,
1122 &value_columns.iter().collect::<Vec<_>>(),
1123 None,
1124 |group_idx, batch_idx, columns| {
1125 let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1126 accumulated.push((group_idx, values));
1127 },
1128 );
1129
1130 let expected = vec![
1131 (0, vec![1, 10]),
1132 (1, vec![2, 20]),
1133 (0, vec![3, 30]),
1134 (1, vec![4, 40]),
1135 ];
1136 assert_eq!(accumulated, expected);
1137 }
1138
1139 #[test]
1140 fn test_accumulate_multiple_with_nulls() {
1141 let group_indices = vec![0, 1, 0, 1];
1142 let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1143 let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1144 let value_columns = [values1, values2];
1145
1146 let mut accumulated = vec![];
1147 accumulate_multiple(
1148 &group_indices,
1149 &value_columns.iter().collect::<Vec<_>>(),
1150 None,
1151 |group_idx, batch_idx, columns| {
1152 let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1153 accumulated.push((group_idx, values));
1154 },
1155 );
1156
1157 let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])];
1159 assert_eq!(accumulated, expected);
1160 }
1161
1162 #[test]
1163 fn test_accumulate_multiple_with_filter() {
1164 let group_indices = vec![0, 1, 0, 1];
1165 let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1166 let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1167 let value_columns = [values1, values2];
1168
1169 let filter = BooleanArray::from(vec![true, false, true, false]);
1170
1171 let mut accumulated = vec![];
1172 accumulate_multiple(
1173 &group_indices,
1174 &value_columns.iter().collect::<Vec<_>>(),
1175 Some(&filter),
1176 |group_idx, batch_idx, columns| {
1177 let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1178 accumulated.push((group_idx, values));
1179 },
1180 );
1181
1182 let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
1184 assert_eq!(accumulated, expected);
1185 }
1186
1187 #[test]
1188 fn test_accumulate_indices_with_null_filter() {
1189 let group_indices = vec![0, 1, 0, 1];
1190 let filter = BooleanArray::new(
1191 BooleanBuffer::from(vec![true, true, true, false]),
1192 Some(NullBuffer::from(vec![true, false, true, true])),
1193 );
1194
1195 let mut accumulated = vec![];
1196 accumulate_indices(&group_indices, None, Some(&filter), |group_idx| {
1197 accumulated.push(group_idx);
1198 });
1199
1200 let expected = vec![0, 0];
1203 assert_eq!(accumulated, expected);
1204
1205 let value_validity = NullBuffer::from(vec![true, true, false, true]);
1206 let mut accumulated = vec![];
1207 accumulate_indices(
1208 &group_indices,
1209 Some(&value_validity),
1210 Some(&filter),
1211 |group_idx| {
1212 accumulated.push(group_idx);
1213 },
1214 );
1215
1216 let expected = vec![0];
1217 assert_eq!(accumulated, expected);
1218 }
1219
1220 #[test]
1221 fn test_accumulate_multiple_with_null_filter() {
1222 let group_indices = vec![0, 1, 0, 1];
1223 let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1224 let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1225 let value_columns = [values1, values2];
1226
1227 let filter = BooleanArray::new(
1228 BooleanBuffer::from(vec![true, true, true, false]),
1229 Some(NullBuffer::from(vec![true, false, true, true])),
1230 );
1231
1232 let mut accumulated = vec![];
1233 accumulate_multiple(
1234 &group_indices,
1235 &value_columns.iter().collect::<Vec<_>>(),
1236 Some(&filter),
1237 |group_idx, batch_idx, columns| {
1238 let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1239 accumulated.push((group_idx, values));
1240 },
1241 );
1242
1243 let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
1246 assert_eq!(accumulated, expected);
1247 }
1248
1249 #[test]
1250 fn test_accumulate_multiple_with_nulls_and_filter() {
1251 let group_indices = vec![0, 1, 0, 1];
1252 let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1253 let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1254 let value_columns = [values1, values2];
1255
1256 let filter = BooleanArray::from(vec![true, true, true, false]);
1257
1258 let mut accumulated = vec![];
1259 accumulate_multiple(
1260 &group_indices,
1261 &value_columns.iter().collect::<Vec<_>>(),
1262 Some(&filter),
1263 |group_idx, batch_idx, columns| {
1264 let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1265 accumulated.push((group_idx, values));
1266 },
1267 );
1268
1269 let expected = [(0, vec![1, 10])];
1274 assert_eq!(accumulated, expected);
1275 }
1276}