use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use crate::ArrayRef;
use crate::Columnar;
use crate::ExecutionCtx;
use crate::aggregate_fn::AggregateFn;
use crate::aggregate_fn::AggregateFnRef;
use crate::aggregate_fn::AggregateFnVTable;
use crate::aggregate_fn::session::AggregateFnSessionExt;
use crate::columnar::AnyColumnar;
use crate::dtype::DType;
use crate::executor::max_iterations;
use crate::expr::stats::Precision;
use crate::expr::stats::Stat;
use crate::expr::stats::StatsProvider;
use crate::scalar::Scalar;
pub type AccumulatorRef = Box<dyn DynAccumulator>;
pub struct Accumulator<V: AggregateFnVTable> {
vtable: V,
aggregate_fn: AggregateFnRef,
dtype: DType,
return_dtype: DType,
partial_dtype: DType,
partial: V::Partial,
}
impl<V: AggregateFnVTable> Accumulator<V> {
pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
vortex_err!(
"Aggregate function {} cannot be applied to dtype {}",
vtable.id(),
dtype
)
})?;
let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
vortex_err!(
"Aggregate function {} cannot be applied to dtype {}",
vtable.id(),
dtype
)
})?;
let partial = vtable.empty_partial(&options, &dtype)?;
let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
Ok(Self {
vtable,
aggregate_fn,
dtype,
return_dtype,
partial_dtype,
partial,
})
}
}
pub trait DynAccumulator: 'static + Send {
fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
fn combine_partials(&mut self, other: Scalar) -> VortexResult<()>;
fn is_saturated(&self) -> bool;
fn reset(&mut self);
fn partial_scalar(&self) -> VortexResult<Scalar>;
fn final_scalar(&self) -> VortexResult<Scalar>;
fn flush(&mut self) -> VortexResult<Scalar>;
fn finish(&mut self) -> VortexResult<Scalar>;
}
impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
if self.is_saturated() {
return Ok(());
}
vortex_ensure!(
batch.dtype() == &self.dtype,
"Input DType mismatch: expected {}, got {}",
self.dtype,
batch.dtype()
);
if let Some(stat) = Stat::from_aggregate_fn(&self.aggregate_fn)
&& let Some(Precision::Exact(partial)) = batch.statistics().get(stat)
{
vortex_ensure!(
partial.dtype() == &self.partial_dtype,
"Aggregate {} read legacy stat {} with dtype {}, expected {}",
self.aggregate_fn,
stat,
partial.dtype(),
self.partial_dtype,
);
self.vtable.combine_partials(&mut self.partial, partial)?;
return Ok(());
}
let session = ctx.session().clone();
let kernels = &session.aggregate_fns().kernels;
{
let kernels_r = kernels.read();
let batch_id = batch.encoding_id();
let kernel = kernels_r
.get(&(batch_id, Some(self.aggregate_fn.id())))
.or_else(|| kernels_r.get(&(batch_id, None)))
.copied();
drop(kernels_r);
if let Some(kernel) = kernel
&& let Some(result) = kernel.aggregate(&self.aggregate_fn, batch, ctx)?
{
vortex_ensure!(
result.dtype() == &self.partial_dtype,
"Aggregate kernel returned {}, expected {}",
result.dtype(),
self.partial_dtype,
);
self.vtable.combine_partials(&mut self.partial, result)?;
return Ok(());
}
}
if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
return Ok(());
}
let mut batch = batch.clone();
for _ in 0..max_iterations() {
if batch.is::<AnyColumnar>() {
break;
}
let kernels_r = kernels.read();
let batch_id = batch.encoding_id();
let kernel = kernels_r
.get(&(batch_id, Some(self.aggregate_fn.id())))
.or_else(|| kernels_r.get(&(batch_id, None)))
.copied();
drop(kernels_r);
if let Some(kernel) = kernel
&& let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch, ctx)?
{
vortex_ensure!(
result.dtype() == &self.partial_dtype,
"Aggregate kernel returned {}, expected {}",
result.dtype(),
self.partial_dtype,
);
self.vtable.combine_partials(&mut self.partial, result)?;
return Ok(());
}
batch = batch.execute(ctx)?;
}
let columnar = batch.execute::<Columnar>(ctx)?;
self.vtable.accumulate(&mut self.partial, &columnar, ctx)
}
fn combine_partials(&mut self, other: Scalar) -> VortexResult<()> {
self.vtable.combine_partials(&mut self.partial, other)
}
fn is_saturated(&self) -> bool {
self.vtable.is_saturated(&self.partial)
}
fn reset(&mut self) {
self.vtable.reset(&mut self.partial);
}
fn partial_scalar(&self) -> VortexResult<Scalar> {
let partial = self.vtable.to_scalar(&self.partial)?;
#[cfg(debug_assertions)]
{
vortex_ensure!(
partial.dtype() == &self.partial_dtype,
"Aggregate returned incorrect DType on partial_scalar: expected {}, got {}",
self.partial_dtype,
partial.dtype(),
);
}
Ok(partial)
}
fn final_scalar(&self) -> VortexResult<Scalar> {
let result = self.vtable.finalize_scalar(&self.partial)?;
vortex_ensure!(
result.dtype() == &self.return_dtype,
"Aggregate returned incorrect DType on final_scalar: expected {}, got {}",
self.return_dtype,
result.dtype(),
);
Ok(result)
}
fn flush(&mut self) -> VortexResult<Scalar> {
let partial = self.partial_scalar()?;
self.reset();
Ok(partial)
}
fn finish(&mut self) -> VortexResult<Scalar> {
let result = self.final_scalar()?;
self.reset();
Ok(result)
}
}
#[cfg(test)]
mod tests {
use vortex_buffer::buffer;
use vortex_error::VortexResult;
use vortex_session::SessionExt;
use vortex_session::VortexSession;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::VortexSessionExecute;
use crate::aggregate_fn::Accumulator;
use crate::aggregate_fn::AggregateFnRef;
use crate::aggregate_fn::AggregateFnVTable;
use crate::aggregate_fn::DynAccumulator;
use crate::aggregate_fn::EmptyOptions;
use crate::aggregate_fn::combined::Combined;
use crate::aggregate_fn::combined::PairOptions;
use crate::aggregate_fn::fns::mean::Mean;
use crate::aggregate_fn::fns::sum::Sum;
use crate::aggregate_fn::kernels::DynAggregateKernel;
use crate::aggregate_fn::session::AggregateFnSession;
use crate::array::VTable;
use crate::arrays::Dict;
use crate::arrays::DictArray;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::scalar::Scalar;
use crate::session::ArraySession;
#[derive(Debug)]
struct SentinelMeanPartialKernel;
impl DynAggregateKernel for SentinelMeanPartialKernel {
fn aggregate(
&self,
_aggregate_fn: &AggregateFnRef,
_batch: &ArrayRef,
_ctx: &mut ExecutionCtx,
) -> VortexResult<Option<Scalar>> {
Ok(Some(sentinel_partial()))
}
}
#[derive(Debug)]
struct DeclineKernel;
impl DynAggregateKernel for DeclineKernel {
fn aggregate(
&self,
_aggregate_fn: &AggregateFnRef,
_batch: &ArrayRef,
_ctx: &mut ExecutionCtx,
) -> VortexResult<Option<Scalar>> {
Ok(None)
}
}
#[derive(Debug)]
struct SentinelSumPartialKernel;
impl DynAggregateKernel for SentinelSumPartialKernel {
fn aggregate(
&self,
_aggregate_fn: &AggregateFnRef,
_batch: &ArrayRef,
_ctx: &mut ExecutionCtx,
) -> VortexResult<Option<Scalar>> {
Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable)))
}
}
fn fresh_session() -> VortexSession {
VortexSession::empty().with::<ArraySession>()
}
fn dict_of_seven() -> ArrayRef {
DictArray::try_new(buffer![0u32].into_array(), buffer![7.0f64].into_array())
.expect("valid dictionary")
.into_array()
}
fn mean_f64_accumulator() -> VortexResult<Accumulator<Combined<Mean>>> {
let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
Accumulator::try_new(
Mean::combined(),
PairOptions(EmptyOptions, EmptyOptions),
dtype,
)
}
fn sentinel_partial() -> Scalar {
let acc = mean_f64_accumulator().expect("build accumulator");
let sum = Scalar::primitive(42.0f64, Nullability::Nullable);
let count = Scalar::primitive(1u64, Nullability::NonNullable);
Scalar::struct_(acc.partial_dtype, vec![sum, count])
}
#[test]
fn combined_kernel_fires() -> VortexResult<()> {
static KERNEL: SentinelMeanPartialKernel = SentinelMeanPartialKernel;
let session = fresh_session();
session
.get::<AggregateFnSession>()
.register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
let mut ctx = session.create_execution_ctx();
let mut acc = mean_f64_accumulator()?;
acc.accumulate(&dict_of_seven(), &mut ctx)?;
let partial = acc.flush()?;
let s = partial.as_struct();
assert_eq!(
s.field("sum").unwrap().as_primitive().as_::<f64>(),
Some(42.0)
);
assert_eq!(
s.field("count").unwrap().as_primitive().as_::<u64>(),
Some(1)
);
Ok(())
}
#[test]
fn fallback_when_kernel_declines() -> VortexResult<()> {
static KERNEL: DeclineKernel = DeclineKernel;
let session = fresh_session();
session
.get::<AggregateFnSession>()
.register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
let mut ctx = session.create_execution_ctx();
let mut acc = mean_f64_accumulator()?;
acc.accumulate(&dict_of_seven(), &mut ctx)?;
let partial = acc.flush()?;
let s = partial.as_struct();
assert_eq!(
s.field("sum").unwrap().as_primitive().as_::<f64>(),
Some(7.0)
);
assert_eq!(
s.field("count").unwrap().as_primitive().as_::<u64>(),
Some(1)
);
Ok(())
}
#[test]
fn child_kernel_fires_through_combined() -> VortexResult<()> {
static KERNEL: SentinelSumPartialKernel = SentinelSumPartialKernel;
let session = fresh_session();
session
.get::<AggregateFnSession>()
.register_aggregate_kernel(Dict.id(), Some(Sum.id()), &KERNEL);
let mut ctx = session.create_execution_ctx();
let mut acc = mean_f64_accumulator()?;
acc.accumulate(&dict_of_seven(), &mut ctx)?;
let partial = acc.flush()?;
let s = partial.as_struct();
assert_eq!(
s.field("sum").unwrap().as_primitive().as_::<f64>(),
Some(42.0)
);
assert_eq!(
s.field("count").unwrap().as_primitive().as_::<u64>(),
Some(1)
);
Ok(())
}
}