use std::sync::Arc;
use arrow::{array::AsArray, datatypes::ArrowPrimitiveType};
use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray};
use arrow_schema::DataType;
use datafusion_common::Result;
use crate::GroupsAccumulator;
use super::{accumulate::NullState, EmitTo};
#[derive(Debug)]
pub struct PrimitiveGroupsAccumulator<T, F>
where
T: ArrowPrimitiveType + Send,
F: Fn(&mut T::Native, T::Native) + Send + Sync,
{
values: Vec<T::Native>,
data_type: DataType,
starting_value: T::Native,
null_state: NullState,
prim_fn: F,
}
impl<T, F> PrimitiveGroupsAccumulator<T, F>
where
T: ArrowPrimitiveType + Send,
F: Fn(&mut T::Native, T::Native) + Send + Sync,
{
pub fn new(data_type: &DataType, prim_fn: F) -> Self {
Self {
values: vec![],
data_type: data_type.clone(),
null_state: NullState::new(),
starting_value: T::default_value(),
prim_fn,
}
}
pub fn with_starting_value(mut self, starting_value: T::Native) -> Self {
self.starting_value = starting_value;
self
}
}
impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F>
where
T: ArrowPrimitiveType + Send,
F: Fn(&mut T::Native, T::Native) + Send + Sync,
{
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<T>();
self.values.resize(total_num_groups, self.starting_value);
self.null_state.accumulate(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
let value = &mut self.values[group_index];
(self.prim_fn)(value, new_value);
},
);
Ok(())
}
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let values = emit_to.take_needed(&mut self.values);
let nulls = self.null_state.build(emit_to);
let values = PrimitiveArray::<T>::new(values.into(), Some(nulls)) .with_data_type(self.data_type.clone());
Ok(Arc::new(values))
}
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
self.evaluate(emit_to).map(|arr| vec![arr])
}
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
self.update_batch(values, group_indices, opt_filter, total_num_groups)
}
fn size(&self) -> usize {
self.values.capacity() * std::mem::size_of::<T::Native>() + self.null_state.size()
}
}