use std::sync::LazyLock;
use arcref::ArcRef;
use num_traits::CheckedAdd;
use num_traits::CheckedSub;
use vortex_error::VortexError;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use vortex_error::vortex_panic;
use crate::ArrayRef;
use crate::DynArray;
use crate::IntoArray as _;
use crate::compute::ComputeFn;
use crate::compute::ComputeFnVTable;
use crate::compute::InvocationArgs;
use crate::compute::Kernel;
use crate::compute::Output;
use crate::dtype::DType;
use crate::expr::stats::Precision;
use crate::expr::stats::Stat;
use crate::expr::stats::StatsProvider;
use crate::scalar::NumericOperator;
use crate::scalar::Scalar;
use crate::vtable::VTable;
static SUM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
let compute = ComputeFn::new("sum".into(), ArcRef::new_ref(&Sum));
for kernel in inventory::iter::<SumKernelRef> {
compute.register_kernel(kernel.0.clone());
}
compute
});
pub(crate) fn warm_up_vtable() -> usize {
SUM_FN.kernels().len()
}
pub(crate) fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult<Scalar> {
SUM_FN
.invoke(&InvocationArgs {
inputs: &[array.into(), accumulator.into()],
options: &(),
})?
.unwrap_scalar()
}
pub fn sum(array: &ArrayRef) -> VortexResult<Scalar> {
let sum_dtype = Stat::Sum
.dtype(array.dtype())
.ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?;
let zero = Scalar::zero_value(&sum_dtype);
sum_with_accumulator(array, &zero)
}
pub struct SumArgs<'a> {
pub array: &'a dyn DynArray,
pub accumulator: &'a Scalar,
}
impl<'a> TryFrom<&InvocationArgs<'a>> for SumArgs<'a> {
type Error = VortexError;
fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
if value.inputs.len() != 2 {
vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
}
let array = value.inputs[0]
.array()
.ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
let accumulator = value.inputs[1]
.scalar()
.ok_or_else(|| vortex_err!("Expected input 1 to be a scalar"))?;
Ok(SumArgs { array, accumulator })
}
}
struct Sum;
impl ComputeFnVTable for Sum {
fn invoke(
&self,
args: &InvocationArgs,
kernels: &[ArcRef<dyn Kernel>],
) -> VortexResult<Output> {
let SumArgs { array, accumulator } = args.try_into()?;
let array = array.to_array();
let sum_dtype = self.return_dtype(args)?;
vortex_ensure!(
&sum_dtype == accumulator.dtype(),
"sum_dtype {sum_dtype} must match accumulator dtype {}",
accumulator.dtype()
);
if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) {
match &sum_dtype {
DType::Primitive(p, _) => {
if p.is_float() && accumulator.is_zero() == Some(true) {
return Ok(sum_scalar.into());
} else if p.is_int() {
let sum_from_stat = accumulator
.as_primitive()
.checked_add(&sum_scalar.as_primitive())
.map(Scalar::from);
return Ok(sum_from_stat
.unwrap_or_else(|| Scalar::null(sum_dtype))
.into());
}
}
DType::Decimal(..) => {
let sum_from_stat = accumulator
.as_decimal()
.checked_binary_numeric(&sum_scalar.as_decimal(), NumericOperator::Add)
.map(Scalar::from);
return Ok(sum_from_stat
.unwrap_or_else(|| Scalar::null(sum_dtype))
.into());
}
_ => unreachable!("Sum will always be a decimal or a primitive dtype"),
}
}
let sum_scalar = sum_impl(&array, accumulator, kernels)?;
match sum_dtype {
DType::Primitive(p, _) => {
if p.is_float()
&& accumulator.is_zero() == Some(true)
&& let Some(sum_value) = sum_scalar.value().cloned()
{
array
.statistics()
.set(Stat::Sum, Precision::Exact(sum_value));
} else if p.is_int()
&& let Some(less_accumulator) = sum_scalar
.as_primitive()
.checked_sub(&accumulator.as_primitive())
&& let Some(val) = Scalar::from(less_accumulator).into_value()
{
array.statistics().set(Stat::Sum, Precision::Exact(val));
}
}
DType::Decimal(..) => {
if let Some(less_accumulator) = sum_scalar
.as_decimal()
.checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub)
&& let Some(val) = Scalar::from(less_accumulator).into_value()
{
array.statistics().set(Stat::Sum, Precision::Exact(val));
}
}
_ => unreachable!("Sum will always be a decimal or a primitive dtype"),
}
Ok(sum_scalar.into())
}
fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
let SumArgs { array, .. } = args.try_into()?;
Stat::Sum
.dtype(array.dtype())
.ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
}
fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
Ok(1)
}
fn is_elementwise(&self) -> bool {
false
}
}
pub struct SumKernelRef(ArcRef<dyn Kernel>);
inventory::collect!(SumKernelRef);
pub trait SumKernel: VTable {
fn sum(&self, array: &Self::Array, accumulator: &Scalar) -> VortexResult<Scalar>;
}
#[derive(Debug)]
pub struct SumKernelAdapter<V: VTable>(pub V);
impl<V: VTable + SumKernel> SumKernelAdapter<V> {
pub const fn lift(&'static self) -> SumKernelRef {
SumKernelRef(ArcRef::new_ref(self))
}
}
impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
let SumArgs { array, accumulator } = args.try_into()?;
let Some(array) = array.as_opt::<V>() else {
return Ok(None);
};
Ok(Some(V::sum(&self.0, array, accumulator)?.into()))
}
}
pub fn sum_impl(
array: &ArrayRef,
accumulator: &Scalar,
kernels: &[ArcRef<dyn Kernel>],
) -> VortexResult<Scalar> {
if array.is_empty() || array.all_invalid()? || accumulator.is_null() {
return Ok(accumulator.clone());
}
let args = InvocationArgs {
inputs: &[array.into(), accumulator.into()],
options: &(),
};
for kernel in kernels {
if let Some(output) = kernel.invoke(&args)? {
return output.unwrap_scalar();
}
}
tracing::debug!("No sum implementation found for {}", array.encoding_id());
if array.is_canonical() {
vortex_panic!(
"No sum implementation found for canonical array: {}",
array.encoding_id()
);
}
let canonical = array.to_canonical()?.into_array();
sum_with_accumulator(&canonical, accumulator)
}
#[cfg(test)]
mod test {
use vortex_buffer::buffer;
use vortex_error::VortexExpect;
use crate::IntoArray as _;
use crate::arrays::BoolArray;
use crate::arrays::ChunkedArray;
use crate::arrays::PrimitiveArray;
use crate::compute::sum;
use crate::compute::sum_with_accumulator;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::scalar::Scalar;
#[test]
fn sum_all_invalid() {
let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]).into_array();
let result = sum(&array).unwrap();
assert_eq!(result, Scalar::primitive(0i64, Nullability::Nullable));
}
#[test]
fn sum_all_invalid_float() {
let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]).into_array();
let result = sum(&array).unwrap();
assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable));
}
#[test]
fn sum_constant() {
let array = buffer![1, 1, 1, 1].into_array();
let result = sum(&array).unwrap();
assert_eq!(result.as_primitive().as_::<i32>(), Some(4));
}
#[test]
fn sum_constant_float() {
let array = buffer![1., 1., 1., 1.].into_array();
let result = sum(&array).unwrap();
assert_eq!(result.as_primitive().as_::<f32>(), Some(4.));
}
#[test]
fn sum_boolean() {
let array = BoolArray::from_iter([true, false, false, true]).into_array();
let result = sum(&array).unwrap();
assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
}
#[test]
fn sum_stats() {
let array = ChunkedArray::try_new(
vec![
PrimitiveArray::from_iter([1, 1, 1]).into_array(),
PrimitiveArray::from_iter([2, 2, 2]).into_array(),
],
DType::Primitive(PType::I32, Nullability::NonNullable),
)
.vortex_expect("operation should succeed in test");
let array = array.into_array();
sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullability::Nullable)).unwrap();
let sum_without_acc = sum(&array).unwrap();
assert_eq!(
sum_without_acc,
Scalar::primitive(9i64, Nullability::Nullable)
);
}
}