datafusion_functions_aggregate_common/aggregate/groups_accumulator/
accumulate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`]
19//!
20//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
21
22use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
23use arrow::buffer::{BooleanBuffer, NullBuffer};
24use arrow::datatypes::ArrowPrimitiveType;
25
26use datafusion_expr_common::groups_accumulator::EmitTo;
27/// Track the accumulator null state per row: if any values for that
28/// group were null and if any values have been seen at all for that group.
29///
30/// This is part of the inner loop for many [`GroupsAccumulator`]s,
31/// and thus the performance is critical and so there are multiple
32/// specialized implementations, invoked depending on the specific
33/// combinations of the input.
34///
35/// Typically there are 4 potential combinations of inputs must be
36/// special cased for performance:
37///
38/// * With / Without filter
39/// * With / Without nulls in the input
40///
41/// If the input has nulls, then the accumulator must potentially
42/// handle each input null value specially (e.g. for `SUM` to mark the
43/// corresponding sum as null)
44///
45/// If there are filters present, `NullState` tracks if it has seen
46/// *any* value for that group (as some values may be filtered
47/// out). Without a filter, the accumulator is only passed groups that
48/// had at least one value to accumulate so they do not need to track
49/// if they have seen values for a particular group.
50///
51/// [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
52#[derive(Debug)]
53pub struct NullState {
54    /// Have we seen any non-filtered input values for `group_index`?
55    ///
56    /// If `seen_values[i]` is true, have seen at least one non null
57    /// value for group `i`
58    ///
59    /// If `seen_values[i]` is false, have not seen any values that
60    /// pass the filter yet for group `i`
61    seen_values: BooleanBufferBuilder,
62}
63
64impl Default for NullState {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl NullState {
71    pub fn new() -> Self {
72        Self {
73            seen_values: BooleanBufferBuilder::new(0),
74        }
75    }
76
77    /// return the size of all buffers allocated by this null state, not including self
78    pub fn size(&self) -> usize {
79        // capacity is in bits, so convert to bytes
80        self.seen_values.capacity() / 8
81    }
82
83    /// Invokes `value_fn(group_index, value)` for each non null, non
84    /// filtered value of `value`, while tracking which groups have
85    /// seen null inputs and which groups have seen any inputs if necessary
86    //
87    /// # Arguments:
88    ///
89    /// * `values`: the input arguments to the accumulator
90    /// * `group_indices`:  To which groups do the rows in `values` belong, (aka group_index)
91    /// * `opt_filter`: if present, only rows for which is Some(true) are included
92    /// * `value_fn`: function invoked for  (group_index, value) where value is non null
93    ///
94    /// See [`accumulate`], for more details on how value_fn is called
95    ///
96    /// When value_fn is called it also sets
97    ///
98    /// 1. `self.seen_values[group_index]` to true for all rows that had a non null value
99    pub fn accumulate<T, F>(
100        &mut self,
101        group_indices: &[usize],
102        values: &PrimitiveArray<T>,
103        opt_filter: Option<&BooleanArray>,
104        total_num_groups: usize,
105        mut value_fn: F,
106    ) where
107        T: ArrowPrimitiveType + Send,
108        F: FnMut(usize, T::Native) + Send,
109    {
110        // ensure the seen_values is big enough (start everything at
111        // "not seen" valid)
112        let seen_values =
113            initialize_builder(&mut self.seen_values, total_num_groups, false);
114        accumulate(group_indices, values, opt_filter, |group_index, value| {
115            seen_values.set_bit(group_index, true);
116            value_fn(group_index, value);
117        });
118    }
119
120    /// Invokes `value_fn(group_index, value)` for each non null, non
121    /// filtered value in `values`, while tracking which groups have
122    /// seen null inputs and which groups have seen any inputs, for
123    /// [`BooleanArray`]s.
124    ///
125    /// Since `BooleanArray` is not a [`PrimitiveArray`] it must be
126    /// handled specially.
127    ///
128    /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for
129    /// more details on other arguments.
130    pub fn accumulate_boolean<F>(
131        &mut self,
132        group_indices: &[usize],
133        values: &BooleanArray,
134        opt_filter: Option<&BooleanArray>,
135        total_num_groups: usize,
136        mut value_fn: F,
137    ) where
138        F: FnMut(usize, bool) + Send,
139    {
140        let data = values.values();
141        assert_eq!(data.len(), group_indices.len());
142
143        // ensure the seen_values is big enough (start everything at
144        // "not seen" valid)
145        let seen_values =
146            initialize_builder(&mut self.seen_values, total_num_groups, false);
147
148        // These could be made more performant by iterating in chunks of 64 bits at a time
149        match (values.null_count() > 0, opt_filter) {
150            // no nulls, no filter,
151            (false, None) => {
152                // if we have previously seen nulls, ensure the null
153                // buffer is big enough (start everything at valid)
154                group_indices.iter().zip(data.iter()).for_each(
155                    |(&group_index, new_value)| {
156                        seen_values.set_bit(group_index, true);
157                        value_fn(group_index, new_value)
158                    },
159                )
160            }
161            // nulls, no filter
162            (true, None) => {
163                let nulls = values.nulls().unwrap();
164                group_indices
165                    .iter()
166                    .zip(data.iter())
167                    .zip(nulls.iter())
168                    .for_each(|((&group_index, new_value), is_valid)| {
169                        if is_valid {
170                            seen_values.set_bit(group_index, true);
171                            value_fn(group_index, new_value);
172                        }
173                    })
174            }
175            // no nulls, but a filter
176            (false, Some(filter)) => {
177                assert_eq!(filter.len(), group_indices.len());
178
179                group_indices
180                    .iter()
181                    .zip(data.iter())
182                    .zip(filter.iter())
183                    .for_each(|((&group_index, new_value), filter_value)| {
184                        if let Some(true) = filter_value {
185                            seen_values.set_bit(group_index, true);
186                            value_fn(group_index, new_value);
187                        }
188                    })
189            }
190            // both null values and filters
191            (true, Some(filter)) => {
192                assert_eq!(filter.len(), group_indices.len());
193                filter
194                    .iter()
195                    .zip(group_indices.iter())
196                    .zip(values.iter())
197                    .for_each(|((filter_value, &group_index), new_value)| {
198                        if let Some(true) = filter_value {
199                            if let Some(new_value) = new_value {
200                                seen_values.set_bit(group_index, true);
201                                value_fn(group_index, new_value)
202                            }
203                        }
204                    })
205            }
206        }
207    }
208
209    /// Creates the a [`NullBuffer`] representing which group_indices
210    /// should have null values (because they never saw any values)
211    /// for the `emit_to` rows.
212    ///
213    /// resets the internal state appropriately
214    pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
215        let nulls: BooleanBuffer = self.seen_values.finish();
216
217        let nulls = match emit_to {
218            EmitTo::All => nulls,
219            EmitTo::First(n) => {
220                // split off the first N values in seen_values
221                //
222                // TODO make this more efficient rather than two
223                // copies and bitwise manipulation
224                let first_n_null: BooleanBuffer = nulls.iter().take(n).collect();
225                // reset the existing seen buffer
226                for seen in nulls.iter().skip(n) {
227                    self.seen_values.append(seen);
228                }
229                first_n_null
230            }
231        };
232        NullBuffer::new(nulls)
233    }
234}
235
236/// Invokes `value_fn(group_index, value)` for each non null, non
237/// filtered value of `value`,
238///
239/// # Arguments:
240///
241/// * `group_indices`:  To which groups do the rows in `values` belong, (aka group_index)
242/// * `values`: the input arguments to the accumulator
243/// * `opt_filter`: if present, only rows for which is Some(true) are included
244/// * `value_fn`: function invoked for  (group_index, value) where value is non null
245///
246/// # Example
247///
248/// ```text
249///  ┌─────────┐   ┌─────────┐   ┌ ─ ─ ─ ─ ┐
250///  │ ┌─────┐ │   │ ┌─────┐ │     ┌─────┐
251///  │ │  2  │ │   │ │ 200 │ │   │ │  t  │ │
252///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
253///  │ │  2  │ │   │ │ 100 │ │   │ │  f  │ │
254///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
255///  │ │  0  │ │   │ │ 200 │ │   │ │  t  │ │
256///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
257///  │ │  1  │ │   │ │ 200 │ │   │ │NULL │ │
258///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
259///  │ │  0  │ │   │ │ 300 │ │   │ │  t  │ │
260///  │ └─────┘ │   │ └─────┘ │     └─────┘
261///  └─────────┘   └─────────┘   └ ─ ─ ─ ─ ┘
262///
263/// group_indices   values        opt_filter
264/// ```
265///
266/// In the example above, `value_fn` is invoked for each (group_index,
267/// value) pair where `opt_filter[i]` is true and values is non null
268///
269/// ```text
270/// value_fn(2, 200)
271/// value_fn(0, 200)
272/// value_fn(0, 300)
273/// ```
274pub fn accumulate<T, F>(
275    group_indices: &[usize],
276    values: &PrimitiveArray<T>,
277    opt_filter: Option<&BooleanArray>,
278    mut value_fn: F,
279) where
280    T: ArrowPrimitiveType + Send,
281    F: FnMut(usize, T::Native) + Send,
282{
283    let data: &[T::Native] = values.values();
284    assert_eq!(data.len(), group_indices.len());
285
286    match (values.null_count() > 0, opt_filter) {
287        // no nulls, no filter,
288        (false, None) => {
289            let iter = group_indices.iter().zip(data.iter());
290            for (&group_index, &new_value) in iter {
291                value_fn(group_index, new_value);
292            }
293        }
294        // nulls, no filter
295        (true, None) => {
296            let nulls = values.nulls().unwrap();
297            // This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum
298            // iterate over in chunks of 64 bits for more efficient null checking
299            let group_indices_chunks = group_indices.chunks_exact(64);
300            let data_chunks = data.chunks_exact(64);
301            let bit_chunks = nulls.inner().bit_chunks();
302
303            let group_indices_remainder = group_indices_chunks.remainder();
304            let data_remainder = data_chunks.remainder();
305
306            group_indices_chunks
307                .zip(data_chunks)
308                .zip(bit_chunks.iter())
309                .for_each(|((group_index_chunk, data_chunk), mask)| {
310                    // index_mask has value 1 << i in the loop
311                    let mut index_mask = 1;
312                    group_index_chunk.iter().zip(data_chunk.iter()).for_each(
313                        |(&group_index, &new_value)| {
314                            // valid bit was set, real value
315                            let is_valid = (mask & index_mask) != 0;
316                            if is_valid {
317                                value_fn(group_index, new_value);
318                            }
319                            index_mask <<= 1;
320                        },
321                    )
322                });
323
324            // handle any remaining bits (after the initial 64)
325            let remainder_bits = bit_chunks.remainder_bits();
326            group_indices_remainder
327                .iter()
328                .zip(data_remainder.iter())
329                .enumerate()
330                .for_each(|(i, (&group_index, &new_value))| {
331                    let is_valid = remainder_bits & (1 << i) != 0;
332                    if is_valid {
333                        value_fn(group_index, new_value);
334                    }
335                });
336        }
337        // no nulls, but a filter
338        (false, Some(filter)) => {
339            assert_eq!(filter.len(), group_indices.len());
340            // The performance with a filter could be improved by
341            // iterating over the filter in chunks, rather than a single
342            // iterator. TODO file a ticket
343            group_indices
344                .iter()
345                .zip(data.iter())
346                .zip(filter.iter())
347                .for_each(|((&group_index, &new_value), filter_value)| {
348                    if let Some(true) = filter_value {
349                        value_fn(group_index, new_value);
350                    }
351                })
352        }
353        // both null values and filters
354        (true, Some(filter)) => {
355            assert_eq!(filter.len(), group_indices.len());
356            // The performance with a filter could be improved by
357            // iterating over the filter in chunks, rather than using
358            // iterators. TODO file a ticket
359            filter
360                .iter()
361                .zip(group_indices.iter())
362                .zip(values.iter())
363                .for_each(|((filter_value, &group_index), new_value)| {
364                    if let Some(true) = filter_value {
365                        if let Some(new_value) = new_value {
366                            value_fn(group_index, new_value)
367                        }
368                    }
369                })
370        }
371    }
372}
373
374/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`)
375///
376/// This method assumes that for any input record index, if any of the value column
377/// is null, or it's filtered out by `opt_filter`, then the record would be ignored.
378/// (won't be accumulated by `value_fn`)
379///
380/// # Arguments
381///
382/// * `group_indices` - To which groups do the rows in `value_columns` belong
383/// * `value_columns` - The input arrays to accumulate
384/// * `opt_filter` - Optional filter array. If present, only rows where filter is `Some(true)` are included
385/// * `value_fn` - Callback function for each valid row, with parameters:
386///     * `group_idx`: The group index for the current row
387///     * `batch_idx`: The index of the current row in the input arrays
388///     * `columns`: Reference to all input arrays for accessing values
389pub fn accumulate_multiple<T, F>(
390    group_indices: &[usize],
391    value_columns: &[&PrimitiveArray<T>],
392    opt_filter: Option<&BooleanArray>,
393    mut value_fn: F,
394) where
395    T: ArrowPrimitiveType + Send,
396    F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
397{
398    // Calculate `valid_indices` to accumulate, non-valid indices are ignored.
399    // `valid_indices` is a bit mask corresponding to the `group_indices`. An index
400    // is considered valid if:
401    // 1. All columns are non-null at this index.
402    // 2. Not filtered out by `opt_filter`
403
404    // Take AND from all null buffers of `value_columns`.
405    let combined_nulls = value_columns
406        .iter()
407        .map(|arr| arr.logical_nulls())
408        .fold(None, |acc, nulls| {
409            NullBuffer::union(acc.as_ref(), nulls.as_ref())
410        });
411
412    // Take AND from previous combined nulls and `opt_filter`.
413    let valid_indices = match (combined_nulls, opt_filter) {
414        (None, None) => None,
415        (None, Some(filter)) => Some(filter.clone()),
416        (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)),
417        (Some(nulls), Some(filter)) => {
418            let combined = nulls.inner() & filter.values();
419            Some(BooleanArray::new(combined, None))
420        }
421    };
422
423    for col in value_columns.iter() {
424        debug_assert_eq!(col.len(), group_indices.len());
425    }
426
427    match valid_indices {
428        None => {
429            for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
430                value_fn(group_idx, batch_idx, value_columns);
431            }
432        }
433        Some(valid_indices) => {
434            for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
435                if valid_indices.value(batch_idx) {
436                    value_fn(group_idx, batch_idx, value_columns);
437                }
438            }
439        }
440    }
441}
442
443/// This function is called to update the accumulator state per row
444/// when the value is not needed (e.g. COUNT)
445///
446/// `F`: Invoked like `value_fn(group_index) for all non null values
447/// passing the filter. Note that no tracking is done for null inputs
448/// or which groups have seen any values
449///
450/// See [`NullState::accumulate`], for more details on other
451/// arguments.
452pub fn accumulate_indices<F>(
453    group_indices: &[usize],
454    nulls: Option<&NullBuffer>,
455    opt_filter: Option<&BooleanArray>,
456    mut index_fn: F,
457) where
458    F: FnMut(usize) + Send,
459{
460    match (nulls, opt_filter) {
461        (None, None) => {
462            for &group_index in group_indices.iter() {
463                index_fn(group_index)
464            }
465        }
466        (None, Some(filter)) => {
467            debug_assert_eq!(filter.len(), group_indices.len());
468            let group_indices_chunks = group_indices.chunks_exact(64);
469            let bit_chunks = filter.values().bit_chunks();
470
471            let group_indices_remainder = group_indices_chunks.remainder();
472
473            group_indices_chunks.zip(bit_chunks.iter()).for_each(
474                |(group_index_chunk, mask)| {
475                    // index_mask has value 1 << i in the loop
476                    let mut index_mask = 1;
477                    group_index_chunk.iter().for_each(|&group_index| {
478                        // valid bit was set, real vale
479                        let is_valid = (mask & index_mask) != 0;
480                        if is_valid {
481                            index_fn(group_index);
482                        }
483                        index_mask <<= 1;
484                    })
485                },
486            );
487
488            // handle any remaining bits (after the initial 64)
489            let remainder_bits = bit_chunks.remainder_bits();
490            group_indices_remainder
491                .iter()
492                .enumerate()
493                .for_each(|(i, &group_index)| {
494                    let is_valid = remainder_bits & (1 << i) != 0;
495                    if is_valid {
496                        index_fn(group_index)
497                    }
498                });
499        }
500        (Some(valids), None) => {
501            debug_assert_eq!(valids.len(), group_indices.len());
502            // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
503            // iterate over in chunks of 64 bits for more efficient null checking
504            let group_indices_chunks = group_indices.chunks_exact(64);
505            let bit_chunks = valids.inner().bit_chunks();
506
507            let group_indices_remainder = group_indices_chunks.remainder();
508
509            group_indices_chunks.zip(bit_chunks.iter()).for_each(
510                |(group_index_chunk, mask)| {
511                    // index_mask has value 1 << i in the loop
512                    let mut index_mask = 1;
513                    group_index_chunk.iter().for_each(|&group_index| {
514                        // valid bit was set, real vale
515                        let is_valid = (mask & index_mask) != 0;
516                        if is_valid {
517                            index_fn(group_index);
518                        }
519                        index_mask <<= 1;
520                    })
521                },
522            );
523
524            // handle any remaining bits (after the initial 64)
525            let remainder_bits = bit_chunks.remainder_bits();
526            group_indices_remainder
527                .iter()
528                .enumerate()
529                .for_each(|(i, &group_index)| {
530                    let is_valid = remainder_bits & (1 << i) != 0;
531                    if is_valid {
532                        index_fn(group_index)
533                    }
534                });
535        }
536
537        (Some(valids), Some(filter)) => {
538            debug_assert_eq!(filter.len(), group_indices.len());
539            debug_assert_eq!(valids.len(), group_indices.len());
540
541            let group_indices_chunks = group_indices.chunks_exact(64);
542            let valid_bit_chunks = valids.inner().bit_chunks();
543            let filter_bit_chunks = filter.values().bit_chunks();
544
545            let group_indices_remainder = group_indices_chunks.remainder();
546
547            group_indices_chunks
548                .zip(valid_bit_chunks.iter())
549                .zip(filter_bit_chunks.iter())
550                .for_each(|((group_index_chunk, valid_mask), filter_mask)| {
551                    // index_mask has value 1 << i in the loop
552                    let mut index_mask = 1;
553                    group_index_chunk.iter().for_each(|&group_index| {
554                        // valid bit was set, real vale
555                        let is_valid = (valid_mask & filter_mask & index_mask) != 0;
556                        if is_valid {
557                            index_fn(group_index);
558                        }
559                        index_mask <<= 1;
560                    })
561                });
562
563            // handle any remaining bits (after the initial 64)
564            let remainder_valid_bits = valid_bit_chunks.remainder_bits();
565            let remainder_filter_bits = filter_bit_chunks.remainder_bits();
566            group_indices_remainder
567                .iter()
568                .enumerate()
569                .for_each(|(i, &group_index)| {
570                    let is_valid =
571                        remainder_valid_bits & remainder_filter_bits & (1 << i) != 0;
572                    if is_valid {
573                        index_fn(group_index)
574                    }
575                });
576        }
577    }
578}
579
580/// Ensures that `builder` contains a `BooleanBufferBuilder with at
581/// least `total_num_groups`.
582///
583/// All new entries are initialized to `default_value`
584fn initialize_builder(
585    builder: &mut BooleanBufferBuilder,
586    total_num_groups: usize,
587    default_value: bool,
588) -> &mut BooleanBufferBuilder {
589    if builder.len() < total_num_groups {
590        let new_groups = total_num_groups - builder.len();
591        builder.append_n(new_groups, default_value);
592    }
593    builder
594}
595
596#[cfg(test)]
597mod test {
598    use super::*;
599
600    use arrow::array::{Int32Array, UInt32Array};
601    use rand::{rngs::ThreadRng, Rng};
602    use std::collections::HashSet;
603
604    #[test]
605    fn accumulate() {
606        let group_indices = (0..100).collect();
607        let values = (0..100).map(|i| (i + 1) * 10).collect();
608        let values_with_nulls = (0..100)
609            .map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) })
610            .collect();
611
612        // default to every fifth value being false, every even
613        // being null
614        let filter: BooleanArray = (0..100)
615            .map(|i| {
616                let is_even = i % 2 == 0;
617                let is_fifth = i % 5 == 0;
618                if is_even {
619                    None
620                } else if is_fifth {
621                    Some(false)
622                } else {
623                    Some(true)
624                }
625            })
626            .collect();
627
628        Fixture {
629            group_indices,
630            values,
631            values_with_nulls,
632            filter,
633        }
634        .run()
635    }
636
637    #[test]
638    fn accumulate_fuzz() {
639        let mut rng = rand::rng();
640        for _ in 0..100 {
641            Fixture::new_random(&mut rng).run();
642        }
643    }
644
645    /// Values for testing (there are enough values to exercise the 64 bit chunks
646    struct Fixture {
647        /// 100..0
648        group_indices: Vec<usize>,
649
650        /// 10, 20, ... 1010
651        values: Vec<u32>,
652
653        /// same as values, but every third is null:
654        /// None, Some(20), Some(30), None ...
655        values_with_nulls: Vec<Option<u32>>,
656
657        /// filter (defaults to None)
658        filter: BooleanArray,
659    }
660
661    impl Fixture {
662        fn new_random(rng: &mut ThreadRng) -> Self {
663            // Number of input values in a batch
664            let num_values: usize = rng.random_range(1..200);
665            // number of distinct groups
666            let num_groups: usize = rng.random_range(2..1000);
667            let max_group = num_groups - 1;
668
669            let group_indices: Vec<usize> = (0..num_values)
670                .map(|_| rng.random_range(0..max_group))
671                .collect();
672
673            let values: Vec<u32> = (0..num_values).map(|_| rng.random()).collect();
674
675            // 10% chance of false
676            // 10% change of null
677            // 80% chance of true
678            let filter: BooleanArray = (0..num_values)
679                .map(|_| {
680                    let filter_value = rng.random_range(0.0..1.0);
681                    if filter_value < 0.1 {
682                        Some(false)
683                    } else if filter_value < 0.2 {
684                        None
685                    } else {
686                        Some(true)
687                    }
688                })
689                .collect();
690
691            // random values with random number and location of nulls
692            // random null percentage
693            let null_pct: f32 = rng.random_range(0.0..1.0);
694            let values_with_nulls: Vec<Option<u32>> = (0..num_values)
695                .map(|_| {
696                    let is_null = null_pct < rng.random_range(0.0..1.0);
697                    if is_null {
698                        None
699                    } else {
700                        Some(rng.random())
701                    }
702                })
703                .collect();
704
705            Self {
706                group_indices,
707                values,
708                values_with_nulls,
709                filter,
710            }
711        }
712
713        /// returns `Self::values` an Array
714        fn values_array(&self) -> UInt32Array {
715            UInt32Array::from(self.values.clone())
716        }
717
718        /// returns `Self::values_with_nulls` as an Array
719        fn values_with_nulls_array(&self) -> UInt32Array {
720            UInt32Array::from(self.values_with_nulls.clone())
721        }
722
723        /// Calls `NullState::accumulate` and `accumulate_indices`
724        /// with all combinations of nulls and filter values
725        fn run(&self) {
726            let total_num_groups = *self.group_indices.iter().max().unwrap() + 1;
727
728            let group_indices = &self.group_indices;
729            let values_array = self.values_array();
730            let values_with_nulls_array = self.values_with_nulls_array();
731            let filter = &self.filter;
732
733            // no null, no filters
734            Self::accumulate_test(group_indices, &values_array, None, total_num_groups);
735
736            // nulls, no filters
737            Self::accumulate_test(
738                group_indices,
739                &values_with_nulls_array,
740                None,
741                total_num_groups,
742            );
743
744            // no nulls, filters
745            Self::accumulate_test(
746                group_indices,
747                &values_array,
748                Some(filter),
749                total_num_groups,
750            );
751
752            // nulls, filters
753            Self::accumulate_test(
754                group_indices,
755                &values_with_nulls_array,
756                Some(filter),
757                total_num_groups,
758            );
759        }
760
761        /// Calls `NullState::accumulate` and `accumulate_indices` to
762        /// ensure it generates the correct values.
763        ///
764        fn accumulate_test(
765            group_indices: &[usize],
766            values: &UInt32Array,
767            opt_filter: Option<&BooleanArray>,
768            total_num_groups: usize,
769        ) {
770            Self::accumulate_values_test(
771                group_indices,
772                values,
773                opt_filter,
774                total_num_groups,
775            );
776            Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter);
777
778            // Convert values into a boolean array (anything above the
779            // average is true, otherwise false)
780            let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum();
781            let boolean_values: BooleanArray =
782                values.iter().map(|v| v.map(|v| v as usize > avg)).collect();
783            Self::accumulate_boolean_test(
784                group_indices,
785                &boolean_values,
786                opt_filter,
787                total_num_groups,
788            );
789        }
790
791        /// This is effectively a different implementation of
792        /// accumulate that we compare with the above implementation
793        fn accumulate_values_test(
794            group_indices: &[usize],
795            values: &UInt32Array,
796            opt_filter: Option<&BooleanArray>,
797            total_num_groups: usize,
798        ) {
799            let mut accumulated_values = vec![];
800            let mut null_state = NullState::new();
801
802            null_state.accumulate(
803                group_indices,
804                values,
805                opt_filter,
806                total_num_groups,
807                |group_index, value| {
808                    accumulated_values.push((group_index, value));
809                },
810            );
811
812            // Figure out the expected values
813            let mut expected_values = vec![];
814            let mut mock = MockNullState::new();
815
816            match opt_filter {
817                None => group_indices.iter().zip(values.iter()).for_each(
818                    |(&group_index, value)| {
819                        if let Some(value) = value {
820                            mock.saw_value(group_index);
821                            expected_values.push((group_index, value));
822                        }
823                    },
824                ),
825                Some(filter) => {
826                    group_indices
827                        .iter()
828                        .zip(values.iter())
829                        .zip(filter.iter())
830                        .for_each(|((&group_index, value), is_included)| {
831                            // if value passed filter
832                            if let Some(true) = is_included {
833                                if let Some(value) = value {
834                                    mock.saw_value(group_index);
835                                    expected_values.push((group_index, value));
836                                }
837                            }
838                        });
839                }
840            }
841
842            assert_eq!(accumulated_values, expected_values,
843                       "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}");
844            let seen_values = null_state.seen_values.finish_cloned();
845            mock.validate_seen_values(&seen_values);
846
847            // Validate the final buffer (one value per group)
848            let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
849
850            let null_buffer = null_state.build(EmitTo::All);
851
852            assert_eq!(null_buffer, expected_null_buffer);
853        }
854
855        // Calls `accumulate_indices`
856        // and opt_filter and ensures it calls the right values
857        fn accumulate_indices_test(
858            group_indices: &[usize],
859            nulls: Option<&NullBuffer>,
860            opt_filter: Option<&BooleanArray>,
861        ) {
862            let mut accumulated_values = vec![];
863
864            accumulate_indices(group_indices, nulls, opt_filter, |group_index| {
865                accumulated_values.push(group_index);
866            });
867
868            // Figure out the expected values
869            let mut expected_values = vec![];
870
871            match (nulls, opt_filter) {
872                (None, None) => group_indices.iter().for_each(|&group_index| {
873                    expected_values.push(group_index);
874                }),
875                (Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each(
876                    |(&group_index, is_valid)| {
877                        if is_valid {
878                            expected_values.push(group_index);
879                        }
880                    },
881                ),
882                (None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each(
883                    |(&group_index, is_included)| {
884                        if let Some(true) = is_included {
885                            expected_values.push(group_index);
886                        }
887                    },
888                ),
889                (Some(nulls), Some(filter)) => {
890                    group_indices
891                        .iter()
892                        .zip(nulls.iter())
893                        .zip(filter.iter())
894                        .for_each(|((&group_index, is_valid), is_included)| {
895                            // if value passed filter
896                            if let (true, Some(true)) = (is_valid, is_included) {
897                                expected_values.push(group_index);
898                            }
899                        });
900                }
901            }
902
903            assert_eq!(accumulated_values, expected_values,
904                       "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}");
905        }
906
907        /// This is effectively a different implementation of
908        /// accumulate_boolean that we compare with the above implementation
909        fn accumulate_boolean_test(
910            group_indices: &[usize],
911            values: &BooleanArray,
912            opt_filter: Option<&BooleanArray>,
913            total_num_groups: usize,
914        ) {
915            let mut accumulated_values = vec![];
916            let mut null_state = NullState::new();
917
918            null_state.accumulate_boolean(
919                group_indices,
920                values,
921                opt_filter,
922                total_num_groups,
923                |group_index, value| {
924                    accumulated_values.push((group_index, value));
925                },
926            );
927
928            // Figure out the expected values
929            let mut expected_values = vec![];
930            let mut mock = MockNullState::new();
931
932            match opt_filter {
933                None => group_indices.iter().zip(values.iter()).for_each(
934                    |(&group_index, value)| {
935                        if let Some(value) = value {
936                            mock.saw_value(group_index);
937                            expected_values.push((group_index, value));
938                        }
939                    },
940                ),
941                Some(filter) => {
942                    group_indices
943                        .iter()
944                        .zip(values.iter())
945                        .zip(filter.iter())
946                        .for_each(|((&group_index, value), is_included)| {
947                            // if value passed filter
948                            if let Some(true) = is_included {
949                                if let Some(value) = value {
950                                    mock.saw_value(group_index);
951                                    expected_values.push((group_index, value));
952                                }
953                            }
954                        });
955                }
956            }
957
958            assert_eq!(accumulated_values, expected_values,
959                       "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}");
960
961            let seen_values = null_state.seen_values.finish_cloned();
962            mock.validate_seen_values(&seen_values);
963
964            // Validate the final buffer (one value per group)
965            let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
966
967            let null_buffer = null_state.build(EmitTo::All);
968
969            assert_eq!(null_buffer, expected_null_buffer);
970        }
971    }
972
973    /// Parallel implementation of NullState to check expected values
974    #[derive(Debug, Default)]
975    struct MockNullState {
976        /// group indices that had values that passed the filter
977        seen_values: HashSet<usize>,
978    }
979
980    impl MockNullState {
981        fn new() -> Self {
982            Default::default()
983        }
984
985        fn saw_value(&mut self, group_index: usize) {
986            self.seen_values.insert(group_index);
987        }
988
989        /// did this group index see any input?
990        fn expected_seen(&self, group_index: usize) -> bool {
991            self.seen_values.contains(&group_index)
992        }
993
994        /// Validate that the seen_values matches self.seen_values
995        fn validate_seen_values(&self, seen_values: &BooleanBuffer) {
996            for (group_index, is_seen) in seen_values.iter().enumerate() {
997                let expected_seen = self.expected_seen(group_index);
998                assert_eq!(
999                    expected_seen, is_seen,
1000                    "mismatch at for group {group_index}"
1001                );
1002            }
1003        }
1004
1005        /// Create the expected null buffer based on if the input had nulls and a filter
1006        fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer {
1007            (0..total_num_groups)
1008                .map(|group_index| self.expected_seen(group_index))
1009                .collect()
1010        }
1011    }
1012
1013    #[test]
1014    fn test_accumulate_multiple_no_nulls_no_filter() {
1015        let group_indices = vec![0, 1, 0, 1];
1016        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1017        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1018        let value_columns = [values1, values2];
1019
1020        let mut accumulated = vec![];
1021        accumulate_multiple(
1022            &group_indices,
1023            &value_columns.iter().collect::<Vec<_>>(),
1024            None,
1025            |group_idx, batch_idx, columns| {
1026                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1027                accumulated.push((group_idx, values));
1028            },
1029        );
1030
1031        let expected = vec![
1032            (0, vec![1, 10]),
1033            (1, vec![2, 20]),
1034            (0, vec![3, 30]),
1035            (1, vec![4, 40]),
1036        ];
1037        assert_eq!(accumulated, expected);
1038    }
1039
1040    #[test]
1041    fn test_accumulate_multiple_with_nulls() {
1042        let group_indices = vec![0, 1, 0, 1];
1043        let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1044        let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1045        let value_columns = [values1, values2];
1046
1047        let mut accumulated = vec![];
1048        accumulate_multiple(
1049            &group_indices,
1050            &value_columns.iter().collect::<Vec<_>>(),
1051            None,
1052            |group_idx, batch_idx, columns| {
1053                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1054                accumulated.push((group_idx, values));
1055            },
1056        );
1057
1058        // Only rows where both columns are non-null should be accumulated
1059        let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])];
1060        assert_eq!(accumulated, expected);
1061    }
1062
1063    #[test]
1064    fn test_accumulate_multiple_with_filter() {
1065        let group_indices = vec![0, 1, 0, 1];
1066        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1067        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1068        let value_columns = [values1, values2];
1069
1070        let filter = BooleanArray::from(vec![true, false, true, false]);
1071
1072        let mut accumulated = vec![];
1073        accumulate_multiple(
1074            &group_indices,
1075            &value_columns.iter().collect::<Vec<_>>(),
1076            Some(&filter),
1077            |group_idx, batch_idx, columns| {
1078                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1079                accumulated.push((group_idx, values));
1080            },
1081        );
1082
1083        // Only rows where filter is true should be accumulated
1084        let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
1085        assert_eq!(accumulated, expected);
1086    }
1087
1088    #[test]
1089    fn test_accumulate_multiple_with_nulls_and_filter() {
1090        let group_indices = vec![0, 1, 0, 1];
1091        let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1092        let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1093        let value_columns = [values1, values2];
1094
1095        let filter = BooleanArray::from(vec![true, true, true, false]);
1096
1097        let mut accumulated = vec![];
1098        accumulate_multiple(
1099            &group_indices,
1100            &value_columns.iter().collect::<Vec<_>>(),
1101            Some(&filter),
1102            |group_idx, batch_idx, columns| {
1103                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1104                accumulated.push((group_idx, values));
1105            },
1106        );
1107
1108        // Only rows where both:
1109        // 1. Filter is true
1110        // 2. Both columns are non-null
1111        // should be accumulated
1112        let expected = [(0, vec![1, 10])];
1113        assert_eq!(accumulated, expected);
1114    }
1115}