use std::sync::LazyLock;
use arcref::ArcRef;
use vortex_dtype::DType;
use vortex_error::{VortexResult, vortex_err, vortex_panic};
use vortex_scalar::Scalar;
use crate::Array;
use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
use crate::stats::{Precision, Stat, StatsProvider};
use crate::vtable::VTable;
static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
for kernel in inventory::iter::<SumKernelRef> {
compute.register_kernel(kernel.0.clone());
}
compute
});
pub(crate) fn warm_up_vtable() -> usize {
SUM_FN.kernels().len()
}
pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
SUM_FN
.invoke(&InvocationArgs {
inputs: &[array.into()],
options: &(),
})?
.unwrap_scalar()
}
struct Sum;
impl ComputeFnVTable for Sum {
fn invoke(
&self,
args: &InvocationArgs,
kernels: &[ArcRef<dyn Kernel>],
) -> VortexResult<Output> {
let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
let sum_dtype = self.return_dtype(args)?;
if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
return Ok(sum.into());
}
let sum_scalar = sum_impl(array, sum_dtype, kernels)?;
array
.statistics()
.set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
Ok(sum_scalar.into())
}
fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
Stat::Sum
.dtype(array.dtype())
.ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
}
fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
Ok(1)
}
fn is_elementwise(&self) -> bool {
false
}
}
pub struct SumKernelRef(ArcRef<dyn Kernel>);
inventory::collect!(SumKernelRef);
pub trait SumKernel: VTable {
fn sum(&self, array: &Self::Array) -> VortexResult<Scalar>;
}
#[derive(Debug)]
pub struct SumKernelAdapter<V: VTable>(pub V);
impl<V: VTable + SumKernel> SumKernelAdapter<V> {
pub const fn lift(&'static self) -> SumKernelRef {
SumKernelRef(ArcRef::new_ref(self))
}
}
impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
let Some(array) = array.as_opt::<V>() else {
return Ok(None);
};
Ok(Some(V::sum(&self.0, array)?.into()))
}
}
pub fn sum_impl(
array: &dyn Array,
sum_dtype: DType,
kernels: &[ArcRef<dyn Kernel>],
) -> VortexResult<Scalar> {
if array.is_empty() {
return if sum_dtype.is_float() {
Ok(Scalar::new(sum_dtype, 0.0.into()))
} else {
Ok(Scalar::new(sum_dtype, 0.into()))
};
}
if array.all_invalid() {
return Ok(Scalar::null(sum_dtype));
}
let args = InvocationArgs {
inputs: &[array.into()],
options: &(),
};
for kernel in kernels {
if let Some(output) = kernel.invoke(&args)? {
return output.unwrap_scalar();
}
}
if let Some(output) = array.invoke(&SUM_FN, &args)? {
return output.unwrap_scalar();
}
log::debug!("No sum implementation found for {}", array.encoding_id());
if array.is_canonical() {
vortex_panic!(
"No sum implementation found for canonical array: {}",
array.encoding_id()
);
}
sum(array.to_canonical().as_ref())
}
#[cfg(test)]
mod test {
use vortex_buffer::buffer;
use vortex_dtype::{DType, Nullability, PType};
use vortex_scalar::Scalar;
use crate::IntoArray as _;
use crate::arrays::{BoolArray, PrimitiveArray};
use crate::compute::sum;
#[test]
fn sum_all_invalid() {
let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
let result = sum(array.as_ref()).unwrap();
assert_eq!(
result,
Scalar::null(DType::Primitive(PType::I64, Nullability::Nullable))
);
}
#[test]
fn sum_all_invalid_float() {
let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
let result = sum(array.as_ref()).unwrap();
assert_eq!(
result,
Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable))
);
}
#[test]
fn sum_constant() {
let array = buffer![1, 1, 1, 1].into_array();
let result = sum(array.as_ref()).unwrap();
assert_eq!(result.as_primitive().as_::<i32>(), Some(4));
}
#[test]
fn sum_constant_float() {
let array = buffer![1., 1., 1., 1.].into_array();
let result = sum(array.as_ref()).unwrap();
assert_eq!(result.as_primitive().as_::<f32>(), Some(4.));
}
#[test]
fn sum_boolean() {
let array = BoolArray::from_iter([true, false, false, true]);
let result = sum(array.as_ref()).unwrap();
assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
}
}