Skip to main content

datafusion_functions_aggregate_common/aggregate/groups_accumulator/
accumulate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`]
19//!
20//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
21
22use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
23use arrow::buffer::NullBuffer;
24use arrow::datatypes::ArrowPrimitiveType;
25
26use crate::aggregate::groups_accumulator::nulls::filter_to_validity;
27use datafusion_expr_common::groups_accumulator::EmitTo;
28
29/// If the input has nulls, then the accumulator must potentially
30/// handle each input null value specially (e.g. for `SUM` to mark the
31/// corresponding sum as null)
32///
33/// If there are filters present, `NullState` tracks if it has seen
34/// *any* value for that group (as some values may be filtered
35/// out). Without a filter, the accumulator is only passed groups that
36/// had at least one value to accumulate so they do not need to track
37/// if they have seen values for a particular group.
38#[derive(Debug)]
39pub enum SeenValues {
40    /// All groups seen so far have seen at least one non-null value
41    All {
42        num_values: usize,
43    },
44    // Some groups have not yet seen a non-null value
45    Some {
46        values: BooleanBufferBuilder,
47    },
48}
49
50impl Default for SeenValues {
51    fn default() -> Self {
52        SeenValues::All { num_values: 0 }
53    }
54}
55
56impl SeenValues {
57    /// Return a mutable reference to the `BooleanBufferBuilder` in `SeenValues::Some`.
58    ///
59    /// If `self` is `SeenValues::All`, it is transitioned to `SeenValues::Some`
60    /// by creating a new `BooleanBufferBuilder` where the first `num_values` are true.
61    ///
62    /// The builder is then ensured to have at least `total_num_groups` length,
63    /// with any new entries initialized to false.
64    fn get_builder(&mut self, total_num_groups: usize) -> &mut BooleanBufferBuilder {
65        match self {
66            SeenValues::All { num_values } => {
67                let mut builder = BooleanBufferBuilder::new(total_num_groups);
68                builder.append_n(*num_values, true);
69                if total_num_groups > *num_values {
70                    builder.append_n(total_num_groups - *num_values, false);
71                }
72                *self = SeenValues::Some { values: builder };
73                match self {
74                    SeenValues::Some { values } => values,
75                    _ => unreachable!(),
76                }
77            }
78            SeenValues::Some { values } => {
79                if values.len() < total_num_groups {
80                    values.append_n(total_num_groups - values.len(), false);
81                }
82                values
83            }
84        }
85    }
86}
87
88/// Track the accumulator null state per row: if any values for that
89/// group were null and if any values have been seen at all for that group.
90///
91/// This is part of the inner loop for many [`GroupsAccumulator`]s,
92/// and thus the performance is critical and so there are multiple
93/// specialized implementations, invoked depending on the specific
94/// combinations of the input.
95///
96/// Typically there are 4 potential combinations of inputs must be
97/// special cased for performance:
98///
99/// * With / Without filter
100/// * With / Without nulls in the input
101///
102/// If the input has nulls, then the accumulator must potentially
103/// handle each input null value specially (e.g. for `SUM` to mark the
104/// corresponding sum as null)
105///
106/// If there are filters present, `NullState` tracks if it has seen
107/// *any* value for that group (as some values may be filtered
108/// out). Without a filter, the accumulator is only passed groups that
109/// had at least one value to accumulate so they do not need to track
110/// if they have seen values for a particular group.
111///
112/// [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
113#[derive(Debug)]
114pub struct NullState {
115    /// Have we seen any non-filtered input values for `group_index`?
116    ///
117    /// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is true, have seen at least one non null
118    /// value for group `i`
119    ///
120    /// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is false, have not seen any values that
121    /// pass the filter yet for group `i`
122    ///
123    /// If `seen_values` is `SeenValues::All`, all groups have seen at least one non null value
124    seen_values: SeenValues,
125}
126
127impl Default for NullState {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133impl NullState {
134    pub fn new() -> Self {
135        Self {
136            seen_values: SeenValues::All { num_values: 0 },
137        }
138    }
139
140    /// return the size of all buffers allocated by this null state, not including self
141    pub fn size(&self) -> usize {
142        match &self.seen_values {
143            SeenValues::All { .. } => 0,
144            SeenValues::Some { values } => values.capacity() / 8,
145        }
146    }
147
148    /// Invokes `value_fn(group_index, value)` for each non null, non
149    /// filtered value of `value`, while tracking which groups have
150    /// seen null inputs and which groups have seen any inputs if necessary
151    //
152    /// # Arguments:
153    ///
154    /// * `values`: the input arguments to the accumulator
155    /// * `group_indices`:  To which groups do the rows in `values` belong, (aka group_index)
156    /// * `opt_filter`: if present, only rows for which is Some(true) are included
157    /// * `value_fn`: function invoked for  (group_index, value) where value is non null
158    ///
159    /// See [`accumulate`], for more details on how value_fn is called
160    ///
161    /// When value_fn is called it also sets
162    ///
163    /// 1. `self.seen_values[group_index]` to true for all rows that had a non null value
164    pub fn accumulate<T, F>(
165        &mut self,
166        group_indices: &[usize],
167        values: &PrimitiveArray<T>,
168        opt_filter: Option<&BooleanArray>,
169        total_num_groups: usize,
170        mut value_fn: F,
171    ) where
172        T: ArrowPrimitiveType + Send,
173        F: FnMut(usize, T::Native) + Send,
174    {
175        // skip null handling if no nulls in input or accumulator
176        if let SeenValues::All { num_values } = &mut self.seen_values
177            && opt_filter.is_none()
178            && values.null_count() == 0
179        {
180            accumulate(group_indices, values, None, value_fn);
181            *num_values = total_num_groups;
182            return;
183        }
184
185        let seen_values = self.seen_values.get_builder(total_num_groups);
186        accumulate(group_indices, values, opt_filter, |group_index, value| {
187            seen_values.set_bit(group_index, true);
188            value_fn(group_index, value);
189        });
190    }
191
192    /// Invokes `value_fn(group_index, value)` for each non null, non
193    /// filtered value in `values`, while tracking which groups have
194    /// seen null inputs and which groups have seen any inputs, for
195    /// [`BooleanArray`]s.
196    ///
197    /// Since `BooleanArray` is not a [`PrimitiveArray`] it must be
198    /// handled specially.
199    ///
200    /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for
201    /// more details on other arguments.
202    pub fn accumulate_boolean<F>(
203        &mut self,
204        group_indices: &[usize],
205        values: &BooleanArray,
206        opt_filter: Option<&BooleanArray>,
207        total_num_groups: usize,
208        mut value_fn: F,
209    ) where
210        F: FnMut(usize, bool) + Send,
211    {
212        let data = values.values();
213        assert_eq!(data.len(), group_indices.len());
214
215        // skip null handling if no nulls in input or accumulator
216        if let SeenValues::All { num_values } = &mut self.seen_values
217            && opt_filter.is_none()
218            && values.null_count() == 0
219        {
220            group_indices
221                .iter()
222                .zip(data.iter())
223                .for_each(|(&group_index, new_value)| value_fn(group_index, new_value));
224            *num_values = total_num_groups;
225
226            return;
227        }
228
229        let seen_values = self.seen_values.get_builder(total_num_groups);
230
231        // These could be made more performant by iterating in chunks of 64 bits at a time
232        match (values.null_count() > 0, opt_filter) {
233            // no nulls, no filter,
234            (false, None) => {
235                // if we have previously seen nulls, ensure the null
236                // buffer is big enough (start everything at valid)
237                group_indices.iter().zip(data.iter()).for_each(
238                    |(&group_index, new_value)| {
239                        seen_values.set_bit(group_index, true);
240                        value_fn(group_index, new_value)
241                    },
242                )
243            }
244            // nulls, no filter
245            (true, None) => {
246                let nulls = values.nulls().unwrap();
247                group_indices
248                    .iter()
249                    .zip(data.iter())
250                    .zip(nulls.iter())
251                    .for_each(|((&group_index, new_value), is_valid)| {
252                        if is_valid {
253                            seen_values.set_bit(group_index, true);
254                            value_fn(group_index, new_value);
255                        }
256                    })
257            }
258            // no nulls, but a filter
259            (false, Some(filter)) => {
260                assert_eq!(filter.len(), group_indices.len());
261
262                group_indices
263                    .iter()
264                    .zip(data.iter())
265                    .zip(filter.iter())
266                    .for_each(|((&group_index, new_value), filter_value)| {
267                        if let Some(true) = filter_value {
268                            seen_values.set_bit(group_index, true);
269                            value_fn(group_index, new_value);
270                        }
271                    })
272            }
273            // both null values and filters
274            (true, Some(filter)) => {
275                assert_eq!(filter.len(), group_indices.len());
276                filter
277                    .iter()
278                    .zip(group_indices.iter())
279                    .zip(values.iter())
280                    .for_each(|((filter_value, &group_index), new_value)| {
281                        if let Some(true) = filter_value
282                            && let Some(new_value) = new_value
283                        {
284                            seen_values.set_bit(group_index, true);
285                            value_fn(group_index, new_value)
286                        }
287                    })
288            }
289        }
290    }
291
292    /// Creates the a [`NullBuffer`] representing which group_indices
293    /// should have null values (because they never saw any values)
294    /// for the `emit_to` rows.
295    ///
296    /// resets the internal state appropriately
297    pub fn build(&mut self, emit_to: EmitTo) -> Option<NullBuffer> {
298        match emit_to {
299            EmitTo::All => {
300                let old_seen = std::mem::take(&mut self.seen_values);
301                match old_seen {
302                    SeenValues::All { .. } => None,
303                    SeenValues::Some { mut values } => {
304                        Some(NullBuffer::new(values.finish()))
305                    }
306                }
307            }
308            EmitTo::First(n) => match &mut self.seen_values {
309                SeenValues::All { num_values } => {
310                    *num_values = num_values.saturating_sub(n);
311                    None
312                }
313                SeenValues::Some { .. } => {
314                    let mut old_values = match std::mem::take(&mut self.seen_values) {
315                        SeenValues::Some { values } => values,
316                        _ => unreachable!(),
317                    };
318                    let nulls = old_values.finish();
319                    let first_n_null = nulls.slice(0, n);
320                    let remainder = nulls.slice(n, nulls.len() - n);
321                    let mut new_builder = BooleanBufferBuilder::new(remainder.len());
322                    new_builder.append_buffer(&remainder);
323                    self.seen_values = SeenValues::Some {
324                        values: new_builder,
325                    };
326                    Some(NullBuffer::new(first_n_null))
327                }
328            },
329        }
330    }
331}
332
333/// Invokes `value_fn(group_index, value)` for each non null, non
334/// filtered value of `value`,
335///
336/// # Arguments:
337///
338/// * `group_indices`:  To which groups do the rows in `values` belong, (aka group_index)
339/// * `values`: the input arguments to the accumulator
340/// * `opt_filter`: if present, only rows for which is Some(true) are included
341/// * `value_fn`: function invoked for  (group_index, value) where value is non null
342///
343/// # Example
344///
345/// ```text
346///  ┌─────────┐   ┌─────────┐   ┌ ─ ─ ─ ─ ┐
347///  │ ┌─────┐ │   │ ┌─────┐ │     ┌─────┐
348///  │ │  2  │ │   │ │ 200 │ │   │ │  t  │ │
349///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
350///  │ │  2  │ │   │ │ 100 │ │   │ │  f  │ │
351///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
352///  │ │  0  │ │   │ │ 200 │ │   │ │  t  │ │
353///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
354///  │ │  1  │ │   │ │ 200 │ │   │ │NULL │ │
355///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
356///  │ │  0  │ │   │ │ 300 │ │   │ │  t  │ │
357///  │ └─────┘ │   │ └─────┘ │     └─────┘
358///  └─────────┘   └─────────┘   └ ─ ─ ─ ─ ┘
359///
360/// group_indices   values        opt_filter
361/// ```
362///
363/// In the example above, `value_fn` is invoked for each (group_index,
364/// value) pair where `opt_filter[i]` is true and values is non null
365///
366/// ```text
367/// value_fn(2, 200)
368/// value_fn(0, 200)
369/// value_fn(0, 300)
370/// ```
371pub fn accumulate<T, F>(
372    group_indices: &[usize],
373    values: &PrimitiveArray<T>,
374    opt_filter: Option<&BooleanArray>,
375    mut value_fn: F,
376) where
377    T: ArrowPrimitiveType + Send,
378    F: FnMut(usize, T::Native) + Send,
379{
380    let data: &[T::Native] = values.values();
381    assert_eq!(data.len(), group_indices.len());
382
383    match (values.null_count() > 0, opt_filter) {
384        // no nulls, no filter,
385        (false, None) => {
386            let iter = group_indices.iter().zip(data.iter());
387            for (&group_index, &new_value) in iter {
388                value_fn(group_index, new_value);
389            }
390        }
391        // nulls, no filter
392        (true, None) => {
393            let nulls = values.nulls().unwrap();
394            // This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum
395            // iterate over in chunks of 64 bits for more efficient null checking
396            let group_indices_chunks = group_indices.chunks_exact(64);
397            let data_chunks = data.chunks_exact(64);
398            let bit_chunks = nulls.inner().bit_chunks();
399
400            let group_indices_remainder = group_indices_chunks.remainder();
401            let data_remainder = data_chunks.remainder();
402
403            group_indices_chunks
404                .zip(data_chunks)
405                .zip(bit_chunks.iter())
406                .for_each(|((group_index_chunk, data_chunk), mask)| {
407                    // index_mask has value 1 << i in the loop
408                    let mut index_mask = 1;
409                    group_index_chunk.iter().zip(data_chunk.iter()).for_each(
410                        |(&group_index, &new_value)| {
411                            // valid bit was set, real value
412                            let is_valid = (mask & index_mask) != 0;
413                            if is_valid {
414                                value_fn(group_index, new_value);
415                            }
416                            index_mask <<= 1;
417                        },
418                    )
419                });
420
421            // handle any remaining bits (after the initial 64)
422            let remainder_bits = bit_chunks.remainder_bits();
423            group_indices_remainder
424                .iter()
425                .zip(data_remainder.iter())
426                .enumerate()
427                .for_each(|(i, (&group_index, &new_value))| {
428                    let is_valid = remainder_bits & (1 << i) != 0;
429                    if is_valid {
430                        value_fn(group_index, new_value);
431                    }
432                });
433        }
434        // no nulls, but a filter
435        (false, Some(filter)) => {
436            assert_eq!(filter.len(), group_indices.len());
437            // The performance with a filter could be improved by
438            // iterating over the filter in chunks, rather than a single
439            // iterator. TODO file a ticket
440            group_indices
441                .iter()
442                .zip(data.iter())
443                .zip(filter.iter())
444                .for_each(|((&group_index, &new_value), filter_value)| {
445                    if let Some(true) = filter_value {
446                        value_fn(group_index, new_value);
447                    }
448                })
449        }
450        // both null values and filters
451        (true, Some(filter)) => {
452            assert_eq!(filter.len(), group_indices.len());
453            // The performance with a filter could be improved by
454            // iterating over the filter in chunks, rather than using
455            // iterators. TODO file a ticket
456            filter
457                .iter()
458                .zip(group_indices.iter())
459                .zip(values.iter())
460                .for_each(|((filter_value, &group_index), new_value)| {
461                    if let Some(true) = filter_value
462                        && let Some(new_value) = new_value
463                    {
464                        value_fn(group_index, new_value)
465                    }
466                })
467        }
468    }
469}
470
471/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`)
472///
473/// This method assumes that for any input record index, if any of the value column
474/// is null, or it's filtered out by `opt_filter`, then the record would be ignored.
475/// (Won't be accumulated by `value_fn`)
476///
477/// # Arguments
478///
479/// * `group_indices` - To which groups do the rows in `value_columns` belong
480/// * `value_columns` - The input arrays to accumulate
481/// * `opt_filter` - Optional filter array. If present, only rows where filter is `Some(true)` are included
482/// * `value_fn` - Callback function for each valid row, with parameters:
483///     * `group_idx`: The group index for the current row
484///     * `batch_idx`: The index of the current row in the input arrays
485///     * `columns`: Reference to all input arrays for accessing values
486pub fn accumulate_multiple<T, F>(
487    group_indices: &[usize],
488    value_columns: &[&PrimitiveArray<T>],
489    opt_filter: Option<&BooleanArray>,
490    mut value_fn: F,
491) where
492    T: ArrowPrimitiveType + Send,
493    F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
494{
495    for col in value_columns.iter() {
496        debug_assert_eq!(col.len(), group_indices.len());
497    }
498
499    // Start with rows where all value columns are non-null.
500    let mut valid_indices =
501        NullBuffer::union_many(value_columns.iter().map(|arr| arr.nulls()))
502            .map(NullBuffer::into_inner);
503
504    // Restrict to rows where the optional filter is Some(true). Keep the filter
505    // as a raw BooleanBuffer to avoid computing a NullBuffer null_count just to
506    // test row validity below.
507    if let Some(filter) = opt_filter {
508        debug_assert_eq!(filter.len(), group_indices.len());
509        let filter_validity = filter_to_validity(filter);
510        if let Some(valid_indices) = valid_indices.as_mut() {
511            *valid_indices &= &filter_validity;
512        } else {
513            valid_indices = Some(filter_validity);
514        }
515    }
516
517    match valid_indices {
518        None => {
519            for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
520                value_fn(group_idx, batch_idx, value_columns);
521            }
522        }
523        Some(valid_indices) => {
524            for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
525                if valid_indices.value(batch_idx) {
526                    value_fn(group_idx, batch_idx, value_columns);
527                }
528            }
529        }
530    }
531}
532
533/// This function is called to update the accumulator state per row
534/// when the value is not needed (e.g. COUNT)
535///
536/// `F`: Invoked like `value_fn(group_index) for all non null values
537/// passing the filter. Note that no tracking is done for null inputs
538/// or which groups have seen any values
539///
540/// See [`NullState::accumulate`], for more details on other
541/// arguments.
542pub fn accumulate_indices<F>(
543    group_indices: &[usize],
544    nulls: Option<&NullBuffer>,
545    opt_filter: Option<&BooleanArray>,
546    mut index_fn: F,
547) where
548    F: FnMut(usize) + Send,
549{
550    match (nulls, opt_filter) {
551        (None, None) => {
552            for &group_index in group_indices.iter() {
553                index_fn(group_index)
554            }
555        }
556        (None, Some(filter)) => {
557            debug_assert_eq!(filter.len(), group_indices.len());
558            let group_indices_chunks = group_indices.chunks_exact(64);
559            let filter_validity = filter_to_validity(filter);
560            let bit_chunks = filter_validity.bit_chunks();
561
562            let group_indices_remainder = group_indices_chunks.remainder();
563
564            group_indices_chunks.zip(bit_chunks.iter()).for_each(
565                |(group_index_chunk, mask)| {
566                    // index_mask has value 1 << i in the loop
567                    let mut index_mask = 1;
568                    group_index_chunk.iter().for_each(|&group_index| {
569                        // valid bit was set, real vale
570                        let is_valid = (mask & index_mask) != 0;
571                        if is_valid {
572                            index_fn(group_index);
573                        }
574                        index_mask <<= 1;
575                    })
576                },
577            );
578
579            // handle any remaining bits (after the initial 64)
580            let remainder_bits = bit_chunks.remainder_bits();
581            group_indices_remainder
582                .iter()
583                .enumerate()
584                .for_each(|(i, &group_index)| {
585                    let is_valid = remainder_bits & (1 << i) != 0;
586                    if is_valid {
587                        index_fn(group_index)
588                    }
589                });
590        }
591        (Some(valids), None) => {
592            debug_assert_eq!(valids.len(), group_indices.len());
593            // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
594            // iterate over in chunks of 64 bits for more efficient null checking
595            let group_indices_chunks = group_indices.chunks_exact(64);
596            let bit_chunks = valids.inner().bit_chunks();
597
598            let group_indices_remainder = group_indices_chunks.remainder();
599
600            group_indices_chunks.zip(bit_chunks.iter()).for_each(
601                |(group_index_chunk, mask)| {
602                    // index_mask has value 1 << i in the loop
603                    let mut index_mask = 1;
604                    group_index_chunk.iter().for_each(|&group_index| {
605                        // valid bit was set, real vale
606                        let is_valid = (mask & index_mask) != 0;
607                        if is_valid {
608                            index_fn(group_index);
609                        }
610                        index_mask <<= 1;
611                    })
612                },
613            );
614
615            // handle any remaining bits (after the initial 64)
616            let remainder_bits = bit_chunks.remainder_bits();
617            group_indices_remainder
618                .iter()
619                .enumerate()
620                .for_each(|(i, &group_index)| {
621                    let is_valid = remainder_bits & (1 << i) != 0;
622                    if is_valid {
623                        index_fn(group_index)
624                    }
625                });
626        }
627
628        (Some(valids), Some(filter)) => {
629            debug_assert_eq!(filter.len(), group_indices.len());
630            debug_assert_eq!(valids.len(), group_indices.len());
631
632            let group_indices_chunks = group_indices.chunks_exact(64);
633            let valid_bit_chunks = valids.inner().bit_chunks();
634            let filter_validity = filter_to_validity(filter);
635            let filter_bit_chunks = filter_validity.bit_chunks();
636
637            let group_indices_remainder = group_indices_chunks.remainder();
638
639            group_indices_chunks
640                .zip(valid_bit_chunks.iter())
641                .zip(filter_bit_chunks.iter())
642                .for_each(|((group_index_chunk, valid_mask), filter_mask)| {
643                    // index_mask has value 1 << i in the loop
644                    let mut index_mask = 1;
645                    group_index_chunk.iter().for_each(|&group_index| {
646                        // valid bit was set, real vale
647                        let is_valid = (valid_mask & filter_mask & index_mask) != 0;
648                        if is_valid {
649                            index_fn(group_index);
650                        }
651                        index_mask <<= 1;
652                    })
653                });
654
655            // handle any remaining bits (after the initial 64)
656            let remainder_valid_bits = valid_bit_chunks.remainder_bits();
657            let remainder_filter_bits = filter_bit_chunks.remainder_bits();
658            group_indices_remainder
659                .iter()
660                .enumerate()
661                .for_each(|(i, &group_index)| {
662                    let is_valid =
663                        remainder_valid_bits & remainder_filter_bits & (1 << i) != 0;
664                    if is_valid {
665                        index_fn(group_index)
666                    }
667                });
668        }
669    }
670}
671
672#[cfg(test)]
673mod test {
674    use super::*;
675
676    use arrow::{
677        array::{Int32Array, UInt32Array},
678        buffer::BooleanBuffer,
679    };
680    use rand::{Rng, rngs::ThreadRng};
681    use std::collections::HashSet;
682
683    #[test]
684    fn accumulate() {
685        let group_indices = (0..100).collect();
686        let values = (0..100).map(|i| (i + 1) * 10).collect();
687        let values_with_nulls = (0..100)
688            .map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) })
689            .collect();
690
691        // default to every fifth value being false, every even
692        // being null
693        let filter: BooleanArray = (0..100)
694            .map(|i| {
695                let is_even = i % 2 == 0;
696                let is_fifth = i % 5 == 0;
697                if is_even {
698                    None
699                } else if is_fifth {
700                    Some(false)
701                } else {
702                    Some(true)
703                }
704            })
705            .collect();
706
707        Fixture {
708            group_indices,
709            values,
710            values_with_nulls,
711            filter,
712        }
713        .run()
714    }
715
716    #[test]
717    fn accumulate_fuzz() {
718        let mut rng = rand::rng();
719        for _ in 0..100 {
720            Fixture::new_random(&mut rng).run();
721        }
722    }
723
724    /// Values for testing (there are enough values to exercise the 64 bit chunks
725    struct Fixture {
726        /// 100..0
727        group_indices: Vec<usize>,
728
729        /// 10, 20, ... 1010
730        values: Vec<u32>,
731
732        /// same as values, but every third is null:
733        /// None, Some(20), Some(30), None ...
734        values_with_nulls: Vec<Option<u32>>,
735
736        /// filter (defaults to None)
737        filter: BooleanArray,
738    }
739
740    impl Fixture {
741        fn new_random(rng: &mut ThreadRng) -> Self {
742            // Number of input values in a batch
743            let num_values: usize = rng.random_range(1..200);
744            // number of distinct groups
745            let num_groups: usize = rng.random_range(2..1000);
746            let max_group = num_groups - 1;
747
748            let group_indices: Vec<usize> = (0..num_values)
749                .map(|_| rng.random_range(0..max_group))
750                .collect();
751
752            let values: Vec<u32> = (0..num_values).map(|_| rng.random()).collect();
753
754            // 10% chance of false
755            // 10% change of null
756            // 80% chance of true
757            let filter: BooleanArray = (0..num_values)
758                .map(|_| {
759                    let filter_value = rng.random_range(0.0..1.0);
760                    if filter_value < 0.1 {
761                        Some(false)
762                    } else if filter_value < 0.2 {
763                        None
764                    } else {
765                        Some(true)
766                    }
767                })
768                .collect();
769
770            // random values with random number and location of nulls
771            // random null percentage
772            let null_pct: f32 = rng.random_range(0.0..1.0);
773            let values_with_nulls: Vec<Option<u32>> = (0..num_values)
774                .map(|_| {
775                    let is_null = null_pct < rng.random_range(0.0..1.0);
776                    if is_null { None } else { Some(rng.random()) }
777                })
778                .collect();
779
780            Self {
781                group_indices,
782                values,
783                values_with_nulls,
784                filter,
785            }
786        }
787
788        /// returns `Self::values` an Array
789        fn values_array(&self) -> UInt32Array {
790            UInt32Array::from(self.values.clone())
791        }
792
793        /// returns `Self::values_with_nulls` as an Array
794        fn values_with_nulls_array(&self) -> UInt32Array {
795            UInt32Array::from(self.values_with_nulls.clone())
796        }
797
798        /// Calls `NullState::accumulate` and `accumulate_indices`
799        /// with all combinations of nulls and filter values
800        fn run(&self) {
801            let total_num_groups = *self.group_indices.iter().max().unwrap() + 1;
802
803            let group_indices = &self.group_indices;
804            let values_array = self.values_array();
805            let values_with_nulls_array = self.values_with_nulls_array();
806            let filter = &self.filter;
807
808            // no null, no filters
809            Self::accumulate_test(group_indices, &values_array, None, total_num_groups);
810
811            // nulls, no filters
812            Self::accumulate_test(
813                group_indices,
814                &values_with_nulls_array,
815                None,
816                total_num_groups,
817            );
818
819            // no nulls, filters
820            Self::accumulate_test(
821                group_indices,
822                &values_array,
823                Some(filter),
824                total_num_groups,
825            );
826
827            // nulls, filters
828            Self::accumulate_test(
829                group_indices,
830                &values_with_nulls_array,
831                Some(filter),
832                total_num_groups,
833            );
834        }
835
836        /// Calls `NullState::accumulate` and `accumulate_indices` to
837        /// ensure it generates the correct values.
838        fn accumulate_test(
839            group_indices: &[usize],
840            values: &UInt32Array,
841            opt_filter: Option<&BooleanArray>,
842            total_num_groups: usize,
843        ) {
844            Self::accumulate_values_test(
845                group_indices,
846                values,
847                opt_filter,
848                total_num_groups,
849            );
850            Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter);
851
852            // Convert values into a boolean array (anything above the
853            // average is true, otherwise false)
854            let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum();
855            let boolean_values: BooleanArray =
856                values.iter().map(|v| v.map(|v| v as usize > avg)).collect();
857            Self::accumulate_boolean_test(
858                group_indices,
859                &boolean_values,
860                opt_filter,
861                total_num_groups,
862            );
863        }
864
865        /// This is effectively a different implementation of
866        /// accumulate that we compare with the above implementation
867        fn accumulate_values_test(
868            group_indices: &[usize],
869            values: &UInt32Array,
870            opt_filter: Option<&BooleanArray>,
871            total_num_groups: usize,
872        ) {
873            let mut accumulated_values = vec![];
874            let mut null_state = NullState::new();
875
876            null_state.accumulate(
877                group_indices,
878                values,
879                opt_filter,
880                total_num_groups,
881                |group_index, value| {
882                    accumulated_values.push((group_index, value));
883                },
884            );
885
886            // Figure out the expected values
887            let mut expected_values = vec![];
888            let mut mock = MockNullState::new();
889
890            match opt_filter {
891                None => group_indices.iter().zip(values.iter()).for_each(
892                    |(&group_index, value)| {
893                        if let Some(value) = value {
894                            mock.saw_value(group_index);
895                            expected_values.push((group_index, value));
896                        }
897                    },
898                ),
899                Some(filter) => {
900                    group_indices
901                        .iter()
902                        .zip(values.iter())
903                        .zip(filter.iter())
904                        .for_each(|((&group_index, value), is_included)| {
905                            // if value passed filter
906                            if let Some(true) = is_included
907                                && let Some(value) = value
908                            {
909                                mock.saw_value(group_index);
910                                expected_values.push((group_index, value));
911                            }
912                        });
913                }
914            }
915
916            assert_eq!(
917                accumulated_values, expected_values,
918                "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
919            );
920
921            match &null_state.seen_values {
922                SeenValues::All { num_values } => {
923                    assert_eq!(*num_values, total_num_groups);
924                }
925                SeenValues::Some { values } => {
926                    let seen_values = values.finish_cloned();
927                    mock.validate_seen_values(&seen_values);
928                }
929            }
930
931            // Validate the final buffer (one value per group)
932            let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
933
934            let null_buffer = null_state.build(EmitTo::All);
935            if let Some(nulls) = &null_buffer {
936                assert_eq!(*nulls, expected_null_buffer);
937            }
938        }
939
940        // Calls `accumulate_indices`
941        // and opt_filter and ensures it calls the right values
942        fn accumulate_indices_test(
943            group_indices: &[usize],
944            nulls: Option<&NullBuffer>,
945            opt_filter: Option<&BooleanArray>,
946        ) {
947            let mut accumulated_values = vec![];
948
949            accumulate_indices(group_indices, nulls, opt_filter, |group_index| {
950                accumulated_values.push(group_index);
951            });
952
953            // Figure out the expected values
954            let mut expected_values = vec![];
955
956            match (nulls, opt_filter) {
957                (None, None) => group_indices.iter().for_each(|&group_index| {
958                    expected_values.push(group_index);
959                }),
960                (Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each(
961                    |(&group_index, is_valid)| {
962                        if is_valid {
963                            expected_values.push(group_index);
964                        }
965                    },
966                ),
967                (None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each(
968                    |(&group_index, is_included)| {
969                        if let Some(true) = is_included {
970                            expected_values.push(group_index);
971                        }
972                    },
973                ),
974                (Some(nulls), Some(filter)) => {
975                    group_indices
976                        .iter()
977                        .zip(nulls.iter())
978                        .zip(filter.iter())
979                        .for_each(|((&group_index, is_valid), is_included)| {
980                            // if value passed filter
981                            if let (true, Some(true)) = (is_valid, is_included) {
982                                expected_values.push(group_index);
983                            }
984                        });
985                }
986            }
987
988            assert_eq!(
989                accumulated_values, expected_values,
990                "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
991            );
992        }
993
994        /// This is effectively a different implementation of
995        /// accumulate_boolean that we compare with the above implementation
996        fn accumulate_boolean_test(
997            group_indices: &[usize],
998            values: &BooleanArray,
999            opt_filter: Option<&BooleanArray>,
1000            total_num_groups: usize,
1001        ) {
1002            let mut accumulated_values = vec![];
1003            let mut null_state = NullState::new();
1004
1005            null_state.accumulate_boolean(
1006                group_indices,
1007                values,
1008                opt_filter,
1009                total_num_groups,
1010                |group_index, value| {
1011                    accumulated_values.push((group_index, value));
1012                },
1013            );
1014
1015            // Figure out the expected values
1016            let mut expected_values = vec![];
1017            let mut mock = MockNullState::new();
1018
1019            match opt_filter {
1020                None => group_indices.iter().zip(values.iter()).for_each(
1021                    |(&group_index, value)| {
1022                        if let Some(value) = value {
1023                            mock.saw_value(group_index);
1024                            expected_values.push((group_index, value));
1025                        }
1026                    },
1027                ),
1028                Some(filter) => {
1029                    group_indices
1030                        .iter()
1031                        .zip(values.iter())
1032                        .zip(filter.iter())
1033                        .for_each(|((&group_index, value), is_included)| {
1034                            // if value passed filter
1035                            if let Some(true) = is_included
1036                                && let Some(value) = value
1037                            {
1038                                mock.saw_value(group_index);
1039                                expected_values.push((group_index, value));
1040                            }
1041                        });
1042                }
1043            }
1044
1045            assert_eq!(
1046                accumulated_values, expected_values,
1047                "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
1048            );
1049
1050            match &null_state.seen_values {
1051                SeenValues::All { num_values } => {
1052                    assert_eq!(*num_values, total_num_groups);
1053                }
1054                SeenValues::Some { values } => {
1055                    let seen_values = values.finish_cloned();
1056                    mock.validate_seen_values(&seen_values);
1057                }
1058            }
1059
1060            // Validate the final buffer (one value per group)
1061            let expected_null_buffer = Some(mock.expected_null_buffer(total_num_groups));
1062
1063            let is_all_seen = matches!(null_state.seen_values, SeenValues::All { .. });
1064            let null_buffer = null_state.build(EmitTo::All);
1065
1066            if !is_all_seen {
1067                assert_eq!(null_buffer, expected_null_buffer);
1068            }
1069        }
1070    }
1071
1072    /// Parallel implementation of NullState to check expected values
1073    #[derive(Debug, Default)]
1074    struct MockNullState {
1075        /// group indices that had values that passed the filter
1076        seen_values: HashSet<usize>,
1077    }
1078
1079    impl MockNullState {
1080        fn new() -> Self {
1081            Default::default()
1082        }
1083
1084        fn saw_value(&mut self, group_index: usize) {
1085            self.seen_values.insert(group_index);
1086        }
1087
1088        /// did this group index see any input?
1089        fn expected_seen(&self, group_index: usize) -> bool {
1090            self.seen_values.contains(&group_index)
1091        }
1092
1093        /// Validate that the seen_values matches self.seen_values
1094        fn validate_seen_values(&self, seen_values: &BooleanBuffer) {
1095            for (group_index, is_seen) in seen_values.iter().enumerate() {
1096                let expected_seen = self.expected_seen(group_index);
1097                assert_eq!(
1098                    expected_seen, is_seen,
1099                    "mismatch at for group {group_index}"
1100                );
1101            }
1102        }
1103
1104        /// Create the expected null buffer based on if the input had nulls and a filter
1105        fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer {
1106            (0..total_num_groups)
1107                .map(|group_index| self.expected_seen(group_index))
1108                .collect()
1109        }
1110    }
1111
1112    #[test]
1113    fn test_accumulate_multiple_no_nulls_no_filter() {
1114        let group_indices = vec![0, 1, 0, 1];
1115        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1116        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1117        let value_columns = [values1, values2];
1118
1119        let mut accumulated = vec![];
1120        accumulate_multiple(
1121            &group_indices,
1122            &value_columns.iter().collect::<Vec<_>>(),
1123            None,
1124            |group_idx, batch_idx, columns| {
1125                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1126                accumulated.push((group_idx, values));
1127            },
1128        );
1129
1130        let expected = vec![
1131            (0, vec![1, 10]),
1132            (1, vec![2, 20]),
1133            (0, vec![3, 30]),
1134            (1, vec![4, 40]),
1135        ];
1136        assert_eq!(accumulated, expected);
1137    }
1138
1139    #[test]
1140    fn test_accumulate_multiple_with_nulls() {
1141        let group_indices = vec![0, 1, 0, 1];
1142        let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1143        let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1144        let value_columns = [values1, values2];
1145
1146        let mut accumulated = vec![];
1147        accumulate_multiple(
1148            &group_indices,
1149            &value_columns.iter().collect::<Vec<_>>(),
1150            None,
1151            |group_idx, batch_idx, columns| {
1152                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1153                accumulated.push((group_idx, values));
1154            },
1155        );
1156
1157        // Only rows where both columns are non-null should be accumulated
1158        let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])];
1159        assert_eq!(accumulated, expected);
1160    }
1161
1162    #[test]
1163    fn test_accumulate_multiple_with_filter() {
1164        let group_indices = vec![0, 1, 0, 1];
1165        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1166        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1167        let value_columns = [values1, values2];
1168
1169        let filter = BooleanArray::from(vec![true, false, true, false]);
1170
1171        let mut accumulated = vec![];
1172        accumulate_multiple(
1173            &group_indices,
1174            &value_columns.iter().collect::<Vec<_>>(),
1175            Some(&filter),
1176            |group_idx, batch_idx, columns| {
1177                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1178                accumulated.push((group_idx, values));
1179            },
1180        );
1181
1182        // Only rows where filter is true should be accumulated
1183        let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
1184        assert_eq!(accumulated, expected);
1185    }
1186
1187    #[test]
1188    fn test_accumulate_indices_with_null_filter() {
1189        let group_indices = vec![0, 1, 0, 1];
1190        let filter = BooleanArray::new(
1191            BooleanBuffer::from(vec![true, true, true, false]),
1192            Some(NullBuffer::from(vec![true, false, true, true])),
1193        );
1194
1195        let mut accumulated = vec![];
1196        accumulate_indices(&group_indices, None, Some(&filter), |group_idx| {
1197            accumulated.push(group_idx);
1198        });
1199
1200        // A NULL filter value should be treated the same as false, even if the
1201        // underlying BooleanBuffer value is true.
1202        let expected = vec![0, 0];
1203        assert_eq!(accumulated, expected);
1204
1205        let value_validity = NullBuffer::from(vec![true, true, false, true]);
1206        let mut accumulated = vec![];
1207        accumulate_indices(
1208            &group_indices,
1209            Some(&value_validity),
1210            Some(&filter),
1211            |group_idx| {
1212                accumulated.push(group_idx);
1213            },
1214        );
1215
1216        let expected = vec![0];
1217        assert_eq!(accumulated, expected);
1218    }
1219
1220    #[test]
1221    fn test_accumulate_multiple_with_null_filter() {
1222        let group_indices = vec![0, 1, 0, 1];
1223        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1224        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1225        let value_columns = [values1, values2];
1226
1227        let filter = BooleanArray::new(
1228            BooleanBuffer::from(vec![true, true, true, false]),
1229            Some(NullBuffer::from(vec![true, false, true, true])),
1230        );
1231
1232        let mut accumulated = vec![];
1233        accumulate_multiple(
1234            &group_indices,
1235            &value_columns.iter().collect::<Vec<_>>(),
1236            Some(&filter),
1237            |group_idx, batch_idx, columns| {
1238                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1239                accumulated.push((group_idx, values));
1240            },
1241        );
1242
1243        // A NULL filter value should be treated the same as false, even if the
1244        // underlying BooleanBuffer value is true.
1245        let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
1246        assert_eq!(accumulated, expected);
1247    }
1248
1249    #[test]
1250    fn test_accumulate_multiple_with_nulls_and_filter() {
1251        let group_indices = vec![0, 1, 0, 1];
1252        let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1253        let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1254        let value_columns = [values1, values2];
1255
1256        let filter = BooleanArray::from(vec![true, true, true, false]);
1257
1258        let mut accumulated = vec![];
1259        accumulate_multiple(
1260            &group_indices,
1261            &value_columns.iter().collect::<Vec<_>>(),
1262            Some(&filter),
1263            |group_idx, batch_idx, columns| {
1264                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1265                accumulated.push((group_idx, values));
1266            },
1267        );
1268
1269        // Only rows where both:
1270        // 1. Filter is true
1271        // 2. Both columns are non-null
1272        // should be accumulated
1273        let expected = [(0, vec![1, 10])];
1274        assert_eq!(accumulated, expected);
1275    }
1276}