use std::mem::size_of;
use std::sync::Arc;
use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray};
use arrow::buffer::NullBuffer;
use arrow::compute;
use arrow::datatypes::ArrowPrimitiveType;
use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result, internal_datafusion_err};
use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
use super::accumulate::NullState;
#[derive(Debug)]
pub struct PrimitiveGroupsAccumulator<T, F>
where
T: ArrowPrimitiveType + Send,
F: Fn(&mut T::Native, T::Native) + Send + Sync,
{
values: Vec<T::Native>,
data_type: DataType,
starting_value: T::Native,
null_state: NullState,
prim_fn: F,
}
impl<T, F> PrimitiveGroupsAccumulator<T, F>
where
T: ArrowPrimitiveType + Send,
F: Fn(&mut T::Native, T::Native) + Send + Sync,
{
pub fn new(data_type: &DataType, prim_fn: F) -> Self {
Self {
values: vec![],
data_type: data_type.clone(),
null_state: NullState::new(),
starting_value: T::default_value(),
prim_fn,
}
}
pub fn with_starting_value(mut self, starting_value: T::Native) -> Self {
self.starting_value = starting_value;
self
}
}
impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F>
where
T: ArrowPrimitiveType + Send,
F: Fn(&mut T::Native, T::Native) + Send + Sync,
{
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<T>();
self.values.resize(total_num_groups, self.starting_value);
self.null_state.accumulate(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
let value = &mut self.values[group_index];
(self.prim_fn)(value, new_value);
},
);
Ok(())
}
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let values = emit_to.take_needed(&mut self.values);
let nulls = self.null_state.build(emit_to);
let values = PrimitiveArray::<T>::new(values.into(), Some(nulls)) .with_data_type(self.data_type.clone());
Ok(Arc::new(values))
}
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
self.evaluate(emit_to).map(|arr| vec![arr])
}
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
self.update_batch(values, group_indices, opt_filter, total_num_groups)
}
fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let values = values[0].as_primitive::<T>().clone();
let initial_state =
PrimitiveArray::<T>::from_value(self.starting_value, values.len());
let values = match opt_filter {
None => values,
Some(filter) => {
let (filter_values, filter_nulls) = filter.clone().into_parts();
let filter_bool = match filter_nulls {
Some(filter_nulls) => filter_nulls.inner() & &filter_values,
None => filter_values,
};
let filter_nulls = NullBuffer::from(filter_bool);
let (dt, values_buf, original_nulls) = values.into_parts();
let nulls_buf =
NullBuffer::union(original_nulls.as_ref(), Some(&filter_nulls));
PrimitiveArray::<T>::new(values_buf, nulls_buf).with_data_type(dt)
}
};
let state_values = compute::binary_mut(initial_state, &values, |mut x, y| {
(self.prim_fn)(&mut x, y);
x
});
let state_values = state_values
.map_err(|_| {
internal_datafusion_err!(
"initial_values underlying buffer must not be shared"
)
})?
.map_err(DataFusionError::from)?
.with_data_type(self.data_type.clone());
Ok(vec![Arc::new(state_values)])
}
fn supports_convert_to_state(&self) -> bool {
true
}
fn size(&self) -> usize {
self.values.capacity() * size_of::<T::Native>() + self.null_state.size()
}
}