arrow2/compute/aggregate/
sum.rs

1use std::ops::Add;
2
3use multiversion::multiversion;
4
5use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact};
6use crate::datatypes::{DataType, PhysicalType, PrimitiveType};
7use crate::error::{Error, Result};
8use crate::scalar::*;
9use crate::types::simd::*;
10use crate::types::NativeType;
11use crate::{
12    array::{Array, PrimitiveArray},
13    bitmap::Bitmap,
14};
15
16/// Object that can reduce itself to a number. This is used in the context of SIMD to reduce
17/// a MD (e.g. `[f32; 16]`) into a single number (`f32`).
18pub trait Sum<T> {
19    /// Reduces this element to a single value.
20    fn simd_sum(self) -> T;
21}
22
23#[multiversion(targets = "simd")]
24/// Compute the sum of a slice
25pub fn sum_slice<T>(values: &[T]) -> T
26where
27    T: NativeType + Simd + Add<Output = T> + std::iter::Sum<T>,
28    T::Simd: Sum<T> + Add<Output = T::Simd>,
29{
30    let (head, simd_vals, tail) = T::Simd::align(values);
31
32    let mut reduced = T::Simd::from_incomplete_chunk(&[], T::default());
33    for chunk in simd_vals {
34        reduced = reduced + *chunk;
35    }
36
37    reduced.simd_sum() + head.iter().copied().sum() + tail.iter().copied().sum()
38}
39
40/// # Panics
41/// iff `values.len() != bitmap.len()` or the operation overflows.
42#[multiversion(targets = "simd")]
43fn null_sum_impl<T, I>(values: &[T], mut validity_masks: I) -> T
44where
45    T: NativeType + Simd,
46    T::Simd: Add<Output = T::Simd> + Sum<T>,
47    I: BitChunkIterExact<<<T as Simd>::Simd as NativeSimd>::Chunk>,
48{
49    let mut chunks = values.chunks_exact(T::Simd::LANES);
50
51    let sum = chunks.by_ref().zip(validity_masks.by_ref()).fold(
52        T::Simd::default(),
53        |acc, (chunk, validity_chunk)| {
54            let chunk = T::Simd::from_chunk(chunk);
55            let mask = <T::Simd as NativeSimd>::Mask::from_chunk(validity_chunk);
56            let selected = chunk.select(mask, T::Simd::default());
57            acc + selected
58        },
59    );
60
61    let remainder = T::Simd::from_incomplete_chunk(chunks.remainder(), T::default());
62    let mask = <T::Simd as NativeSimd>::Mask::from_chunk(validity_masks.remainder());
63    let remainder = remainder.select(mask, T::Simd::default());
64    let reduced = sum + remainder;
65
66    reduced.simd_sum()
67}
68
69/// # Panics
70/// iff `values.len() != bitmap.len()` or the operation overflows.
71fn null_sum<T>(values: &[T], bitmap: &Bitmap) -> T
72where
73    T: NativeType + Simd,
74    T::Simd: Add<Output = T::Simd> + Sum<T>,
75{
76    let (slice, offset, length) = bitmap.as_slice();
77    if offset == 0 {
78        let validity_masks = BitChunksExact::<<T::Simd as NativeSimd>::Chunk>::new(slice, length);
79        null_sum_impl(values, validity_masks)
80    } else {
81        let validity_masks = bitmap.chunks::<<T::Simd as NativeSimd>::Chunk>();
82        null_sum_impl(values, validity_masks)
83    }
84}
85
86/// Returns the sum of values in the array.
87///
88/// Returns `None` if the array is empty or only contains null values.
89pub fn sum_primitive<T>(array: &PrimitiveArray<T>) -> Option<T>
90where
91    T: NativeType + Simd + Add<Output = T> + std::iter::Sum<T>,
92    T::Simd: Add<Output = T::Simd> + Sum<T>,
93{
94    let null_count = array.null_count();
95
96    if null_count == array.len() {
97        return None;
98    }
99
100    match array.validity() {
101        None => Some(sum_slice(array.values())),
102        Some(bitmap) => Some(null_sum(array.values(), bitmap)),
103    }
104}
105
106/// Whether [`sum`] supports `data_type`
107pub fn can_sum(data_type: &DataType) -> bool {
108    if let PhysicalType::Primitive(primitive) = data_type.to_physical_type() {
109        use PrimitiveType::*;
110        matches!(
111            primitive,
112            Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64
113        )
114    } else {
115        false
116    }
117}
118
119macro_rules! with_match_primitive_type {(
120    $key_type:expr, | $_:tt $T:ident | $($body:tt)*
121) => ({
122    macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
123    use crate::datatypes::PrimitiveType::*;
124    match $key_type {
125        Int8 => __with_ty__! { i8 },
126        Int16 => __with_ty__! { i16 },
127        Int32 => __with_ty__! { i32 },
128        Int64 => __with_ty__! { i64 },
129        Int128 => __with_ty__! { i128 },
130        UInt8 => __with_ty__! { u8 },
131        UInt16 => __with_ty__! { u16 },
132        UInt32 => __with_ty__! { u32 },
133        UInt64 => __with_ty__! { u64 },
134        Float32 => __with_ty__! { f32 },
135        Float64 => __with_ty__! { f64 },
136        _ => return Err(Error::InvalidArgumentError(format!(
137            "`sum` operator do not support primitive `{:?}`",
138            $key_type,
139        ))),
140    }
141})}
142
143/// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical
144/// and logical types as `array`.
145/// # Error
146/// Errors iff the operation is not supported.
147pub fn sum(array: &dyn Array) -> Result<Box<dyn Scalar>> {
148    Ok(match array.data_type().to_physical_type() {
149        PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
150            let data_type = array.data_type().clone();
151            let array = array.as_any().downcast_ref().unwrap();
152            Box::new(PrimitiveScalar::new(data_type, sum_primitive::<$T>(array)))
153        }),
154        _ => {
155            return Err(Error::InvalidArgumentError(format!(
156                "The `sum` operator does not support type `{:?}`",
157                array.data_type(),
158            )))
159        }
160    })
161}