use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use crate::AnyCanonical;
use crate::ArrayRef;
use crate::Columnar;
use crate::ExecutionCtx;
use crate::aggregate_fn::AggregateFn;
use crate::aggregate_fn::AggregateFnRef;
use crate::aggregate_fn::AggregateFnVTable;
use crate::aggregate_fn::session::AggregateFnSessionExt;
use crate::dtype::DType;
use crate::executor::MAX_ITERATIONS;
use crate::scalar::Scalar;
pub type AccumulatorRef = Box<dyn DynAccumulator>;
pub struct Accumulator<V: AggregateFnVTable> {
vtable: V,
aggregate_fn: AggregateFnRef,
dtype: DType,
return_dtype: DType,
partial_dtype: DType,
partial: V::Partial,
}
impl<V: AggregateFnVTable> Accumulator<V> {
pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
vortex_err!(
"Aggregate function {} cannot be applied to dtype {}",
vtable.id(),
dtype
)
})?;
let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
vortex_err!(
"Aggregate function {} cannot be applied to dtype {}",
vtable.id(),
dtype
)
})?;
let partial = vtable.empty_partial(&options, &dtype)?;
let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
Ok(Self {
vtable,
aggregate_fn,
dtype,
return_dtype,
partial_dtype,
partial,
})
}
}
pub trait DynAccumulator: 'static + Send {
fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
fn is_saturated(&self) -> bool;
fn flush(&mut self) -> VortexResult<Scalar>;
fn finish(&mut self) -> VortexResult<Scalar>;
}
impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
if self.is_saturated() {
return Ok(());
}
vortex_ensure!(
batch.dtype() == &self.dtype,
"Input DType mismatch: expected {}, got {}",
self.dtype,
batch.dtype()
);
if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
return Ok(());
}
let session = ctx.session().clone();
let kernels = &session.aggregate_fns().kernels;
let mut batch = batch.clone();
for _ in 0..*MAX_ITERATIONS {
if batch.is::<AnyCanonical>() {
break;
}
let kernels_r = kernels.read();
let batch_id = batch.encoding_id();
if let Some(result) = kernels_r
.get(&(batch_id.clone(), Some(self.aggregate_fn.id())))
.or_else(|| kernels_r.get(&(batch_id, None)))
.and_then(|kernel| {
kernel
.aggregate(&self.aggregate_fn, &batch, ctx)
.transpose()
})
.transpose()?
{
vortex_ensure!(
result.dtype() == &self.partial_dtype,
"Aggregate kernel returned {}, expected {}",
result.dtype(),
self.partial_dtype,
);
self.vtable.combine_partials(&mut self.partial, result)?;
return Ok(());
}
batch = batch.execute(ctx)?;
}
let columnar = batch.execute::<Columnar>(ctx)?;
self.vtable.accumulate(&mut self.partial, &columnar, ctx)
}
fn is_saturated(&self) -> bool {
self.vtable.is_saturated(&self.partial)
}
fn flush(&mut self) -> VortexResult<Scalar> {
let partial = self.vtable.to_scalar(&self.partial)?;
self.vtable.reset(&mut self.partial);
#[cfg(debug_assertions)]
{
vortex_ensure!(
partial.dtype() == &self.partial_dtype,
"Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
self.partial_dtype,
partial.dtype(),
);
}
Ok(partial)
}
fn finish(&mut self) -> VortexResult<Scalar> {
let result = self.vtable.finalize_scalar(&self.partial)?;
self.vtable.reset(&mut self.partial);
vortex_ensure!(
result.dtype() == &self.return_dtype,
"Aggregate kernel returned incorrect DType on finalize: expected {}, got {}",
self.return_dtype,
result.dtype(),
);
Ok(result)
}
}