use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::datatypes::ArrowPrimitiveType;
use datafusion_expr_common::groups_accumulator::EmitTo;
#[derive(Debug)]
pub struct NullState {
seen_values: BooleanBufferBuilder,
}
impl Default for NullState {
fn default() -> Self {
Self::new()
}
}
impl NullState {
pub fn new() -> Self {
Self {
seen_values: BooleanBufferBuilder::new(0),
}
}
pub fn size(&self) -> usize {
self.seen_values.capacity() / 8
}
pub fn accumulate<T, F>(
&mut self,
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
mut value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
{
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);
accumulate(group_indices, values, opt_filter, |group_index, value| {
seen_values.set_bit(group_index, true);
value_fn(group_index, value);
});
}
pub fn accumulate_boolean<F>(
&mut self,
group_indices: &[usize],
values: &BooleanArray,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
mut value_fn: F,
) where
F: FnMut(usize, bool) + Send,
{
let data = values.values();
assert_eq!(data.len(), group_indices.len());
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);
match (values.null_count() > 0, opt_filter) {
(false, None) => {
group_indices.iter().zip(data.iter()).for_each(
|(&group_index, new_value)| {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
},
)
}
(true, None) => {
let nulls = values.nulls().unwrap();
group_indices
.iter()
.zip(data.iter())
.zip(nulls.iter())
.for_each(|((&group_index, new_value), is_valid)| {
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
})
}
(false, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
group_indices
.iter()
.zip(data.iter())
.zip(filter.iter())
.for_each(|((&group_index, new_value), filter_value)| {
if let Some(true) = filter_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
})
}
(true, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
filter
.iter()
.zip(group_indices.iter())
.zip(values.iter())
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value
&& let Some(new_value) = new_value
{
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
}
})
}
}
}
pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
let nulls: BooleanBuffer = self.seen_values.finish();
let nulls = match emit_to {
EmitTo::All => nulls,
EmitTo::First(n) => {
let first_n_null: BooleanBuffer = nulls.slice(0, n);
self.seen_values
.append_buffer(&nulls.slice(n, nulls.len() - n));
first_n_null
}
};
NullBuffer::new(nulls)
}
}
pub fn accumulate<T, F>(
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&BooleanArray>,
mut value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
{
let data: &[T::Native] = values.values();
assert_eq!(data.len(), group_indices.len());
match (values.null_count() > 0, opt_filter) {
(false, None) => {
let iter = group_indices.iter().zip(data.iter());
for (&group_index, &new_value) in iter {
value_fn(group_index, new_value);
}
}
(true, None) => {
let nulls = values.nulls().unwrap();
let group_indices_chunks = group_indices.chunks_exact(64);
let data_chunks = data.chunks_exact(64);
let bit_chunks = nulls.inner().bit_chunks();
let group_indices_remainder = group_indices_chunks.remainder();
let data_remainder = data_chunks.remainder();
group_indices_chunks
.zip(data_chunks)
.zip(bit_chunks.iter())
.for_each(|((group_index_chunk, data_chunk), mask)| {
let mut index_mask = 1;
group_index_chunk.iter().zip(data_chunk.iter()).for_each(
|(&group_index, &new_value)| {
let is_valid = (mask & index_mask) != 0;
if is_valid {
value_fn(group_index, new_value);
}
index_mask <<= 1;
},
)
});
let remainder_bits = bit_chunks.remainder_bits();
group_indices_remainder
.iter()
.zip(data_remainder.iter())
.enumerate()
.for_each(|(i, (&group_index, &new_value))| {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
value_fn(group_index, new_value);
}
});
}
(false, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
group_indices
.iter()
.zip(data.iter())
.zip(filter.iter())
.for_each(|((&group_index, &new_value), filter_value)| {
if let Some(true) = filter_value {
value_fn(group_index, new_value);
}
})
}
(true, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
filter
.iter()
.zip(group_indices.iter())
.zip(values.iter())
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value
&& let Some(new_value) = new_value
{
value_fn(group_index, new_value)
}
})
}
}
}
pub fn accumulate_multiple<T, F>(
group_indices: &[usize],
value_columns: &[&PrimitiveArray<T>],
opt_filter: Option<&BooleanArray>,
mut value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
{
let combined_nulls = value_columns
.iter()
.map(|arr| arr.logical_nulls())
.fold(None, |acc, nulls| {
NullBuffer::union(acc.as_ref(), nulls.as_ref())
});
let valid_indices = match (combined_nulls, opt_filter) {
(None, None) => None,
(None, Some(filter)) => Some(filter.clone()),
(Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)),
(Some(nulls), Some(filter)) => {
let combined = nulls.inner() & filter.values();
Some(BooleanArray::new(combined, None))
}
};
for col in value_columns.iter() {
debug_assert_eq!(col.len(), group_indices.len());
}
match valid_indices {
None => {
for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
value_fn(group_idx, batch_idx, value_columns);
}
}
Some(valid_indices) => {
for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
if valid_indices.value(batch_idx) {
value_fn(group_idx, batch_idx, value_columns);
}
}
}
}
}
pub fn accumulate_indices<F>(
group_indices: &[usize],
nulls: Option<&NullBuffer>,
opt_filter: Option<&BooleanArray>,
mut index_fn: F,
) where
F: FnMut(usize) + Send,
{
match (nulls, opt_filter) {
(None, None) => {
for &group_index in group_indices.iter() {
index_fn(group_index)
}
}
(None, Some(filter)) => {
debug_assert_eq!(filter.len(), group_indices.len());
let group_indices_chunks = group_indices.chunks_exact(64);
let bit_chunks = filter.values().bit_chunks();
let group_indices_remainder = group_indices_chunks.remainder();
group_indices_chunks.zip(bit_chunks.iter()).for_each(
|(group_index_chunk, mask)| {
let mut index_mask = 1;
group_index_chunk.iter().for_each(|&group_index| {
let is_valid = (mask & index_mask) != 0;
if is_valid {
index_fn(group_index);
}
index_mask <<= 1;
})
},
);
let remainder_bits = bit_chunks.remainder_bits();
group_indices_remainder
.iter()
.enumerate()
.for_each(|(i, &group_index)| {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
index_fn(group_index)
}
});
}
(Some(valids), None) => {
debug_assert_eq!(valids.len(), group_indices.len());
let group_indices_chunks = group_indices.chunks_exact(64);
let bit_chunks = valids.inner().bit_chunks();
let group_indices_remainder = group_indices_chunks.remainder();
group_indices_chunks.zip(bit_chunks.iter()).for_each(
|(group_index_chunk, mask)| {
let mut index_mask = 1;
group_index_chunk.iter().for_each(|&group_index| {
let is_valid = (mask & index_mask) != 0;
if is_valid {
index_fn(group_index);
}
index_mask <<= 1;
})
},
);
let remainder_bits = bit_chunks.remainder_bits();
group_indices_remainder
.iter()
.enumerate()
.for_each(|(i, &group_index)| {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
index_fn(group_index)
}
});
}
(Some(valids), Some(filter)) => {
debug_assert_eq!(filter.len(), group_indices.len());
debug_assert_eq!(valids.len(), group_indices.len());
let group_indices_chunks = group_indices.chunks_exact(64);
let valid_bit_chunks = valids.inner().bit_chunks();
let filter_bit_chunks = filter.values().bit_chunks();
let group_indices_remainder = group_indices_chunks.remainder();
group_indices_chunks
.zip(valid_bit_chunks.iter())
.zip(filter_bit_chunks.iter())
.for_each(|((group_index_chunk, valid_mask), filter_mask)| {
let mut index_mask = 1;
group_index_chunk.iter().for_each(|&group_index| {
let is_valid = (valid_mask & filter_mask & index_mask) != 0;
if is_valid {
index_fn(group_index);
}
index_mask <<= 1;
})
});
let remainder_valid_bits = valid_bit_chunks.remainder_bits();
let remainder_filter_bits = filter_bit_chunks.remainder_bits();
group_indices_remainder
.iter()
.enumerate()
.for_each(|(i, &group_index)| {
let is_valid =
remainder_valid_bits & remainder_filter_bits & (1 << i) != 0;
if is_valid {
index_fn(group_index)
}
});
}
}
}
fn initialize_builder(
builder: &mut BooleanBufferBuilder,
total_num_groups: usize,
default_value: bool,
) -> &mut BooleanBufferBuilder {
if builder.len() < total_num_groups {
let new_groups = total_num_groups - builder.len();
builder.append_n(new_groups, default_value);
}
builder
}
#[cfg(test)]
mod test {
use super::*;
use arrow::array::{Int32Array, UInt32Array};
use rand::{Rng, rngs::ThreadRng};
use std::collections::HashSet;
#[test]
fn accumulate() {
let group_indices = (0..100).collect();
let values = (0..100).map(|i| (i + 1) * 10).collect();
let values_with_nulls = (0..100)
.map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) })
.collect();
let filter: BooleanArray = (0..100)
.map(|i| {
let is_even = i % 2 == 0;
let is_fifth = i % 5 == 0;
if is_even {
None
} else if is_fifth {
Some(false)
} else {
Some(true)
}
})
.collect();
Fixture {
group_indices,
values,
values_with_nulls,
filter,
}
.run()
}
#[test]
fn accumulate_fuzz() {
let mut rng = rand::rng();
for _ in 0..100 {
Fixture::new_random(&mut rng).run();
}
}
struct Fixture {
group_indices: Vec<usize>,
values: Vec<u32>,
values_with_nulls: Vec<Option<u32>>,
filter: BooleanArray,
}
impl Fixture {
fn new_random(rng: &mut ThreadRng) -> Self {
let num_values: usize = rng.random_range(1..200);
let num_groups: usize = rng.random_range(2..1000);
let max_group = num_groups - 1;
let group_indices: Vec<usize> = (0..num_values)
.map(|_| rng.random_range(0..max_group))
.collect();
let values: Vec<u32> = (0..num_values).map(|_| rng.random()).collect();
let filter: BooleanArray = (0..num_values)
.map(|_| {
let filter_value = rng.random_range(0.0..1.0);
if filter_value < 0.1 {
Some(false)
} else if filter_value < 0.2 {
None
} else {
Some(true)
}
})
.collect();
let null_pct: f32 = rng.random_range(0.0..1.0);
let values_with_nulls: Vec<Option<u32>> = (0..num_values)
.map(|_| {
let is_null = null_pct < rng.random_range(0.0..1.0);
if is_null { None } else { Some(rng.random()) }
})
.collect();
Self {
group_indices,
values,
values_with_nulls,
filter,
}
}
fn values_array(&self) -> UInt32Array {
UInt32Array::from(self.values.clone())
}
fn values_with_nulls_array(&self) -> UInt32Array {
UInt32Array::from(self.values_with_nulls.clone())
}
fn run(&self) {
let total_num_groups = *self.group_indices.iter().max().unwrap() + 1;
let group_indices = &self.group_indices;
let values_array = self.values_array();
let values_with_nulls_array = self.values_with_nulls_array();
let filter = &self.filter;
Self::accumulate_test(group_indices, &values_array, None, total_num_groups);
Self::accumulate_test(
group_indices,
&values_with_nulls_array,
None,
total_num_groups,
);
Self::accumulate_test(
group_indices,
&values_array,
Some(filter),
total_num_groups,
);
Self::accumulate_test(
group_indices,
&values_with_nulls_array,
Some(filter),
total_num_groups,
);
}
fn accumulate_test(
group_indices: &[usize],
values: &UInt32Array,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) {
Self::accumulate_values_test(
group_indices,
values,
opt_filter,
total_num_groups,
);
Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter);
let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum();
let boolean_values: BooleanArray =
values.iter().map(|v| v.map(|v| v as usize > avg)).collect();
Self::accumulate_boolean_test(
group_indices,
&boolean_values,
opt_filter,
total_num_groups,
);
}
fn accumulate_values_test(
group_indices: &[usize],
values: &UInt32Array,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) {
let mut accumulated_values = vec![];
let mut null_state = NullState::new();
null_state.accumulate(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, value| {
accumulated_values.push((group_index, value));
},
);
let mut expected_values = vec![];
let mut mock = MockNullState::new();
match opt_filter {
None => group_indices.iter().zip(values.iter()).for_each(
|(&group_index, value)| {
if let Some(value) = value {
mock.saw_value(group_index);
expected_values.push((group_index, value));
}
},
),
Some(filter) => {
group_indices
.iter()
.zip(values.iter())
.zip(filter.iter())
.for_each(|((&group_index, value), is_included)| {
if let Some(true) = is_included
&& let Some(value) = value
{
mock.saw_value(group_index);
expected_values.push((group_index, value));
}
});
}
}
assert_eq!(
accumulated_values, expected_values,
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
);
let seen_values = null_state.seen_values.finish_cloned();
mock.validate_seen_values(&seen_values);
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
let null_buffer = null_state.build(EmitTo::All);
assert_eq!(null_buffer, expected_null_buffer);
}
fn accumulate_indices_test(
group_indices: &[usize],
nulls: Option<&NullBuffer>,
opt_filter: Option<&BooleanArray>,
) {
let mut accumulated_values = vec![];
accumulate_indices(group_indices, nulls, opt_filter, |group_index| {
accumulated_values.push(group_index);
});
let mut expected_values = vec![];
match (nulls, opt_filter) {
(None, None) => group_indices.iter().for_each(|&group_index| {
expected_values.push(group_index);
}),
(Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each(
|(&group_index, is_valid)| {
if is_valid {
expected_values.push(group_index);
}
},
),
(None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each(
|(&group_index, is_included)| {
if let Some(true) = is_included {
expected_values.push(group_index);
}
},
),
(Some(nulls), Some(filter)) => {
group_indices
.iter()
.zip(nulls.iter())
.zip(filter.iter())
.for_each(|((&group_index, is_valid), is_included)| {
if let (true, Some(true)) = (is_valid, is_included) {
expected_values.push(group_index);
}
});
}
}
assert_eq!(
accumulated_values, expected_values,
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
);
}
fn accumulate_boolean_test(
group_indices: &[usize],
values: &BooleanArray,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) {
let mut accumulated_values = vec![];
let mut null_state = NullState::new();
null_state.accumulate_boolean(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, value| {
accumulated_values.push((group_index, value));
},
);
let mut expected_values = vec![];
let mut mock = MockNullState::new();
match opt_filter {
None => group_indices.iter().zip(values.iter()).for_each(
|(&group_index, value)| {
if let Some(value) = value {
mock.saw_value(group_index);
expected_values.push((group_index, value));
}
},
),
Some(filter) => {
group_indices
.iter()
.zip(values.iter())
.zip(filter.iter())
.for_each(|((&group_index, value), is_included)| {
if let Some(true) = is_included
&& let Some(value) = value
{
mock.saw_value(group_index);
expected_values.push((group_index, value));
}
});
}
}
assert_eq!(
accumulated_values, expected_values,
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
);
let seen_values = null_state.seen_values.finish_cloned();
mock.validate_seen_values(&seen_values);
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
let null_buffer = null_state.build(EmitTo::All);
assert_eq!(null_buffer, expected_null_buffer);
}
}
#[derive(Debug, Default)]
struct MockNullState {
seen_values: HashSet<usize>,
}
impl MockNullState {
fn new() -> Self {
Default::default()
}
fn saw_value(&mut self, group_index: usize) {
self.seen_values.insert(group_index);
}
fn expected_seen(&self, group_index: usize) -> bool {
self.seen_values.contains(&group_index)
}
fn validate_seen_values(&self, seen_values: &BooleanBuffer) {
for (group_index, is_seen) in seen_values.iter().enumerate() {
let expected_seen = self.expected_seen(group_index);
assert_eq!(
expected_seen, is_seen,
"mismatch at for group {group_index}"
);
}
}
fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer {
(0..total_num_groups)
.map(|group_index| self.expected_seen(group_index))
.collect()
}
}
#[test]
fn test_accumulate_multiple_no_nulls_no_filter() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![1, 2, 3, 4]);
let values2 = Int32Array::from(vec![10, 20, 30, 40]);
let value_columns = [values1, values2];
let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
None,
|group_idx, batch_idx, columns| {
let values = columns.iter().map(|col| col.value(batch_idx)).collect();
accumulated.push((group_idx, values));
},
);
let expected = vec![
(0, vec![1, 10]),
(1, vec![2, 20]),
(0, vec![3, 30]),
(1, vec![4, 40]),
];
assert_eq!(accumulated, expected);
}
#[test]
fn test_accumulate_multiple_with_nulls() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
let value_columns = [values1, values2];
let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
None,
|group_idx, batch_idx, columns| {
let values = columns.iter().map(|col| col.value(batch_idx)).collect();
accumulated.push((group_idx, values));
},
);
let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])];
assert_eq!(accumulated, expected);
}
#[test]
fn test_accumulate_multiple_with_filter() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![1, 2, 3, 4]);
let values2 = Int32Array::from(vec![10, 20, 30, 40]);
let value_columns = [values1, values2];
let filter = BooleanArray::from(vec![true, false, true, false]);
let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
Some(&filter),
|group_idx, batch_idx, columns| {
let values = columns.iter().map(|col| col.value(batch_idx)).collect();
accumulated.push((group_idx, values));
},
);
let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
assert_eq!(accumulated, expected);
}
#[test]
fn test_accumulate_multiple_with_nulls_and_filter() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
let value_columns = [values1, values2];
let filter = BooleanArray::from(vec![true, true, true, false]);
let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
Some(&filter),
|group_idx, batch_idx, columns| {
let values = columns.iter().map(|col| col.value(batch_idx)).collect();
accumulated.push((group_idx, values));
},
);
let expected = [(0, vec![1, 10])];
assert_eq!(accumulated, expected);
}
}