datafusion_functions_aggregate_common/aggregate/groups_accumulator/
accumulate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`]
19//!
20//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
21
22use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
23use arrow::buffer::{BooleanBuffer, NullBuffer};
24use arrow::datatypes::ArrowPrimitiveType;
25
26use datafusion_expr_common::groups_accumulator::EmitTo;
27/// Track the accumulator null state per row: if any values for that
28/// group were null and if any values have been seen at all for that group.
29///
30/// This is part of the inner loop for many [`GroupsAccumulator`]s,
31/// and thus the performance is critical and so there are multiple
32/// specialized implementations, invoked depending on the specific
33/// combinations of the input.
34///
35/// Typically there are 4 potential combinations of inputs must be
36/// special cased for performance:
37///
38/// * With / Without filter
39/// * With / Without nulls in the input
40///
41/// If the input has nulls, then the accumulator must potentially
42/// handle each input null value specially (e.g. for `SUM` to mark the
43/// corresponding sum as null)
44///
45/// If there are filters present, `NullState` tracks if it has seen
46/// *any* value for that group (as some values may be filtered
47/// out). Without a filter, the accumulator is only passed groups that
48/// had at least one value to accumulate so they do not need to track
49/// if they have seen values for a particular group.
50///
51/// [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
52#[derive(Debug)]
53pub struct NullState {
54    /// Have we seen any non-filtered input values for `group_index`?
55    ///
56    /// If `seen_values[i]` is true, have seen at least one non null
57    /// value for group `i`
58    ///
59    /// If `seen_values[i]` is false, have not seen any values that
60    /// pass the filter yet for group `i`
61    seen_values: BooleanBufferBuilder,
62}
63
64impl Default for NullState {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl NullState {
71    pub fn new() -> Self {
72        Self {
73            seen_values: BooleanBufferBuilder::new(0),
74        }
75    }
76
77    /// return the size of all buffers allocated by this null state, not including self
78    pub fn size(&self) -> usize {
79        // capacity is in bits, so convert to bytes
80        self.seen_values.capacity() / 8
81    }
82
83    /// Invokes `value_fn(group_index, value)` for each non null, non
84    /// filtered value of `value`, while tracking which groups have
85    /// seen null inputs and which groups have seen any inputs if necessary
86    //
87    /// # Arguments:
88    ///
89    /// * `values`: the input arguments to the accumulator
90    /// * `group_indices`:  To which groups do the rows in `values` belong, (aka group_index)
91    /// * `opt_filter`: if present, only rows for which is Some(true) are included
92    /// * `value_fn`: function invoked for  (group_index, value) where value is non null
93    ///
94    /// See [`accumulate`], for more details on how value_fn is called
95    ///
96    /// When value_fn is called it also sets
97    ///
98    /// 1. `self.seen_values[group_index]` to true for all rows that had a non null value
99    pub fn accumulate<T, F>(
100        &mut self,
101        group_indices: &[usize],
102        values: &PrimitiveArray<T>,
103        opt_filter: Option<&BooleanArray>,
104        total_num_groups: usize,
105        mut value_fn: F,
106    ) where
107        T: ArrowPrimitiveType + Send,
108        F: FnMut(usize, T::Native) + Send,
109    {
110        // ensure the seen_values is big enough (start everything at
111        // "not seen" valid)
112        let seen_values =
113            initialize_builder(&mut self.seen_values, total_num_groups, false);
114        accumulate(group_indices, values, opt_filter, |group_index, value| {
115            seen_values.set_bit(group_index, true);
116            value_fn(group_index, value);
117        });
118    }
119
120    /// Invokes `value_fn(group_index, value)` for each non null, non
121    /// filtered value in `values`, while tracking which groups have
122    /// seen null inputs and which groups have seen any inputs, for
123    /// [`BooleanArray`]s.
124    ///
125    /// Since `BooleanArray` is not a [`PrimitiveArray`] it must be
126    /// handled specially.
127    ///
128    /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for
129    /// more details on other arguments.
130    pub fn accumulate_boolean<F>(
131        &mut self,
132        group_indices: &[usize],
133        values: &BooleanArray,
134        opt_filter: Option<&BooleanArray>,
135        total_num_groups: usize,
136        mut value_fn: F,
137    ) where
138        F: FnMut(usize, bool) + Send,
139    {
140        let data = values.values();
141        assert_eq!(data.len(), group_indices.len());
142
143        // ensure the seen_values is big enough (start everything at
144        // "not seen" valid)
145        let seen_values =
146            initialize_builder(&mut self.seen_values, total_num_groups, false);
147
148        // These could be made more performant by iterating in chunks of 64 bits at a time
149        match (values.null_count() > 0, opt_filter) {
150            // no nulls, no filter,
151            (false, None) => {
152                // if we have previously seen nulls, ensure the null
153                // buffer is big enough (start everything at valid)
154                group_indices.iter().zip(data.iter()).for_each(
155                    |(&group_index, new_value)| {
156                        seen_values.set_bit(group_index, true);
157                        value_fn(group_index, new_value)
158                    },
159                )
160            }
161            // nulls, no filter
162            (true, None) => {
163                let nulls = values.nulls().unwrap();
164                group_indices
165                    .iter()
166                    .zip(data.iter())
167                    .zip(nulls.iter())
168                    .for_each(|((&group_index, new_value), is_valid)| {
169                        if is_valid {
170                            seen_values.set_bit(group_index, true);
171                            value_fn(group_index, new_value);
172                        }
173                    })
174            }
175            // no nulls, but a filter
176            (false, Some(filter)) => {
177                assert_eq!(filter.len(), group_indices.len());
178
179                group_indices
180                    .iter()
181                    .zip(data.iter())
182                    .zip(filter.iter())
183                    .for_each(|((&group_index, new_value), filter_value)| {
184                        if let Some(true) = filter_value {
185                            seen_values.set_bit(group_index, true);
186                            value_fn(group_index, new_value);
187                        }
188                    })
189            }
190            // both null values and filters
191            (true, Some(filter)) => {
192                assert_eq!(filter.len(), group_indices.len());
193                filter
194                    .iter()
195                    .zip(group_indices.iter())
196                    .zip(values.iter())
197                    .for_each(|((filter_value, &group_index), new_value)| {
198                        if let Some(true) = filter_value
199                            && let Some(new_value) = new_value
200                        {
201                            seen_values.set_bit(group_index, true);
202                            value_fn(group_index, new_value)
203                        }
204                    })
205            }
206        }
207    }
208
209    /// Creates the a [`NullBuffer`] representing which group_indices
210    /// should have null values (because they never saw any values)
211    /// for the `emit_to` rows.
212    ///
213    /// resets the internal state appropriately
214    pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
215        let nulls: BooleanBuffer = self.seen_values.finish();
216
217        let nulls = match emit_to {
218            EmitTo::All => nulls,
219            EmitTo::First(n) => {
220                // split off the first N values in seen_values
221                let first_n_null: BooleanBuffer = nulls.slice(0, n);
222                // reset the existing seen buffer
223                self.seen_values
224                    .append_buffer(&nulls.slice(n, nulls.len() - n));
225                first_n_null
226            }
227        };
228        NullBuffer::new(nulls)
229    }
230}
231
232/// Invokes `value_fn(group_index, value)` for each non null, non
233/// filtered value of `value`,
234///
235/// # Arguments:
236///
237/// * `group_indices`:  To which groups do the rows in `values` belong, (aka group_index)
238/// * `values`: the input arguments to the accumulator
239/// * `opt_filter`: if present, only rows for which is Some(true) are included
240/// * `value_fn`: function invoked for  (group_index, value) where value is non null
241///
242/// # Example
243///
244/// ```text
245///  ┌─────────┐   ┌─────────┐   ┌ ─ ─ ─ ─ ┐
246///  │ ┌─────┐ │   │ ┌─────┐ │     ┌─────┐
247///  │ │  2  │ │   │ │ 200 │ │   │ │  t  │ │
248///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
249///  │ │  2  │ │   │ │ 100 │ │   │ │  f  │ │
250///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
251///  │ │  0  │ │   │ │ 200 │ │   │ │  t  │ │
252///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
253///  │ │  1  │ │   │ │ 200 │ │   │ │NULL │ │
254///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
255///  │ │  0  │ │   │ │ 300 │ │   │ │  t  │ │
256///  │ └─────┘ │   │ └─────┘ │     └─────┘
257///  └─────────┘   └─────────┘   └ ─ ─ ─ ─ ┘
258///
259/// group_indices   values        opt_filter
260/// ```
261///
262/// In the example above, `value_fn` is invoked for each (group_index,
263/// value) pair where `opt_filter[i]` is true and values is non null
264///
265/// ```text
266/// value_fn(2, 200)
267/// value_fn(0, 200)
268/// value_fn(0, 300)
269/// ```
270pub fn accumulate<T, F>(
271    group_indices: &[usize],
272    values: &PrimitiveArray<T>,
273    opt_filter: Option<&BooleanArray>,
274    mut value_fn: F,
275) where
276    T: ArrowPrimitiveType + Send,
277    F: FnMut(usize, T::Native) + Send,
278{
279    let data: &[T::Native] = values.values();
280    assert_eq!(data.len(), group_indices.len());
281
282    match (values.null_count() > 0, opt_filter) {
283        // no nulls, no filter,
284        (false, None) => {
285            let iter = group_indices.iter().zip(data.iter());
286            for (&group_index, &new_value) in iter {
287                value_fn(group_index, new_value);
288            }
289        }
290        // nulls, no filter
291        (true, None) => {
292            let nulls = values.nulls().unwrap();
293            // This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum
294            // iterate over in chunks of 64 bits for more efficient null checking
295            let group_indices_chunks = group_indices.chunks_exact(64);
296            let data_chunks = data.chunks_exact(64);
297            let bit_chunks = nulls.inner().bit_chunks();
298
299            let group_indices_remainder = group_indices_chunks.remainder();
300            let data_remainder = data_chunks.remainder();
301
302            group_indices_chunks
303                .zip(data_chunks)
304                .zip(bit_chunks.iter())
305                .for_each(|((group_index_chunk, data_chunk), mask)| {
306                    // index_mask has value 1 << i in the loop
307                    let mut index_mask = 1;
308                    group_index_chunk.iter().zip(data_chunk.iter()).for_each(
309                        |(&group_index, &new_value)| {
310                            // valid bit was set, real value
311                            let is_valid = (mask & index_mask) != 0;
312                            if is_valid {
313                                value_fn(group_index, new_value);
314                            }
315                            index_mask <<= 1;
316                        },
317                    )
318                });
319
320            // handle any remaining bits (after the initial 64)
321            let remainder_bits = bit_chunks.remainder_bits();
322            group_indices_remainder
323                .iter()
324                .zip(data_remainder.iter())
325                .enumerate()
326                .for_each(|(i, (&group_index, &new_value))| {
327                    let is_valid = remainder_bits & (1 << i) != 0;
328                    if is_valid {
329                        value_fn(group_index, new_value);
330                    }
331                });
332        }
333        // no nulls, but a filter
334        (false, Some(filter)) => {
335            assert_eq!(filter.len(), group_indices.len());
336            // The performance with a filter could be improved by
337            // iterating over the filter in chunks, rather than a single
338            // iterator. TODO file a ticket
339            group_indices
340                .iter()
341                .zip(data.iter())
342                .zip(filter.iter())
343                .for_each(|((&group_index, &new_value), filter_value)| {
344                    if let Some(true) = filter_value {
345                        value_fn(group_index, new_value);
346                    }
347                })
348        }
349        // both null values and filters
350        (true, Some(filter)) => {
351            assert_eq!(filter.len(), group_indices.len());
352            // The performance with a filter could be improved by
353            // iterating over the filter in chunks, rather than using
354            // iterators. TODO file a ticket
355            filter
356                .iter()
357                .zip(group_indices.iter())
358                .zip(values.iter())
359                .for_each(|((filter_value, &group_index), new_value)| {
360                    if let Some(true) = filter_value
361                        && let Some(new_value) = new_value
362                    {
363                        value_fn(group_index, new_value)
364                    }
365                })
366        }
367    }
368}
369
370/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`)
371///
372/// This method assumes that for any input record index, if any of the value column
373/// is null, or it's filtered out by `opt_filter`, then the record would be ignored.
374/// (won't be accumulated by `value_fn`)
375///
376/// # Arguments
377///
378/// * `group_indices` - To which groups do the rows in `value_columns` belong
379/// * `value_columns` - The input arrays to accumulate
380/// * `opt_filter` - Optional filter array. If present, only rows where filter is `Some(true)` are included
381/// * `value_fn` - Callback function for each valid row, with parameters:
382///     * `group_idx`: The group index for the current row
383///     * `batch_idx`: The index of the current row in the input arrays
384///     * `columns`: Reference to all input arrays for accessing values
385pub fn accumulate_multiple<T, F>(
386    group_indices: &[usize],
387    value_columns: &[&PrimitiveArray<T>],
388    opt_filter: Option<&BooleanArray>,
389    mut value_fn: F,
390) where
391    T: ArrowPrimitiveType + Send,
392    F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
393{
394    // Calculate `valid_indices` to accumulate, non-valid indices are ignored.
395    // `valid_indices` is a bit mask corresponding to the `group_indices`. An index
396    // is considered valid if:
397    // 1. All columns are non-null at this index.
398    // 2. Not filtered out by `opt_filter`
399
400    // Take AND from all null buffers of `value_columns`.
401    let combined_nulls = value_columns
402        .iter()
403        .map(|arr| arr.logical_nulls())
404        .fold(None, |acc, nulls| {
405            NullBuffer::union(acc.as_ref(), nulls.as_ref())
406        });
407
408    // Take AND from previous combined nulls and `opt_filter`.
409    let valid_indices = match (combined_nulls, opt_filter) {
410        (None, None) => None,
411        (None, Some(filter)) => Some(filter.clone()),
412        (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)),
413        (Some(nulls), Some(filter)) => {
414            let combined = nulls.inner() & filter.values();
415            Some(BooleanArray::new(combined, None))
416        }
417    };
418
419    for col in value_columns.iter() {
420        debug_assert_eq!(col.len(), group_indices.len());
421    }
422
423    match valid_indices {
424        None => {
425            for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
426                value_fn(group_idx, batch_idx, value_columns);
427            }
428        }
429        Some(valid_indices) => {
430            for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
431                if valid_indices.value(batch_idx) {
432                    value_fn(group_idx, batch_idx, value_columns);
433                }
434            }
435        }
436    }
437}
438
439/// This function is called to update the accumulator state per row
440/// when the value is not needed (e.g. COUNT)
441///
442/// `F`: Invoked like `value_fn(group_index) for all non null values
443/// passing the filter. Note that no tracking is done for null inputs
444/// or which groups have seen any values
445///
446/// See [`NullState::accumulate`], for more details on other
447/// arguments.
448pub fn accumulate_indices<F>(
449    group_indices: &[usize],
450    nulls: Option<&NullBuffer>,
451    opt_filter: Option<&BooleanArray>,
452    mut index_fn: F,
453) where
454    F: FnMut(usize) + Send,
455{
456    match (nulls, opt_filter) {
457        (None, None) => {
458            for &group_index in group_indices.iter() {
459                index_fn(group_index)
460            }
461        }
462        (None, Some(filter)) => {
463            debug_assert_eq!(filter.len(), group_indices.len());
464            let group_indices_chunks = group_indices.chunks_exact(64);
465            let bit_chunks = filter.values().bit_chunks();
466
467            let group_indices_remainder = group_indices_chunks.remainder();
468
469            group_indices_chunks.zip(bit_chunks.iter()).for_each(
470                |(group_index_chunk, mask)| {
471                    // index_mask has value 1 << i in the loop
472                    let mut index_mask = 1;
473                    group_index_chunk.iter().for_each(|&group_index| {
474                        // valid bit was set, real vale
475                        let is_valid = (mask & index_mask) != 0;
476                        if is_valid {
477                            index_fn(group_index);
478                        }
479                        index_mask <<= 1;
480                    })
481                },
482            );
483
484            // handle any remaining bits (after the initial 64)
485            let remainder_bits = bit_chunks.remainder_bits();
486            group_indices_remainder
487                .iter()
488                .enumerate()
489                .for_each(|(i, &group_index)| {
490                    let is_valid = remainder_bits & (1 << i) != 0;
491                    if is_valid {
492                        index_fn(group_index)
493                    }
494                });
495        }
496        (Some(valids), None) => {
497            debug_assert_eq!(valids.len(), group_indices.len());
498            // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
499            // iterate over in chunks of 64 bits for more efficient null checking
500            let group_indices_chunks = group_indices.chunks_exact(64);
501            let bit_chunks = valids.inner().bit_chunks();
502
503            let group_indices_remainder = group_indices_chunks.remainder();
504
505            group_indices_chunks.zip(bit_chunks.iter()).for_each(
506                |(group_index_chunk, mask)| {
507                    // index_mask has value 1 << i in the loop
508                    let mut index_mask = 1;
509                    group_index_chunk.iter().for_each(|&group_index| {
510                        // valid bit was set, real vale
511                        let is_valid = (mask & index_mask) != 0;
512                        if is_valid {
513                            index_fn(group_index);
514                        }
515                        index_mask <<= 1;
516                    })
517                },
518            );
519
520            // handle any remaining bits (after the initial 64)
521            let remainder_bits = bit_chunks.remainder_bits();
522            group_indices_remainder
523                .iter()
524                .enumerate()
525                .for_each(|(i, &group_index)| {
526                    let is_valid = remainder_bits & (1 << i) != 0;
527                    if is_valid {
528                        index_fn(group_index)
529                    }
530                });
531        }
532
533        (Some(valids), Some(filter)) => {
534            debug_assert_eq!(filter.len(), group_indices.len());
535            debug_assert_eq!(valids.len(), group_indices.len());
536
537            let group_indices_chunks = group_indices.chunks_exact(64);
538            let valid_bit_chunks = valids.inner().bit_chunks();
539            let filter_bit_chunks = filter.values().bit_chunks();
540
541            let group_indices_remainder = group_indices_chunks.remainder();
542
543            group_indices_chunks
544                .zip(valid_bit_chunks.iter())
545                .zip(filter_bit_chunks.iter())
546                .for_each(|((group_index_chunk, valid_mask), filter_mask)| {
547                    // index_mask has value 1 << i in the loop
548                    let mut index_mask = 1;
549                    group_index_chunk.iter().for_each(|&group_index| {
550                        // valid bit was set, real vale
551                        let is_valid = (valid_mask & filter_mask & index_mask) != 0;
552                        if is_valid {
553                            index_fn(group_index);
554                        }
555                        index_mask <<= 1;
556                    })
557                });
558
559            // handle any remaining bits (after the initial 64)
560            let remainder_valid_bits = valid_bit_chunks.remainder_bits();
561            let remainder_filter_bits = filter_bit_chunks.remainder_bits();
562            group_indices_remainder
563                .iter()
564                .enumerate()
565                .for_each(|(i, &group_index)| {
566                    let is_valid =
567                        remainder_valid_bits & remainder_filter_bits & (1 << i) != 0;
568                    if is_valid {
569                        index_fn(group_index)
570                    }
571                });
572        }
573    }
574}
575
576/// Ensures that `builder` contains a `BooleanBufferBuilder with at
577/// least `total_num_groups`.
578///
579/// All new entries are initialized to `default_value`
580fn initialize_builder(
581    builder: &mut BooleanBufferBuilder,
582    total_num_groups: usize,
583    default_value: bool,
584) -> &mut BooleanBufferBuilder {
585    if builder.len() < total_num_groups {
586        let new_groups = total_num_groups - builder.len();
587        builder.append_n(new_groups, default_value);
588    }
589    builder
590}
591
592#[cfg(test)]
593mod test {
594    use super::*;
595
596    use arrow::array::{Int32Array, UInt32Array};
597    use rand::{Rng, rngs::ThreadRng};
598    use std::collections::HashSet;
599
600    #[test]
601    fn accumulate() {
602        let group_indices = (0..100).collect();
603        let values = (0..100).map(|i| (i + 1) * 10).collect();
604        let values_with_nulls = (0..100)
605            .map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) })
606            .collect();
607
608        // default to every fifth value being false, every even
609        // being null
610        let filter: BooleanArray = (0..100)
611            .map(|i| {
612                let is_even = i % 2 == 0;
613                let is_fifth = i % 5 == 0;
614                if is_even {
615                    None
616                } else if is_fifth {
617                    Some(false)
618                } else {
619                    Some(true)
620                }
621            })
622            .collect();
623
624        Fixture {
625            group_indices,
626            values,
627            values_with_nulls,
628            filter,
629        }
630        .run()
631    }
632
633    #[test]
634    fn accumulate_fuzz() {
635        let mut rng = rand::rng();
636        for _ in 0..100 {
637            Fixture::new_random(&mut rng).run();
638        }
639    }
640
641    /// Values for testing (there are enough values to exercise the 64 bit chunks
642    struct Fixture {
643        /// 100..0
644        group_indices: Vec<usize>,
645
646        /// 10, 20, ... 1010
647        values: Vec<u32>,
648
649        /// same as values, but every third is null:
650        /// None, Some(20), Some(30), None ...
651        values_with_nulls: Vec<Option<u32>>,
652
653        /// filter (defaults to None)
654        filter: BooleanArray,
655    }
656
657    impl Fixture {
658        fn new_random(rng: &mut ThreadRng) -> Self {
659            // Number of input values in a batch
660            let num_values: usize = rng.random_range(1..200);
661            // number of distinct groups
662            let num_groups: usize = rng.random_range(2..1000);
663            let max_group = num_groups - 1;
664
665            let group_indices: Vec<usize> = (0..num_values)
666                .map(|_| rng.random_range(0..max_group))
667                .collect();
668
669            let values: Vec<u32> = (0..num_values).map(|_| rng.random()).collect();
670
671            // 10% chance of false
672            // 10% change of null
673            // 80% chance of true
674            let filter: BooleanArray = (0..num_values)
675                .map(|_| {
676                    let filter_value = rng.random_range(0.0..1.0);
677                    if filter_value < 0.1 {
678                        Some(false)
679                    } else if filter_value < 0.2 {
680                        None
681                    } else {
682                        Some(true)
683                    }
684                })
685                .collect();
686
687            // random values with random number and location of nulls
688            // random null percentage
689            let null_pct: f32 = rng.random_range(0.0..1.0);
690            let values_with_nulls: Vec<Option<u32>> = (0..num_values)
691                .map(|_| {
692                    let is_null = null_pct < rng.random_range(0.0..1.0);
693                    if is_null { None } else { Some(rng.random()) }
694                })
695                .collect();
696
697            Self {
698                group_indices,
699                values,
700                values_with_nulls,
701                filter,
702            }
703        }
704
705        /// returns `Self::values` an Array
706        fn values_array(&self) -> UInt32Array {
707            UInt32Array::from(self.values.clone())
708        }
709
710        /// returns `Self::values_with_nulls` as an Array
711        fn values_with_nulls_array(&self) -> UInt32Array {
712            UInt32Array::from(self.values_with_nulls.clone())
713        }
714
715        /// Calls `NullState::accumulate` and `accumulate_indices`
716        /// with all combinations of nulls and filter values
717        fn run(&self) {
718            let total_num_groups = *self.group_indices.iter().max().unwrap() + 1;
719
720            let group_indices = &self.group_indices;
721            let values_array = self.values_array();
722            let values_with_nulls_array = self.values_with_nulls_array();
723            let filter = &self.filter;
724
725            // no null, no filters
726            Self::accumulate_test(group_indices, &values_array, None, total_num_groups);
727
728            // nulls, no filters
729            Self::accumulate_test(
730                group_indices,
731                &values_with_nulls_array,
732                None,
733                total_num_groups,
734            );
735
736            // no nulls, filters
737            Self::accumulate_test(
738                group_indices,
739                &values_array,
740                Some(filter),
741                total_num_groups,
742            );
743
744            // nulls, filters
745            Self::accumulate_test(
746                group_indices,
747                &values_with_nulls_array,
748                Some(filter),
749                total_num_groups,
750            );
751        }
752
753        /// Calls `NullState::accumulate` and `accumulate_indices` to
754        /// ensure it generates the correct values.
755        fn accumulate_test(
756            group_indices: &[usize],
757            values: &UInt32Array,
758            opt_filter: Option<&BooleanArray>,
759            total_num_groups: usize,
760        ) {
761            Self::accumulate_values_test(
762                group_indices,
763                values,
764                opt_filter,
765                total_num_groups,
766            );
767            Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter);
768
769            // Convert values into a boolean array (anything above the
770            // average is true, otherwise false)
771            let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum();
772            let boolean_values: BooleanArray =
773                values.iter().map(|v| v.map(|v| v as usize > avg)).collect();
774            Self::accumulate_boolean_test(
775                group_indices,
776                &boolean_values,
777                opt_filter,
778                total_num_groups,
779            );
780        }
781
782        /// This is effectively a different implementation of
783        /// accumulate that we compare with the above implementation
784        fn accumulate_values_test(
785            group_indices: &[usize],
786            values: &UInt32Array,
787            opt_filter: Option<&BooleanArray>,
788            total_num_groups: usize,
789        ) {
790            let mut accumulated_values = vec![];
791            let mut null_state = NullState::new();
792
793            null_state.accumulate(
794                group_indices,
795                values,
796                opt_filter,
797                total_num_groups,
798                |group_index, value| {
799                    accumulated_values.push((group_index, value));
800                },
801            );
802
803            // Figure out the expected values
804            let mut expected_values = vec![];
805            let mut mock = MockNullState::new();
806
807            match opt_filter {
808                None => group_indices.iter().zip(values.iter()).for_each(
809                    |(&group_index, value)| {
810                        if let Some(value) = value {
811                            mock.saw_value(group_index);
812                            expected_values.push((group_index, value));
813                        }
814                    },
815                ),
816                Some(filter) => {
817                    group_indices
818                        .iter()
819                        .zip(values.iter())
820                        .zip(filter.iter())
821                        .for_each(|((&group_index, value), is_included)| {
822                            // if value passed filter
823                            if let Some(true) = is_included
824                                && let Some(value) = value
825                            {
826                                mock.saw_value(group_index);
827                                expected_values.push((group_index, value));
828                            }
829                        });
830                }
831            }
832
833            assert_eq!(
834                accumulated_values, expected_values,
835                "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
836            );
837            let seen_values = null_state.seen_values.finish_cloned();
838            mock.validate_seen_values(&seen_values);
839
840            // Validate the final buffer (one value per group)
841            let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
842
843            let null_buffer = null_state.build(EmitTo::All);
844
845            assert_eq!(null_buffer, expected_null_buffer);
846        }
847
848        // Calls `accumulate_indices`
849        // and opt_filter and ensures it calls the right values
850        fn accumulate_indices_test(
851            group_indices: &[usize],
852            nulls: Option<&NullBuffer>,
853            opt_filter: Option<&BooleanArray>,
854        ) {
855            let mut accumulated_values = vec![];
856
857            accumulate_indices(group_indices, nulls, opt_filter, |group_index| {
858                accumulated_values.push(group_index);
859            });
860
861            // Figure out the expected values
862            let mut expected_values = vec![];
863
864            match (nulls, opt_filter) {
865                (None, None) => group_indices.iter().for_each(|&group_index| {
866                    expected_values.push(group_index);
867                }),
868                (Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each(
869                    |(&group_index, is_valid)| {
870                        if is_valid {
871                            expected_values.push(group_index);
872                        }
873                    },
874                ),
875                (None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each(
876                    |(&group_index, is_included)| {
877                        if let Some(true) = is_included {
878                            expected_values.push(group_index);
879                        }
880                    },
881                ),
882                (Some(nulls), Some(filter)) => {
883                    group_indices
884                        .iter()
885                        .zip(nulls.iter())
886                        .zip(filter.iter())
887                        .for_each(|((&group_index, is_valid), is_included)| {
888                            // if value passed filter
889                            if let (true, Some(true)) = (is_valid, is_included) {
890                                expected_values.push(group_index);
891                            }
892                        });
893                }
894            }
895
896            assert_eq!(
897                accumulated_values, expected_values,
898                "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
899            );
900        }
901
902        /// This is effectively a different implementation of
903        /// accumulate_boolean that we compare with the above implementation
904        fn accumulate_boolean_test(
905            group_indices: &[usize],
906            values: &BooleanArray,
907            opt_filter: Option<&BooleanArray>,
908            total_num_groups: usize,
909        ) {
910            let mut accumulated_values = vec![];
911            let mut null_state = NullState::new();
912
913            null_state.accumulate_boolean(
914                group_indices,
915                values,
916                opt_filter,
917                total_num_groups,
918                |group_index, value| {
919                    accumulated_values.push((group_index, value));
920                },
921            );
922
923            // Figure out the expected values
924            let mut expected_values = vec![];
925            let mut mock = MockNullState::new();
926
927            match opt_filter {
928                None => group_indices.iter().zip(values.iter()).for_each(
929                    |(&group_index, value)| {
930                        if let Some(value) = value {
931                            mock.saw_value(group_index);
932                            expected_values.push((group_index, value));
933                        }
934                    },
935                ),
936                Some(filter) => {
937                    group_indices
938                        .iter()
939                        .zip(values.iter())
940                        .zip(filter.iter())
941                        .for_each(|((&group_index, value), is_included)| {
942                            // if value passed filter
943                            if let Some(true) = is_included
944                                && let Some(value) = value
945                            {
946                                mock.saw_value(group_index);
947                                expected_values.push((group_index, value));
948                            }
949                        });
950                }
951            }
952
953            assert_eq!(
954                accumulated_values, expected_values,
955                "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
956            );
957
958            let seen_values = null_state.seen_values.finish_cloned();
959            mock.validate_seen_values(&seen_values);
960
961            // Validate the final buffer (one value per group)
962            let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
963
964            let null_buffer = null_state.build(EmitTo::All);
965
966            assert_eq!(null_buffer, expected_null_buffer);
967        }
968    }
969
970    /// Parallel implementation of NullState to check expected values
971    #[derive(Debug, Default)]
972    struct MockNullState {
973        /// group indices that had values that passed the filter
974        seen_values: HashSet<usize>,
975    }
976
977    impl MockNullState {
978        fn new() -> Self {
979            Default::default()
980        }
981
982        fn saw_value(&mut self, group_index: usize) {
983            self.seen_values.insert(group_index);
984        }
985
986        /// did this group index see any input?
987        fn expected_seen(&self, group_index: usize) -> bool {
988            self.seen_values.contains(&group_index)
989        }
990
991        /// Validate that the seen_values matches self.seen_values
992        fn validate_seen_values(&self, seen_values: &BooleanBuffer) {
993            for (group_index, is_seen) in seen_values.iter().enumerate() {
994                let expected_seen = self.expected_seen(group_index);
995                assert_eq!(
996                    expected_seen, is_seen,
997                    "mismatch at for group {group_index}"
998                );
999            }
1000        }
1001
1002        /// Create the expected null buffer based on if the input had nulls and a filter
1003        fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer {
1004            (0..total_num_groups)
1005                .map(|group_index| self.expected_seen(group_index))
1006                .collect()
1007        }
1008    }
1009
1010    #[test]
1011    fn test_accumulate_multiple_no_nulls_no_filter() {
1012        let group_indices = vec![0, 1, 0, 1];
1013        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1014        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1015        let value_columns = [values1, values2];
1016
1017        let mut accumulated = vec![];
1018        accumulate_multiple(
1019            &group_indices,
1020            &value_columns.iter().collect::<Vec<_>>(),
1021            None,
1022            |group_idx, batch_idx, columns| {
1023                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1024                accumulated.push((group_idx, values));
1025            },
1026        );
1027
1028        let expected = vec![
1029            (0, vec![1, 10]),
1030            (1, vec![2, 20]),
1031            (0, vec![3, 30]),
1032            (1, vec![4, 40]),
1033        ];
1034        assert_eq!(accumulated, expected);
1035    }
1036
1037    #[test]
1038    fn test_accumulate_multiple_with_nulls() {
1039        let group_indices = vec![0, 1, 0, 1];
1040        let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1041        let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1042        let value_columns = [values1, values2];
1043
1044        let mut accumulated = vec![];
1045        accumulate_multiple(
1046            &group_indices,
1047            &value_columns.iter().collect::<Vec<_>>(),
1048            None,
1049            |group_idx, batch_idx, columns| {
1050                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1051                accumulated.push((group_idx, values));
1052            },
1053        );
1054
1055        // Only rows where both columns are non-null should be accumulated
1056        let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])];
1057        assert_eq!(accumulated, expected);
1058    }
1059
1060    #[test]
1061    fn test_accumulate_multiple_with_filter() {
1062        let group_indices = vec![0, 1, 0, 1];
1063        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1064        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1065        let value_columns = [values1, values2];
1066
1067        let filter = BooleanArray::from(vec![true, false, true, false]);
1068
1069        let mut accumulated = vec![];
1070        accumulate_multiple(
1071            &group_indices,
1072            &value_columns.iter().collect::<Vec<_>>(),
1073            Some(&filter),
1074            |group_idx, batch_idx, columns| {
1075                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1076                accumulated.push((group_idx, values));
1077            },
1078        );
1079
1080        // Only rows where filter is true should be accumulated
1081        let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
1082        assert_eq!(accumulated, expected);
1083    }
1084
1085    #[test]
1086    fn test_accumulate_multiple_with_nulls_and_filter() {
1087        let group_indices = vec![0, 1, 0, 1];
1088        let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1089        let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1090        let value_columns = [values1, values2];
1091
1092        let filter = BooleanArray::from(vec![true, true, true, false]);
1093
1094        let mut accumulated = vec![];
1095        accumulate_multiple(
1096            &group_indices,
1097            &value_columns.iter().collect::<Vec<_>>(),
1098            Some(&filter),
1099            |group_idx, batch_idx, columns| {
1100                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1101                accumulated.push((group_idx, values));
1102            },
1103        );
1104
1105        // Only rows where both:
1106        // 1. Filter is true
1107        // 2. Both columns are non-null
1108        // should be accumulated
1109        let expected = [(0, vec![1, 10])];
1110        assert_eq!(accumulated, expected);
1111    }
1112}