arrow2/compute/aggregate/
sum.rs1use std::ops::Add;
2
3use multiversion::multiversion;
4
5use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact};
6use crate::datatypes::{DataType, PhysicalType, PrimitiveType};
7use crate::error::{Error, Result};
8use crate::scalar::*;
9use crate::types::simd::*;
10use crate::types::NativeType;
11use crate::{
12 array::{Array, PrimitiveArray},
13 bitmap::Bitmap,
14};
15
16pub trait Sum<T> {
19 fn simd_sum(self) -> T;
21}
22
23#[multiversion(targets = "simd")]
24pub fn sum_slice<T>(values: &[T]) -> T
26where
27 T: NativeType + Simd + Add<Output = T> + std::iter::Sum<T>,
28 T::Simd: Sum<T> + Add<Output = T::Simd>,
29{
30 let (head, simd_vals, tail) = T::Simd::align(values);
31
32 let mut reduced = T::Simd::from_incomplete_chunk(&[], T::default());
33 for chunk in simd_vals {
34 reduced = reduced + *chunk;
35 }
36
37 reduced.simd_sum() + head.iter().copied().sum() + tail.iter().copied().sum()
38}
39
40#[multiversion(targets = "simd")]
43fn null_sum_impl<T, I>(values: &[T], mut validity_masks: I) -> T
44where
45 T: NativeType + Simd,
46 T::Simd: Add<Output = T::Simd> + Sum<T>,
47 I: BitChunkIterExact<<<T as Simd>::Simd as NativeSimd>::Chunk>,
48{
49 let mut chunks = values.chunks_exact(T::Simd::LANES);
50
51 let sum = chunks.by_ref().zip(validity_masks.by_ref()).fold(
52 T::Simd::default(),
53 |acc, (chunk, validity_chunk)| {
54 let chunk = T::Simd::from_chunk(chunk);
55 let mask = <T::Simd as NativeSimd>::Mask::from_chunk(validity_chunk);
56 let selected = chunk.select(mask, T::Simd::default());
57 acc + selected
58 },
59 );
60
61 let remainder = T::Simd::from_incomplete_chunk(chunks.remainder(), T::default());
62 let mask = <T::Simd as NativeSimd>::Mask::from_chunk(validity_masks.remainder());
63 let remainder = remainder.select(mask, T::Simd::default());
64 let reduced = sum + remainder;
65
66 reduced.simd_sum()
67}
68
69fn null_sum<T>(values: &[T], bitmap: &Bitmap) -> T
72where
73 T: NativeType + Simd,
74 T::Simd: Add<Output = T::Simd> + Sum<T>,
75{
76 let (slice, offset, length) = bitmap.as_slice();
77 if offset == 0 {
78 let validity_masks = BitChunksExact::<<T::Simd as NativeSimd>::Chunk>::new(slice, length);
79 null_sum_impl(values, validity_masks)
80 } else {
81 let validity_masks = bitmap.chunks::<<T::Simd as NativeSimd>::Chunk>();
82 null_sum_impl(values, validity_masks)
83 }
84}
85
86pub fn sum_primitive<T>(array: &PrimitiveArray<T>) -> Option<T>
90where
91 T: NativeType + Simd + Add<Output = T> + std::iter::Sum<T>,
92 T::Simd: Add<Output = T::Simd> + Sum<T>,
93{
94 let null_count = array.null_count();
95
96 if null_count == array.len() {
97 return None;
98 }
99
100 match array.validity() {
101 None => Some(sum_slice(array.values())),
102 Some(bitmap) => Some(null_sum(array.values(), bitmap)),
103 }
104}
105
106pub fn can_sum(data_type: &DataType) -> bool {
108 if let PhysicalType::Primitive(primitive) = data_type.to_physical_type() {
109 use PrimitiveType::*;
110 matches!(
111 primitive,
112 Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64
113 )
114 } else {
115 false
116 }
117}
118
119macro_rules! with_match_primitive_type {(
120 $key_type:expr, | $_:tt $T:ident | $($body:tt)*
121) => ({
122 macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
123 use crate::datatypes::PrimitiveType::*;
124 match $key_type {
125 Int8 => __with_ty__! { i8 },
126 Int16 => __with_ty__! { i16 },
127 Int32 => __with_ty__! { i32 },
128 Int64 => __with_ty__! { i64 },
129 Int128 => __with_ty__! { i128 },
130 UInt8 => __with_ty__! { u8 },
131 UInt16 => __with_ty__! { u16 },
132 UInt32 => __with_ty__! { u32 },
133 UInt64 => __with_ty__! { u64 },
134 Float32 => __with_ty__! { f32 },
135 Float64 => __with_ty__! { f64 },
136 _ => return Err(Error::InvalidArgumentError(format!(
137 "`sum` operator do not support primitive `{:?}`",
138 $key_type,
139 ))),
140 }
141})}
142
143pub fn sum(array: &dyn Array) -> Result<Box<dyn Scalar>> {
148 Ok(match array.data_type().to_physical_type() {
149 PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
150 let data_type = array.data_type().clone();
151 let array = array.as_any().downcast_ref().unwrap();
152 Box::new(PrimitiveScalar::new(data_type, sum_primitive::<$T>(array)))
153 }),
154 _ => {
155 return Err(Error::InvalidArgumentError(format!(
156 "The `sum` operator does not support type `{:?}`",
157 array.data_type(),
158 )))
159 }
160 })
161}