use arrow_buffer::ArrowNativeType;
use vortex_buffer::Buffer;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_ensure;
use vortex_error::vortex_panic;
use vortex_mask::Mask;
use vortex_session::VortexSession;
use crate::AnyCanonical;
use crate::ArrayRef;
use crate::Canonical;
use crate::DynArray;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::VortexSessionExecute;
use crate::aggregate_fn::Accumulator;
use crate::aggregate_fn::AggregateFn;
use crate::aggregate_fn::AggregateFnRef;
use crate::aggregate_fn::AggregateFnVTable;
use crate::aggregate_fn::DynAccumulator;
use crate::aggregate_fn::session::AggregateFnSessionExt;
use crate::arrays::ChunkedArray;
use crate::arrays::FixedSizeListArray;
use crate::arrays::ListViewArray;
use crate::builders::builder_with_capacity;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::IntegerPType;
use crate::executor::MAX_ITERATIONS;
use crate::match_each_integer_ptype;
use crate::vtable::ValidityHelper;
pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
pub struct GroupedAccumulator<V: AggregateFnVTable> {
vtable: V,
options: V::Options,
aggregate_fn: AggregateFnRef,
dtype: DType,
return_dtype: DType,
partial_dtype: DType,
partials: Vec<ArrayRef>,
session: VortexSession,
}
impl<V: AggregateFnVTable> GroupedAccumulator<V> {
pub fn try_new(
vtable: V,
options: V::Options,
dtype: DType,
session: VortexSession,
) -> VortexResult<Self> {
let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
let return_dtype = vtable.return_dtype(&options, &dtype)?;
let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
Ok(Self {
vtable,
options,
aggregate_fn,
dtype,
return_dtype,
partial_dtype,
partials: vec![],
session,
})
}
}
pub trait DynGroupedAccumulator: 'static + Send {
fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()>;
fn flush(&mut self) -> VortexResult<ArrayRef>;
fn finish(&mut self) -> VortexResult<ArrayRef>;
}
impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()> {
let elements_dtype = match groups.dtype() {
DType::List(elem, _) => elem,
DType::FixedSizeList(elem, ..) => elem,
_ => vortex_bail!(
"Input DType mismatch: expected List or FixedSizeList, got {}",
groups.dtype()
),
};
vortex_ensure!(
elements_dtype.as_ref() == &self.dtype,
"Input DType mismatch: expected {}, got {}",
self.dtype,
elements_dtype
);
let mut ctx = self.session.create_execution_ctx();
match groups.clone().execute::<Canonical>(&mut ctx)? {
Canonical::List(groups) => self.accumulate_list_view(&groups, &mut ctx),
Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, &mut ctx),
_ => vortex_panic!("We checked the DType above, so this should never happen"),
}
}
fn flush(&mut self) -> VortexResult<ArrayRef> {
let states = std::mem::take(&mut self.partials);
Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
}
fn finish(&mut self) -> VortexResult<ArrayRef> {
let states = self.flush()?;
let results = self.vtable.finalize(states)?;
vortex_ensure!(
results.dtype() == &self.return_dtype,
"Return DType mismatch: expected {}, got {}",
self.return_dtype,
results.dtype()
);
Ok(results)
}
}
impl<V: AggregateFnVTable> GroupedAccumulator<V> {
fn accumulate_list_view(
&mut self,
groups: &ListViewArray,
ctx: &mut ExecutionCtx,
) -> VortexResult<()> {
let mut elements = groups.elements().clone();
let session = self.session.clone();
let kernels = &session.aggregate_fns().grouped_kernels;
for _ in 0..*MAX_ITERATIONS {
if elements.is::<AnyCanonical>() {
break;
}
let kernels_r = kernels.read();
if let Some(result) = kernels_r
.get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
.or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
.and_then(|kernel| {
let groups = unsafe {
ListViewArray::new_unchecked(
elements.clone(),
groups.offsets().clone(),
groups.sizes().clone(),
groups.validity().clone(),
)
};
kernel
.grouped_aggregate(&self.aggregate_fn, &groups)
.transpose()
})
.transpose()?
{
return self.push_result(result);
}
elements = elements.execute(ctx)?;
}
let elements = elements.execute::<Canonical>(ctx)?.into_array();
let offsets = groups.offsets();
let sizes = groups.sizes().cast(offsets.dtype().clone())?;
let validity = groups.validity().to_mask(offsets.len());
match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
let sizes = sizes.execute::<Buffer<O>>(ctx)?;
self.accumulate_list_view_typed(&elements, offsets.as_ref(), sizes.as_ref(), &validity)
})
}
fn accumulate_list_view_typed<O: IntegerPType>(
&mut self,
elements: &ArrayRef,
offsets: &[O],
sizes: &[O],
validity: &Mask,
) -> VortexResult<()> {
let mut accumulator = Accumulator::try_new(
self.vtable.clone(),
self.options.clone(),
self.dtype.clone(),
self.session.clone(),
)?;
let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
for (offset, size) in offsets.iter().zip(sizes.iter()) {
let offset = offset.to_usize().vortex_expect("Offset value is not usize");
let size = size.to_usize().vortex_expect("Size value is not usize");
if validity.value(offset) {
let group = elements.slice(offset..offset + size)?;
accumulator.accumulate(&group)?;
states.append_scalar(&accumulator.finish()?)?;
} else {
states.append_null()
}
}
self.push_result(states.finish())
}
fn accumulate_fixed_size_list(
&mut self,
groups: &FixedSizeListArray,
ctx: &mut ExecutionCtx,
) -> VortexResult<()> {
let mut elements = groups.elements().clone();
let session = self.session.clone();
let kernels = &session.aggregate_fns().grouped_kernels;
for _ in 0..64 {
if elements.is::<AnyCanonical>() {
break;
}
let kernels_r = kernels.read();
if let Some(result) = kernels_r
.get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
.or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
.and_then(|kernel| {
let groups = unsafe {
FixedSizeListArray::new_unchecked(
elements.clone(),
groups.list_size(),
groups.validity().clone(),
groups.len(),
)
};
kernel
.grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)
.transpose()
})
.transpose()?
{
return self.push_result(result);
}
elements = elements.execute(ctx)?;
}
let elements = elements.execute::<Canonical>(ctx)?.into_array();
let validity = groups.validity().to_mask(groups.len());
let mut accumulator = Accumulator::try_new(
self.vtable.clone(),
self.options.clone(),
self.dtype.clone(),
self.session.clone(),
)?;
let mut states = builder_with_capacity(&self.partial_dtype, groups.len());
let mut offset = 0;
let size = groups
.list_size()
.to_usize()
.vortex_expect("List size is not usize");
for i in 0..groups.len() {
if validity.value(i) {
let group = elements.slice(offset..offset + size)?;
accumulator.accumulate(&group)?;
states.append_scalar(&accumulator.finish()?)?;
} else {
states.append_null()
}
offset += size;
}
self.push_result(states.finish())
}
fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
vortex_ensure!(
state.dtype() == &self.partial_dtype,
"State DType mismatch: expected {}, got {}",
self.partial_dtype,
state.dtype()
);
self.partials.push(state);
Ok(())
}
}