Skip to main content

datafusion_functions_aggregate_common/aggregate/groups_accumulator/
accumulate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`]
19//!
20//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
21
22use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
23use arrow::buffer::NullBuffer;
24use arrow::datatypes::ArrowPrimitiveType;
25
26use datafusion_expr_common::groups_accumulator::EmitTo;
27
28/// If the input has nulls, then the accumulator must potentially
29/// handle each input null value specially (e.g. for `SUM` to mark the
30/// corresponding sum as null)
31///
32/// If there are filters present, `NullState` tracks if it has seen
33/// *any* value for that group (as some values may be filtered
34/// out). Without a filter, the accumulator is only passed groups that
35/// had at least one value to accumulate so they do not need to track
36/// if they have seen values for a particular group.
37#[derive(Debug)]
38pub enum SeenValues {
39    /// All groups seen so far have seen at least one non-null value
40    All {
41        num_values: usize,
42    },
43    // Some groups have not yet seen a non-null value
44    Some {
45        values: BooleanBufferBuilder,
46    },
47}
48
49impl Default for SeenValues {
50    fn default() -> Self {
51        SeenValues::All { num_values: 0 }
52    }
53}
54
55impl SeenValues {
56    /// Return a mutable reference to the `BooleanBufferBuilder` in `SeenValues::Some`.
57    ///
58    /// If `self` is `SeenValues::All`, it is transitioned to `SeenValues::Some`
59    /// by creating a new `BooleanBufferBuilder` where the first `num_values` are true.
60    ///
61    /// The builder is then ensured to have at least `total_num_groups` length,
62    /// with any new entries initialized to false.
63    fn get_builder(&mut self, total_num_groups: usize) -> &mut BooleanBufferBuilder {
64        match self {
65            SeenValues::All { num_values } => {
66                let mut builder = BooleanBufferBuilder::new(total_num_groups);
67                builder.append_n(*num_values, true);
68                if total_num_groups > *num_values {
69                    builder.append_n(total_num_groups - *num_values, false);
70                }
71                *self = SeenValues::Some { values: builder };
72                match self {
73                    SeenValues::Some { values } => values,
74                    _ => unreachable!(),
75                }
76            }
77            SeenValues::Some { values } => {
78                if values.len() < total_num_groups {
79                    values.append_n(total_num_groups - values.len(), false);
80                }
81                values
82            }
83        }
84    }
85}
86
87/// Track the accumulator null state per row: if any values for that
88/// group were null and if any values have been seen at all for that group.
89///
90/// This is part of the inner loop for many [`GroupsAccumulator`]s,
91/// and thus the performance is critical and so there are multiple
92/// specialized implementations, invoked depending on the specific
93/// combinations of the input.
94///
95/// Typically there are 4 potential combinations of inputs must be
96/// special cased for performance:
97///
98/// * With / Without filter
99/// * With / Without nulls in the input
100///
101/// If the input has nulls, then the accumulator must potentially
102/// handle each input null value specially (e.g. for `SUM` to mark the
103/// corresponding sum as null)
104///
105/// If there are filters present, `NullState` tracks if it has seen
106/// *any* value for that group (as some values may be filtered
107/// out). Without a filter, the accumulator is only passed groups that
108/// had at least one value to accumulate so they do not need to track
109/// if they have seen values for a particular group.
110///
111/// [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
112#[derive(Debug)]
113pub struct NullState {
114    /// Have we seen any non-filtered input values for `group_index`?
115    ///
116    /// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is true, have seen at least one non null
117    /// value for group `i`
118    ///
119    /// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is false, have not seen any values that
120    /// pass the filter yet for group `i`
121    ///
122    /// If `seen_values` is `SeenValues::All`, all groups have seen at least one non null value
123    seen_values: SeenValues,
124}
125
126impl Default for NullState {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132impl NullState {
133    pub fn new() -> Self {
134        Self {
135            seen_values: SeenValues::All { num_values: 0 },
136        }
137    }
138
139    /// return the size of all buffers allocated by this null state, not including self
140    pub fn size(&self) -> usize {
141        match &self.seen_values {
142            SeenValues::All { .. } => 0,
143            SeenValues::Some { values } => values.capacity() / 8,
144        }
145    }
146
147    /// Invokes `value_fn(group_index, value)` for each non null, non
148    /// filtered value of `value`, while tracking which groups have
149    /// seen null inputs and which groups have seen any inputs if necessary
150    //
151    /// # Arguments:
152    ///
153    /// * `values`: the input arguments to the accumulator
154    /// * `group_indices`:  To which groups do the rows in `values` belong, (aka group_index)
155    /// * `opt_filter`: if present, only rows for which is Some(true) are included
156    /// * `value_fn`: function invoked for  (group_index, value) where value is non null
157    ///
158    /// See [`accumulate`], for more details on how value_fn is called
159    ///
160    /// When value_fn is called it also sets
161    ///
162    /// 1. `self.seen_values[group_index]` to true for all rows that had a non null value
163    pub fn accumulate<T, F>(
164        &mut self,
165        group_indices: &[usize],
166        values: &PrimitiveArray<T>,
167        opt_filter: Option<&BooleanArray>,
168        total_num_groups: usize,
169        mut value_fn: F,
170    ) where
171        T: ArrowPrimitiveType + Send,
172        F: FnMut(usize, T::Native) + Send,
173    {
174        // skip null handling if no nulls in input or accumulator
175        if let SeenValues::All { num_values } = &mut self.seen_values
176            && opt_filter.is_none()
177            && values.null_count() == 0
178        {
179            accumulate(group_indices, values, None, value_fn);
180            *num_values = total_num_groups;
181            return;
182        }
183
184        let seen_values = self.seen_values.get_builder(total_num_groups);
185        accumulate(group_indices, values, opt_filter, |group_index, value| {
186            seen_values.set_bit(group_index, true);
187            value_fn(group_index, value);
188        });
189    }
190
191    /// Invokes `value_fn(group_index, value)` for each non null, non
192    /// filtered value in `values`, while tracking which groups have
193    /// seen null inputs and which groups have seen any inputs, for
194    /// [`BooleanArray`]s.
195    ///
196    /// Since `BooleanArray` is not a [`PrimitiveArray`] it must be
197    /// handled specially.
198    ///
199    /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for
200    /// more details on other arguments.
201    pub fn accumulate_boolean<F>(
202        &mut self,
203        group_indices: &[usize],
204        values: &BooleanArray,
205        opt_filter: Option<&BooleanArray>,
206        total_num_groups: usize,
207        mut value_fn: F,
208    ) where
209        F: FnMut(usize, bool) + Send,
210    {
211        let data = values.values();
212        assert_eq!(data.len(), group_indices.len());
213
214        // skip null handling if no nulls in input or accumulator
215        if let SeenValues::All { num_values } = &mut self.seen_values
216            && opt_filter.is_none()
217            && values.null_count() == 0
218        {
219            group_indices
220                .iter()
221                .zip(data.iter())
222                .for_each(|(&group_index, new_value)| value_fn(group_index, new_value));
223            *num_values = total_num_groups;
224
225            return;
226        }
227
228        let seen_values = self.seen_values.get_builder(total_num_groups);
229
230        // These could be made more performant by iterating in chunks of 64 bits at a time
231        match (values.null_count() > 0, opt_filter) {
232            // no nulls, no filter,
233            (false, None) => {
234                // if we have previously seen nulls, ensure the null
235                // buffer is big enough (start everything at valid)
236                group_indices.iter().zip(data.iter()).for_each(
237                    |(&group_index, new_value)| {
238                        seen_values.set_bit(group_index, true);
239                        value_fn(group_index, new_value)
240                    },
241                )
242            }
243            // nulls, no filter
244            (true, None) => {
245                let nulls = values.nulls().unwrap();
246                group_indices
247                    .iter()
248                    .zip(data.iter())
249                    .zip(nulls.iter())
250                    .for_each(|((&group_index, new_value), is_valid)| {
251                        if is_valid {
252                            seen_values.set_bit(group_index, true);
253                            value_fn(group_index, new_value);
254                        }
255                    })
256            }
257            // no nulls, but a filter
258            (false, Some(filter)) => {
259                assert_eq!(filter.len(), group_indices.len());
260
261                group_indices
262                    .iter()
263                    .zip(data.iter())
264                    .zip(filter.iter())
265                    .for_each(|((&group_index, new_value), filter_value)| {
266                        if let Some(true) = filter_value {
267                            seen_values.set_bit(group_index, true);
268                            value_fn(group_index, new_value);
269                        }
270                    })
271            }
272            // both null values and filters
273            (true, Some(filter)) => {
274                assert_eq!(filter.len(), group_indices.len());
275                filter
276                    .iter()
277                    .zip(group_indices.iter())
278                    .zip(values.iter())
279                    .for_each(|((filter_value, &group_index), new_value)| {
280                        if let Some(true) = filter_value
281                            && let Some(new_value) = new_value
282                        {
283                            seen_values.set_bit(group_index, true);
284                            value_fn(group_index, new_value)
285                        }
286                    })
287            }
288        }
289    }
290
291    /// Creates the a [`NullBuffer`] representing which group_indices
292    /// should have null values (because they never saw any values)
293    /// for the `emit_to` rows.
294    ///
295    /// resets the internal state appropriately
296    pub fn build(&mut self, emit_to: EmitTo) -> Option<NullBuffer> {
297        match emit_to {
298            EmitTo::All => {
299                let old_seen = std::mem::take(&mut self.seen_values);
300                match old_seen {
301                    SeenValues::All { .. } => None,
302                    SeenValues::Some { mut values } => {
303                        Some(NullBuffer::new(values.finish()))
304                    }
305                }
306            }
307            EmitTo::First(n) => match &mut self.seen_values {
308                SeenValues::All { num_values } => {
309                    *num_values = num_values.saturating_sub(n);
310                    None
311                }
312                SeenValues::Some { .. } => {
313                    let mut old_values = match std::mem::take(&mut self.seen_values) {
314                        SeenValues::Some { values } => values,
315                        _ => unreachable!(),
316                    };
317                    let nulls = old_values.finish();
318                    let first_n_null = nulls.slice(0, n);
319                    let remainder = nulls.slice(n, nulls.len() - n);
320                    let mut new_builder = BooleanBufferBuilder::new(remainder.len());
321                    new_builder.append_buffer(&remainder);
322                    self.seen_values = SeenValues::Some {
323                        values: new_builder,
324                    };
325                    Some(NullBuffer::new(first_n_null))
326                }
327            },
328        }
329    }
330}
331
332/// Invokes `value_fn(group_index, value)` for each non null, non
333/// filtered value of `value`,
334///
335/// # Arguments:
336///
337/// * `group_indices`:  To which groups do the rows in `values` belong, (aka group_index)
338/// * `values`: the input arguments to the accumulator
339/// * `opt_filter`: if present, only rows for which is Some(true) are included
340/// * `value_fn`: function invoked for  (group_index, value) where value is non null
341///
342/// # Example
343///
344/// ```text
345///  ┌─────────┐   ┌─────────┐   ┌ ─ ─ ─ ─ ┐
346///  │ ┌─────┐ │   │ ┌─────┐ │     ┌─────┐
347///  │ │  2  │ │   │ │ 200 │ │   │ │  t  │ │
348///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
349///  │ │  2  │ │   │ │ 100 │ │   │ │  f  │ │
350///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
351///  │ │  0  │ │   │ │ 200 │ │   │ │  t  │ │
352///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
353///  │ │  1  │ │   │ │ 200 │ │   │ │NULL │ │
354///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
355///  │ │  0  │ │   │ │ 300 │ │   │ │  t  │ │
356///  │ └─────┘ │   │ └─────┘ │     └─────┘
357///  └─────────┘   └─────────┘   └ ─ ─ ─ ─ ┘
358///
359/// group_indices   values        opt_filter
360/// ```
361///
362/// In the example above, `value_fn` is invoked for each (group_index,
363/// value) pair where `opt_filter[i]` is true and values is non null
364///
365/// ```text
366/// value_fn(2, 200)
367/// value_fn(0, 200)
368/// value_fn(0, 300)
369/// ```
370pub fn accumulate<T, F>(
371    group_indices: &[usize],
372    values: &PrimitiveArray<T>,
373    opt_filter: Option<&BooleanArray>,
374    mut value_fn: F,
375) where
376    T: ArrowPrimitiveType + Send,
377    F: FnMut(usize, T::Native) + Send,
378{
379    let data: &[T::Native] = values.values();
380    assert_eq!(data.len(), group_indices.len());
381
382    match (values.null_count() > 0, opt_filter) {
383        // no nulls, no filter,
384        (false, None) => {
385            let iter = group_indices.iter().zip(data.iter());
386            for (&group_index, &new_value) in iter {
387                value_fn(group_index, new_value);
388            }
389        }
390        // nulls, no filter
391        (true, None) => {
392            let nulls = values.nulls().unwrap();
393            // This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum
394            // iterate over in chunks of 64 bits for more efficient null checking
395            let group_indices_chunks = group_indices.chunks_exact(64);
396            let data_chunks = data.chunks_exact(64);
397            let bit_chunks = nulls.inner().bit_chunks();
398
399            let group_indices_remainder = group_indices_chunks.remainder();
400            let data_remainder = data_chunks.remainder();
401
402            group_indices_chunks
403                .zip(data_chunks)
404                .zip(bit_chunks.iter())
405                .for_each(|((group_index_chunk, data_chunk), mask)| {
406                    // index_mask has value 1 << i in the loop
407                    let mut index_mask = 1;
408                    group_index_chunk.iter().zip(data_chunk.iter()).for_each(
409                        |(&group_index, &new_value)| {
410                            // valid bit was set, real value
411                            let is_valid = (mask & index_mask) != 0;
412                            if is_valid {
413                                value_fn(group_index, new_value);
414                            }
415                            index_mask <<= 1;
416                        },
417                    )
418                });
419
420            // handle any remaining bits (after the initial 64)
421            let remainder_bits = bit_chunks.remainder_bits();
422            group_indices_remainder
423                .iter()
424                .zip(data_remainder.iter())
425                .enumerate()
426                .for_each(|(i, (&group_index, &new_value))| {
427                    let is_valid = remainder_bits & (1 << i) != 0;
428                    if is_valid {
429                        value_fn(group_index, new_value);
430                    }
431                });
432        }
433        // no nulls, but a filter
434        (false, Some(filter)) => {
435            assert_eq!(filter.len(), group_indices.len());
436            // The performance with a filter could be improved by
437            // iterating over the filter in chunks, rather than a single
438            // iterator. TODO file a ticket
439            group_indices
440                .iter()
441                .zip(data.iter())
442                .zip(filter.iter())
443                .for_each(|((&group_index, &new_value), filter_value)| {
444                    if let Some(true) = filter_value {
445                        value_fn(group_index, new_value);
446                    }
447                })
448        }
449        // both null values and filters
450        (true, Some(filter)) => {
451            assert_eq!(filter.len(), group_indices.len());
452            // The performance with a filter could be improved by
453            // iterating over the filter in chunks, rather than using
454            // iterators. TODO file a ticket
455            filter
456                .iter()
457                .zip(group_indices.iter())
458                .zip(values.iter())
459                .for_each(|((filter_value, &group_index), new_value)| {
460                    if let Some(true) = filter_value
461                        && let Some(new_value) = new_value
462                    {
463                        value_fn(group_index, new_value)
464                    }
465                })
466        }
467    }
468}
469
470/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`)
471///
472/// This method assumes that for any input record index, if any of the value column
473/// is null, or it's filtered out by `opt_filter`, then the record would be ignored.
474/// (won't be accumulated by `value_fn`)
475///
476/// # Arguments
477///
478/// * `group_indices` - To which groups do the rows in `value_columns` belong
479/// * `value_columns` - The input arrays to accumulate
480/// * `opt_filter` - Optional filter array. If present, only rows where filter is `Some(true)` are included
481/// * `value_fn` - Callback function for each valid row, with parameters:
482///     * `group_idx`: The group index for the current row
483///     * `batch_idx`: The index of the current row in the input arrays
484///     * `columns`: Reference to all input arrays for accessing values
485pub fn accumulate_multiple<T, F>(
486    group_indices: &[usize],
487    value_columns: &[&PrimitiveArray<T>],
488    opt_filter: Option<&BooleanArray>,
489    mut value_fn: F,
490) where
491    T: ArrowPrimitiveType + Send,
492    F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
493{
494    // Calculate `valid_indices` to accumulate, non-valid indices are ignored.
495    // `valid_indices` is a bit mask corresponding to the `group_indices`. An index
496    // is considered valid if:
497    // 1. All columns are non-null at this index.
498    // 2. Not filtered out by `opt_filter`
499
500    // Take AND from all null buffers of `value_columns`.
501    let combined_nulls = value_columns
502        .iter()
503        .map(|arr| arr.logical_nulls())
504        .fold(None, |acc, nulls| {
505            NullBuffer::union(acc.as_ref(), nulls.as_ref())
506        });
507
508    // Take AND from previous combined nulls and `opt_filter`.
509    let valid_indices = match (combined_nulls, opt_filter) {
510        (None, None) => None,
511        (None, Some(filter)) => Some(filter.clone()),
512        (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)),
513        (Some(nulls), Some(filter)) => {
514            let combined = nulls.inner() & filter.values();
515            Some(BooleanArray::new(combined, None))
516        }
517    };
518
519    for col in value_columns.iter() {
520        debug_assert_eq!(col.len(), group_indices.len());
521    }
522
523    match valid_indices {
524        None => {
525            for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
526                value_fn(group_idx, batch_idx, value_columns);
527            }
528        }
529        Some(valid_indices) => {
530            for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
531                if valid_indices.value(batch_idx) {
532                    value_fn(group_idx, batch_idx, value_columns);
533                }
534            }
535        }
536    }
537}
538
539/// This function is called to update the accumulator state per row
540/// when the value is not needed (e.g. COUNT)
541///
542/// `F`: Invoked like `value_fn(group_index) for all non null values
543/// passing the filter. Note that no tracking is done for null inputs
544/// or which groups have seen any values
545///
546/// See [`NullState::accumulate`], for more details on other
547/// arguments.
548pub fn accumulate_indices<F>(
549    group_indices: &[usize],
550    nulls: Option<&NullBuffer>,
551    opt_filter: Option<&BooleanArray>,
552    mut index_fn: F,
553) where
554    F: FnMut(usize) + Send,
555{
556    match (nulls, opt_filter) {
557        (None, None) => {
558            for &group_index in group_indices.iter() {
559                index_fn(group_index)
560            }
561        }
562        (None, Some(filter)) => {
563            debug_assert_eq!(filter.len(), group_indices.len());
564            let group_indices_chunks = group_indices.chunks_exact(64);
565            let bit_chunks = filter.values().bit_chunks();
566
567            let group_indices_remainder = group_indices_chunks.remainder();
568
569            group_indices_chunks.zip(bit_chunks.iter()).for_each(
570                |(group_index_chunk, mask)| {
571                    // index_mask has value 1 << i in the loop
572                    let mut index_mask = 1;
573                    group_index_chunk.iter().for_each(|&group_index| {
574                        // valid bit was set, real vale
575                        let is_valid = (mask & index_mask) != 0;
576                        if is_valid {
577                            index_fn(group_index);
578                        }
579                        index_mask <<= 1;
580                    })
581                },
582            );
583
584            // handle any remaining bits (after the initial 64)
585            let remainder_bits = bit_chunks.remainder_bits();
586            group_indices_remainder
587                .iter()
588                .enumerate()
589                .for_each(|(i, &group_index)| {
590                    let is_valid = remainder_bits & (1 << i) != 0;
591                    if is_valid {
592                        index_fn(group_index)
593                    }
594                });
595        }
596        (Some(valids), None) => {
597            debug_assert_eq!(valids.len(), group_indices.len());
598            // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
599            // iterate over in chunks of 64 bits for more efficient null checking
600            let group_indices_chunks = group_indices.chunks_exact(64);
601            let bit_chunks = valids.inner().bit_chunks();
602
603            let group_indices_remainder = group_indices_chunks.remainder();
604
605            group_indices_chunks.zip(bit_chunks.iter()).for_each(
606                |(group_index_chunk, mask)| {
607                    // index_mask has value 1 << i in the loop
608                    let mut index_mask = 1;
609                    group_index_chunk.iter().for_each(|&group_index| {
610                        // valid bit was set, real vale
611                        let is_valid = (mask & index_mask) != 0;
612                        if is_valid {
613                            index_fn(group_index);
614                        }
615                        index_mask <<= 1;
616                    })
617                },
618            );
619
620            // handle any remaining bits (after the initial 64)
621            let remainder_bits = bit_chunks.remainder_bits();
622            group_indices_remainder
623                .iter()
624                .enumerate()
625                .for_each(|(i, &group_index)| {
626                    let is_valid = remainder_bits & (1 << i) != 0;
627                    if is_valid {
628                        index_fn(group_index)
629                    }
630                });
631        }
632
633        (Some(valids), Some(filter)) => {
634            debug_assert_eq!(filter.len(), group_indices.len());
635            debug_assert_eq!(valids.len(), group_indices.len());
636
637            let group_indices_chunks = group_indices.chunks_exact(64);
638            let valid_bit_chunks = valids.inner().bit_chunks();
639            let filter_bit_chunks = filter.values().bit_chunks();
640
641            let group_indices_remainder = group_indices_chunks.remainder();
642
643            group_indices_chunks
644                .zip(valid_bit_chunks.iter())
645                .zip(filter_bit_chunks.iter())
646                .for_each(|((group_index_chunk, valid_mask), filter_mask)| {
647                    // index_mask has value 1 << i in the loop
648                    let mut index_mask = 1;
649                    group_index_chunk.iter().for_each(|&group_index| {
650                        // valid bit was set, real vale
651                        let is_valid = (valid_mask & filter_mask & index_mask) != 0;
652                        if is_valid {
653                            index_fn(group_index);
654                        }
655                        index_mask <<= 1;
656                    })
657                });
658
659            // handle any remaining bits (after the initial 64)
660            let remainder_valid_bits = valid_bit_chunks.remainder_bits();
661            let remainder_filter_bits = filter_bit_chunks.remainder_bits();
662            group_indices_remainder
663                .iter()
664                .enumerate()
665                .for_each(|(i, &group_index)| {
666                    let is_valid =
667                        remainder_valid_bits & remainder_filter_bits & (1 << i) != 0;
668                    if is_valid {
669                        index_fn(group_index)
670                    }
671                });
672        }
673    }
674}
675
676#[cfg(test)]
677mod test {
678    use super::*;
679
680    use arrow::{
681        array::{Int32Array, UInt32Array},
682        buffer::BooleanBuffer,
683    };
684    use rand::{Rng, rngs::ThreadRng};
685    use std::collections::HashSet;
686
687    #[test]
688    fn accumulate() {
689        let group_indices = (0..100).collect();
690        let values = (0..100).map(|i| (i + 1) * 10).collect();
691        let values_with_nulls = (0..100)
692            .map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) })
693            .collect();
694
695        // default to every fifth value being false, every even
696        // being null
697        let filter: BooleanArray = (0..100)
698            .map(|i| {
699                let is_even = i % 2 == 0;
700                let is_fifth = i % 5 == 0;
701                if is_even {
702                    None
703                } else if is_fifth {
704                    Some(false)
705                } else {
706                    Some(true)
707                }
708            })
709            .collect();
710
711        Fixture {
712            group_indices,
713            values,
714            values_with_nulls,
715            filter,
716        }
717        .run()
718    }
719
720    #[test]
721    fn accumulate_fuzz() {
722        let mut rng = rand::rng();
723        for _ in 0..100 {
724            Fixture::new_random(&mut rng).run();
725        }
726    }
727
728    /// Values for testing (there are enough values to exercise the 64 bit chunks
729    struct Fixture {
730        /// 100..0
731        group_indices: Vec<usize>,
732
733        /// 10, 20, ... 1010
734        values: Vec<u32>,
735
736        /// same as values, but every third is null:
737        /// None, Some(20), Some(30), None ...
738        values_with_nulls: Vec<Option<u32>>,
739
740        /// filter (defaults to None)
741        filter: BooleanArray,
742    }
743
744    impl Fixture {
745        fn new_random(rng: &mut ThreadRng) -> Self {
746            // Number of input values in a batch
747            let num_values: usize = rng.random_range(1..200);
748            // number of distinct groups
749            let num_groups: usize = rng.random_range(2..1000);
750            let max_group = num_groups - 1;
751
752            let group_indices: Vec<usize> = (0..num_values)
753                .map(|_| rng.random_range(0..max_group))
754                .collect();
755
756            let values: Vec<u32> = (0..num_values).map(|_| rng.random()).collect();
757
758            // 10% chance of false
759            // 10% change of null
760            // 80% chance of true
761            let filter: BooleanArray = (0..num_values)
762                .map(|_| {
763                    let filter_value = rng.random_range(0.0..1.0);
764                    if filter_value < 0.1 {
765                        Some(false)
766                    } else if filter_value < 0.2 {
767                        None
768                    } else {
769                        Some(true)
770                    }
771                })
772                .collect();
773
774            // random values with random number and location of nulls
775            // random null percentage
776            let null_pct: f32 = rng.random_range(0.0..1.0);
777            let values_with_nulls: Vec<Option<u32>> = (0..num_values)
778                .map(|_| {
779                    let is_null = null_pct < rng.random_range(0.0..1.0);
780                    if is_null { None } else { Some(rng.random()) }
781                })
782                .collect();
783
784            Self {
785                group_indices,
786                values,
787                values_with_nulls,
788                filter,
789            }
790        }
791
792        /// returns `Self::values` an Array
793        fn values_array(&self) -> UInt32Array {
794            UInt32Array::from(self.values.clone())
795        }
796
797        /// returns `Self::values_with_nulls` as an Array
798        fn values_with_nulls_array(&self) -> UInt32Array {
799            UInt32Array::from(self.values_with_nulls.clone())
800        }
801
802        /// Calls `NullState::accumulate` and `accumulate_indices`
803        /// with all combinations of nulls and filter values
804        fn run(&self) {
805            let total_num_groups = *self.group_indices.iter().max().unwrap() + 1;
806
807            let group_indices = &self.group_indices;
808            let values_array = self.values_array();
809            let values_with_nulls_array = self.values_with_nulls_array();
810            let filter = &self.filter;
811
812            // no null, no filters
813            Self::accumulate_test(group_indices, &values_array, None, total_num_groups);
814
815            // nulls, no filters
816            Self::accumulate_test(
817                group_indices,
818                &values_with_nulls_array,
819                None,
820                total_num_groups,
821            );
822
823            // no nulls, filters
824            Self::accumulate_test(
825                group_indices,
826                &values_array,
827                Some(filter),
828                total_num_groups,
829            );
830
831            // nulls, filters
832            Self::accumulate_test(
833                group_indices,
834                &values_with_nulls_array,
835                Some(filter),
836                total_num_groups,
837            );
838        }
839
840        /// Calls `NullState::accumulate` and `accumulate_indices` to
841        /// ensure it generates the correct values.
842        fn accumulate_test(
843            group_indices: &[usize],
844            values: &UInt32Array,
845            opt_filter: Option<&BooleanArray>,
846            total_num_groups: usize,
847        ) {
848            Self::accumulate_values_test(
849                group_indices,
850                values,
851                opt_filter,
852                total_num_groups,
853            );
854            Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter);
855
856            // Convert values into a boolean array (anything above the
857            // average is true, otherwise false)
858            let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum();
859            let boolean_values: BooleanArray =
860                values.iter().map(|v| v.map(|v| v as usize > avg)).collect();
861            Self::accumulate_boolean_test(
862                group_indices,
863                &boolean_values,
864                opt_filter,
865                total_num_groups,
866            );
867        }
868
869        /// This is effectively a different implementation of
870        /// accumulate that we compare with the above implementation
871        fn accumulate_values_test(
872            group_indices: &[usize],
873            values: &UInt32Array,
874            opt_filter: Option<&BooleanArray>,
875            total_num_groups: usize,
876        ) {
877            let mut accumulated_values = vec![];
878            let mut null_state = NullState::new();
879
880            null_state.accumulate(
881                group_indices,
882                values,
883                opt_filter,
884                total_num_groups,
885                |group_index, value| {
886                    accumulated_values.push((group_index, value));
887                },
888            );
889
890            // Figure out the expected values
891            let mut expected_values = vec![];
892            let mut mock = MockNullState::new();
893
894            match opt_filter {
895                None => group_indices.iter().zip(values.iter()).for_each(
896                    |(&group_index, value)| {
897                        if let Some(value) = value {
898                            mock.saw_value(group_index);
899                            expected_values.push((group_index, value));
900                        }
901                    },
902                ),
903                Some(filter) => {
904                    group_indices
905                        .iter()
906                        .zip(values.iter())
907                        .zip(filter.iter())
908                        .for_each(|((&group_index, value), is_included)| {
909                            // if value passed filter
910                            if let Some(true) = is_included
911                                && let Some(value) = value
912                            {
913                                mock.saw_value(group_index);
914                                expected_values.push((group_index, value));
915                            }
916                        });
917                }
918            }
919
920            assert_eq!(
921                accumulated_values, expected_values,
922                "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
923            );
924
925            match &null_state.seen_values {
926                SeenValues::All { num_values } => {
927                    assert_eq!(*num_values, total_num_groups);
928                }
929                SeenValues::Some { values } => {
930                    let seen_values = values.finish_cloned();
931                    mock.validate_seen_values(&seen_values);
932                }
933            }
934
935            // Validate the final buffer (one value per group)
936            let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
937
938            let null_buffer = null_state.build(EmitTo::All);
939            if let Some(nulls) = &null_buffer {
940                assert_eq!(*nulls, expected_null_buffer);
941            }
942        }
943
944        // Calls `accumulate_indices`
945        // and opt_filter and ensures it calls the right values
946        fn accumulate_indices_test(
947            group_indices: &[usize],
948            nulls: Option<&NullBuffer>,
949            opt_filter: Option<&BooleanArray>,
950        ) {
951            let mut accumulated_values = vec![];
952
953            accumulate_indices(group_indices, nulls, opt_filter, |group_index| {
954                accumulated_values.push(group_index);
955            });
956
957            // Figure out the expected values
958            let mut expected_values = vec![];
959
960            match (nulls, opt_filter) {
961                (None, None) => group_indices.iter().for_each(|&group_index| {
962                    expected_values.push(group_index);
963                }),
964                (Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each(
965                    |(&group_index, is_valid)| {
966                        if is_valid {
967                            expected_values.push(group_index);
968                        }
969                    },
970                ),
971                (None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each(
972                    |(&group_index, is_included)| {
973                        if let Some(true) = is_included {
974                            expected_values.push(group_index);
975                        }
976                    },
977                ),
978                (Some(nulls), Some(filter)) => {
979                    group_indices
980                        .iter()
981                        .zip(nulls.iter())
982                        .zip(filter.iter())
983                        .for_each(|((&group_index, is_valid), is_included)| {
984                            // if value passed filter
985                            if let (true, Some(true)) = (is_valid, is_included) {
986                                expected_values.push(group_index);
987                            }
988                        });
989                }
990            }
991
992            assert_eq!(
993                accumulated_values, expected_values,
994                "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
995            );
996        }
997
998        /// This is effectively a different implementation of
999        /// accumulate_boolean that we compare with the above implementation
1000        fn accumulate_boolean_test(
1001            group_indices: &[usize],
1002            values: &BooleanArray,
1003            opt_filter: Option<&BooleanArray>,
1004            total_num_groups: usize,
1005        ) {
1006            let mut accumulated_values = vec![];
1007            let mut null_state = NullState::new();
1008
1009            null_state.accumulate_boolean(
1010                group_indices,
1011                values,
1012                opt_filter,
1013                total_num_groups,
1014                |group_index, value| {
1015                    accumulated_values.push((group_index, value));
1016                },
1017            );
1018
1019            // Figure out the expected values
1020            let mut expected_values = vec![];
1021            let mut mock = MockNullState::new();
1022
1023            match opt_filter {
1024                None => group_indices.iter().zip(values.iter()).for_each(
1025                    |(&group_index, value)| {
1026                        if let Some(value) = value {
1027                            mock.saw_value(group_index);
1028                            expected_values.push((group_index, value));
1029                        }
1030                    },
1031                ),
1032                Some(filter) => {
1033                    group_indices
1034                        .iter()
1035                        .zip(values.iter())
1036                        .zip(filter.iter())
1037                        .for_each(|((&group_index, value), is_included)| {
1038                            // if value passed filter
1039                            if let Some(true) = is_included
1040                                && let Some(value) = value
1041                            {
1042                                mock.saw_value(group_index);
1043                                expected_values.push((group_index, value));
1044                            }
1045                        });
1046                }
1047            }
1048
1049            assert_eq!(
1050                accumulated_values, expected_values,
1051                "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
1052            );
1053
1054            match &null_state.seen_values {
1055                SeenValues::All { num_values } => {
1056                    assert_eq!(*num_values, total_num_groups);
1057                }
1058                SeenValues::Some { values } => {
1059                    let seen_values = values.finish_cloned();
1060                    mock.validate_seen_values(&seen_values);
1061                }
1062            }
1063
1064            // Validate the final buffer (one value per group)
1065            let expected_null_buffer = Some(mock.expected_null_buffer(total_num_groups));
1066
1067            let is_all_seen = matches!(null_state.seen_values, SeenValues::All { .. });
1068            let null_buffer = null_state.build(EmitTo::All);
1069
1070            if !is_all_seen {
1071                assert_eq!(null_buffer, expected_null_buffer);
1072            }
1073        }
1074    }
1075
1076    /// Parallel implementation of NullState to check expected values
1077    #[derive(Debug, Default)]
1078    struct MockNullState {
1079        /// group indices that had values that passed the filter
1080        seen_values: HashSet<usize>,
1081    }
1082
1083    impl MockNullState {
1084        fn new() -> Self {
1085            Default::default()
1086        }
1087
1088        fn saw_value(&mut self, group_index: usize) {
1089            self.seen_values.insert(group_index);
1090        }
1091
1092        /// did this group index see any input?
1093        fn expected_seen(&self, group_index: usize) -> bool {
1094            self.seen_values.contains(&group_index)
1095        }
1096
1097        /// Validate that the seen_values matches self.seen_values
1098        fn validate_seen_values(&self, seen_values: &BooleanBuffer) {
1099            for (group_index, is_seen) in seen_values.iter().enumerate() {
1100                let expected_seen = self.expected_seen(group_index);
1101                assert_eq!(
1102                    expected_seen, is_seen,
1103                    "mismatch at for group {group_index}"
1104                );
1105            }
1106        }
1107
1108        /// Create the expected null buffer based on if the input had nulls and a filter
1109        fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer {
1110            (0..total_num_groups)
1111                .map(|group_index| self.expected_seen(group_index))
1112                .collect()
1113        }
1114    }
1115
1116    #[test]
1117    fn test_accumulate_multiple_no_nulls_no_filter() {
1118        let group_indices = vec![0, 1, 0, 1];
1119        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1120        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1121        let value_columns = [values1, values2];
1122
1123        let mut accumulated = vec![];
1124        accumulate_multiple(
1125            &group_indices,
1126            &value_columns.iter().collect::<Vec<_>>(),
1127            None,
1128            |group_idx, batch_idx, columns| {
1129                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1130                accumulated.push((group_idx, values));
1131            },
1132        );
1133
1134        let expected = vec![
1135            (0, vec![1, 10]),
1136            (1, vec![2, 20]),
1137            (0, vec![3, 30]),
1138            (1, vec![4, 40]),
1139        ];
1140        assert_eq!(accumulated, expected);
1141    }
1142
1143    #[test]
1144    fn test_accumulate_multiple_with_nulls() {
1145        let group_indices = vec![0, 1, 0, 1];
1146        let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1147        let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1148        let value_columns = [values1, values2];
1149
1150        let mut accumulated = vec![];
1151        accumulate_multiple(
1152            &group_indices,
1153            &value_columns.iter().collect::<Vec<_>>(),
1154            None,
1155            |group_idx, batch_idx, columns| {
1156                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1157                accumulated.push((group_idx, values));
1158            },
1159        );
1160
1161        // Only rows where both columns are non-null should be accumulated
1162        let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])];
1163        assert_eq!(accumulated, expected);
1164    }
1165
1166    #[test]
1167    fn test_accumulate_multiple_with_filter() {
1168        let group_indices = vec![0, 1, 0, 1];
1169        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1170        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1171        let value_columns = [values1, values2];
1172
1173        let filter = BooleanArray::from(vec![true, false, true, false]);
1174
1175        let mut accumulated = vec![];
1176        accumulate_multiple(
1177            &group_indices,
1178            &value_columns.iter().collect::<Vec<_>>(),
1179            Some(&filter),
1180            |group_idx, batch_idx, columns| {
1181                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1182                accumulated.push((group_idx, values));
1183            },
1184        );
1185
1186        // Only rows where filter is true should be accumulated
1187        let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
1188        assert_eq!(accumulated, expected);
1189    }
1190
1191    #[test]
1192    fn test_accumulate_multiple_with_nulls_and_filter() {
1193        let group_indices = vec![0, 1, 0, 1];
1194        let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
1195        let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
1196        let value_columns = [values1, values2];
1197
1198        let filter = BooleanArray::from(vec![true, true, true, false]);
1199
1200        let mut accumulated = vec![];
1201        accumulate_multiple(
1202            &group_indices,
1203            &value_columns.iter().collect::<Vec<_>>(),
1204            Some(&filter),
1205            |group_idx, batch_idx, columns| {
1206                let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1207                accumulated.push((group_idx, values));
1208            },
1209        );
1210
1211        // Only rows where both:
1212        // 1. Filter is true
1213        // 2. Both columns are non-null
1214        // should be accumulated
1215        let expected = [(0, vec![1, 10])];
1216        assert_eq!(accumulated, expected);
1217    }
1218}