use arrow::datatypes::ArrowPrimitiveType;
use arrow_array::{Array, BooleanArray, PrimitiveArray};
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
use crate::EmitTo;
#[derive(Debug)]
pub struct NullState {
seen_values: BooleanBufferBuilder,
}
impl NullState {
pub fn new() -> Self {
Self {
seen_values: BooleanBufferBuilder::new(0),
}
}
pub fn size(&self) -> usize {
self.seen_values.capacity() / 8
}
pub fn accumulate<T, F>(
&mut self,
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
mut value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
{
let data: &[T::Native] = values.values();
assert_eq!(data.len(), group_indices.len());
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);
match (values.null_count() > 0, opt_filter) {
(false, None) => {
let iter = group_indices.iter().zip(data.iter());
for (&group_index, &new_value) in iter {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
}
(true, None) => {
let nulls = values.nulls().unwrap();
let group_indices_chunks = group_indices.chunks_exact(64);
let data_chunks = data.chunks_exact(64);
let bit_chunks = nulls.inner().bit_chunks();
let group_indices_remainder = group_indices_chunks.remainder();
let data_remainder = data_chunks.remainder();
group_indices_chunks
.zip(data_chunks)
.zip(bit_chunks.iter())
.for_each(|((group_index_chunk, data_chunk), mask)| {
let mut index_mask = 1;
group_index_chunk.iter().zip(data_chunk.iter()).for_each(
|(&group_index, &new_value)| {
let is_valid = (mask & index_mask) != 0;
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
index_mask <<= 1;
},
)
});
let remainder_bits = bit_chunks.remainder_bits();
group_indices_remainder
.iter()
.zip(data_remainder.iter())
.enumerate()
.for_each(|(i, (&group_index, &new_value))| {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
});
}
(false, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
group_indices
.iter()
.zip(data.iter())
.zip(filter.iter())
.for_each(|((&group_index, &new_value), filter_value)| {
if let Some(true) = filter_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
})
}
(true, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
filter
.iter()
.zip(group_indices.iter())
.zip(values.iter())
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
}
}
})
}
}
}
pub fn accumulate_boolean<F>(
&mut self,
group_indices: &[usize],
values: &BooleanArray,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
mut value_fn: F,
) where
F: FnMut(usize, bool) + Send,
{
let data = values.values();
assert_eq!(data.len(), group_indices.len());
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);
match (values.null_count() > 0, opt_filter) {
(false, None) => {
group_indices.iter().zip(data.iter()).for_each(
|(&group_index, new_value)| {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
},
)
}
(true, None) => {
let nulls = values.nulls().unwrap();
group_indices
.iter()
.zip(data.iter())
.zip(nulls.iter())
.for_each(|((&group_index, new_value), is_valid)| {
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
})
}
(false, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
group_indices
.iter()
.zip(data.iter())
.zip(filter.iter())
.for_each(|((&group_index, new_value), filter_value)| {
if let Some(true) = filter_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
})
}
(true, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
filter
.iter()
.zip(group_indices.iter())
.zip(values.iter())
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
}
}
})
}
}
}
pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
let nulls: BooleanBuffer = self.seen_values.finish();
let nulls = match emit_to {
EmitTo::All => nulls,
EmitTo::First(n) => {
let first_n_null: BooleanBuffer = nulls.iter().take(n).collect();
for seen in nulls.iter().skip(n) {
self.seen_values.append(seen);
}
first_n_null
}
};
NullBuffer::new(nulls)
}
}
pub fn accumulate_indices<F>(
group_indices: &[usize],
nulls: Option<&NullBuffer>,
opt_filter: Option<&BooleanArray>,
mut index_fn: F,
) where
F: FnMut(usize) + Send,
{
match (nulls, opt_filter) {
(None, None) => {
for &group_index in group_indices.iter() {
index_fn(group_index)
}
}
(None, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
let iter = group_indices.iter().zip(filter.iter());
for (&group_index, filter_value) in iter {
if let Some(true) = filter_value {
index_fn(group_index)
}
}
}
(Some(valids), None) => {
assert_eq!(valids.len(), group_indices.len());
let group_indices_chunks = group_indices.chunks_exact(64);
let bit_chunks = valids.inner().bit_chunks();
let group_indices_remainder = group_indices_chunks.remainder();
group_indices_chunks.zip(bit_chunks.iter()).for_each(
|(group_index_chunk, mask)| {
let mut index_mask = 1;
group_index_chunk.iter().for_each(|&group_index| {
let is_valid = (mask & index_mask) != 0;
if is_valid {
index_fn(group_index);
}
index_mask <<= 1;
})
},
);
let remainder_bits = bit_chunks.remainder_bits();
group_indices_remainder
.iter()
.enumerate()
.for_each(|(i, &group_index)| {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
index_fn(group_index)
}
});
}
(Some(valids), Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
assert_eq!(valids.len(), group_indices.len());
filter
.iter()
.zip(group_indices.iter())
.zip(valids.iter())
.for_each(|((filter_value, &group_index), is_valid)| {
if let (Some(true), true) = (filter_value, is_valid) {
index_fn(group_index)
}
})
}
}
}
fn initialize_builder(
builder: &mut BooleanBufferBuilder,
total_num_groups: usize,
default_value: bool,
) -> &mut BooleanBufferBuilder {
if builder.len() < total_num_groups {
let new_groups = total_num_groups - builder.len();
builder.append_n(new_groups, default_value);
}
builder
}
#[cfg(test)]
mod test {
use super::*;
use arrow_array::UInt32Array;
use arrow_buffer::BooleanBuffer;
use hashbrown::HashSet;
use rand::{rngs::ThreadRng, Rng};
#[test]
fn accumulate() {
let group_indices = (0..100).collect();
let values = (0..100).map(|i| (i + 1) * 10).collect();
let values_with_nulls = (0..100)
.map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) })
.collect();
let filter: BooleanArray = (0..100)
.map(|i| {
let is_even = i % 2 == 0;
let is_fifth = i % 5 == 0;
if is_even {
None
} else if is_fifth {
Some(false)
} else {
Some(true)
}
})
.collect();
Fixture {
group_indices,
values,
values_with_nulls,
filter,
}
.run()
}
#[test]
fn accumulate_fuzz() {
let mut rng = rand::thread_rng();
for _ in 0..100 {
Fixture::new_random(&mut rng).run();
}
}
struct Fixture {
group_indices: Vec<usize>,
values: Vec<u32>,
values_with_nulls: Vec<Option<u32>>,
filter: BooleanArray,
}
impl Fixture {
fn new_random(rng: &mut ThreadRng) -> Self {
let num_values: usize = rng.gen_range(1..200);
let num_groups: usize = rng.gen_range(2..1000);
let max_group = num_groups - 1;
let group_indices: Vec<usize> = (0..num_values)
.map(|_| rng.gen_range(0..max_group))
.collect();
let values: Vec<u32> = (0..num_values).map(|_| rng.gen()).collect();
let filter: BooleanArray = (0..num_values)
.map(|_| {
let filter_value = rng.gen_range(0.0..1.0);
if filter_value < 0.1 {
Some(false)
} else if filter_value < 0.2 {
None
} else {
Some(true)
}
})
.collect();
let null_pct: f32 = rng.gen_range(0.0..1.0);
let values_with_nulls: Vec<Option<u32>> = (0..num_values)
.map(|_| {
let is_null = null_pct < rng.gen_range(0.0..1.0);
if is_null {
None
} else {
Some(rng.gen())
}
})
.collect();
Self {
group_indices,
values,
values_with_nulls,
filter,
}
}
fn values_array(&self) -> UInt32Array {
UInt32Array::from(self.values.clone())
}
fn values_with_nulls_array(&self) -> UInt32Array {
UInt32Array::from(self.values_with_nulls.clone())
}
fn run(&self) {
let total_num_groups = *self.group_indices.iter().max().unwrap() + 1;
let group_indices = &self.group_indices;
let values_array = self.values_array();
let values_with_nulls_array = self.values_with_nulls_array();
let filter = &self.filter;
Self::accumulate_test(group_indices, &values_array, None, total_num_groups);
Self::accumulate_test(
group_indices,
&values_with_nulls_array,
None,
total_num_groups,
);
Self::accumulate_test(
group_indices,
&values_array,
Some(filter),
total_num_groups,
);
Self::accumulate_test(
group_indices,
&values_with_nulls_array,
Some(filter),
total_num_groups,
);
}
fn accumulate_test(
group_indices: &[usize],
values: &UInt32Array,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) {
Self::accumulate_values_test(
group_indices,
values,
opt_filter,
total_num_groups,
);
Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter);
let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum();
let boolean_values: BooleanArray =
values.iter().map(|v| v.map(|v| v as usize > avg)).collect();
Self::accumulate_boolean_test(
group_indices,
&boolean_values,
opt_filter,
total_num_groups,
);
}
fn accumulate_values_test(
group_indices: &[usize],
values: &UInt32Array,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) {
let mut accumulated_values = vec![];
let mut null_state = NullState::new();
null_state.accumulate(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, value| {
accumulated_values.push((group_index, value));
},
);
let mut expected_values = vec![];
let mut mock = MockNullState::new();
match opt_filter {
None => group_indices.iter().zip(values.iter()).for_each(
|(&group_index, value)| {
if let Some(value) = value {
mock.saw_value(group_index);
expected_values.push((group_index, value));
}
},
),
Some(filter) => {
group_indices
.iter()
.zip(values.iter())
.zip(filter.iter())
.for_each(|((&group_index, value), is_included)| {
if let Some(true) = is_included {
if let Some(value) = value {
mock.saw_value(group_index);
expected_values.push((group_index, value));
}
}
});
}
}
assert_eq!(accumulated_values, expected_values,
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}");
let seen_values = null_state.seen_values.finish_cloned();
mock.validate_seen_values(&seen_values);
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
let null_buffer = null_state.build(EmitTo::All);
assert_eq!(null_buffer, expected_null_buffer);
}
fn accumulate_indices_test(
group_indices: &[usize],
nulls: Option<&NullBuffer>,
opt_filter: Option<&BooleanArray>,
) {
let mut accumulated_values = vec![];
accumulate_indices(group_indices, nulls, opt_filter, |group_index| {
accumulated_values.push(group_index);
});
let mut expected_values = vec![];
match (nulls, opt_filter) {
(None, None) => group_indices.iter().for_each(|&group_index| {
expected_values.push(group_index);
}),
(Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each(
|(&group_index, is_valid)| {
if is_valid {
expected_values.push(group_index);
}
},
),
(None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each(
|(&group_index, is_included)| {
if let Some(true) = is_included {
expected_values.push(group_index);
}
},
),
(Some(nulls), Some(filter)) => {
group_indices
.iter()
.zip(nulls.iter())
.zip(filter.iter())
.for_each(|((&group_index, is_valid), is_included)| {
if let (true, Some(true)) = (is_valid, is_included) {
expected_values.push(group_index);
}
});
}
}
assert_eq!(accumulated_values, expected_values,
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}");
}
fn accumulate_boolean_test(
group_indices: &[usize],
values: &BooleanArray,
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) {
let mut accumulated_values = vec![];
let mut null_state = NullState::new();
null_state.accumulate_boolean(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, value| {
accumulated_values.push((group_index, value));
},
);
let mut expected_values = vec![];
let mut mock = MockNullState::new();
match opt_filter {
None => group_indices.iter().zip(values.iter()).for_each(
|(&group_index, value)| {
if let Some(value) = value {
mock.saw_value(group_index);
expected_values.push((group_index, value));
}
},
),
Some(filter) => {
group_indices
.iter()
.zip(values.iter())
.zip(filter.iter())
.for_each(|((&group_index, value), is_included)| {
if let Some(true) = is_included {
if let Some(value) = value {
mock.saw_value(group_index);
expected_values.push((group_index, value));
}
}
});
}
}
assert_eq!(accumulated_values, expected_values,
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}");
let seen_values = null_state.seen_values.finish_cloned();
mock.validate_seen_values(&seen_values);
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
let null_buffer = null_state.build(EmitTo::All);
assert_eq!(null_buffer, expected_null_buffer);
}
}
#[derive(Debug, Default)]
struct MockNullState {
seen_values: HashSet<usize>,
}
impl MockNullState {
fn new() -> Self {
Default::default()
}
fn saw_value(&mut self, group_index: usize) {
self.seen_values.insert(group_index);
}
fn expected_seen(&self, group_index: usize) -> bool {
self.seen_values.contains(&group_index)
}
fn validate_seen_values(&self, seen_values: &BooleanBuffer) {
for (group_index, is_seen) in seen_values.iter().enumerate() {
let expected_seen = self.expected_seen(group_index);
assert_eq!(
expected_seen, is_seen,
"mismatch at for group {group_index}"
);
}
}
fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer {
(0..total_num_groups)
.map(|group_index| self.expected_seen(group_index))
.collect()
}
}
}