polars_compute/
float_sum.rs

1use std::ops::{Add, IndexMut};
2#[cfg(feature = "simd")]
3use std::simd::{prelude::*, *};
4
5use arrow::array::{Array, PrimitiveArray};
6use arrow::bitmap::Bitmap;
7use arrow::bitmap::bitmask::BitMask;
8use arrow::types::NativeType;
9use num_traits::{AsPrimitive, Float};
10
11const STRIPE: usize = 16;
12const PAIRWISE_RECURSION_LIMIT: usize = 128;
13
14// We want to be generic over both integers and floats, requiring this helper trait.
15#[cfg(feature = "simd")]
16pub trait SimdCastGeneric<const N: usize>
17where
18    LaneCount<N>: SupportedLaneCount,
19{
20    fn cast_generic<U: SimdCast>(self) -> Simd<U, N>;
21}
22
23macro_rules! impl_cast_custom {
24    ($_type:ty) => {
25        #[cfg(feature = "simd")]
26        impl<const N: usize> SimdCastGeneric<N> for Simd<$_type, N>
27        where
28            LaneCount<N>: SupportedLaneCount,
29        {
30            fn cast_generic<U: SimdCast>(self) -> Simd<U, N> {
31                self.cast::<U>()
32            }
33        }
34    };
35}
36
37impl_cast_custom!(u8);
38impl_cast_custom!(u16);
39impl_cast_custom!(u32);
40impl_cast_custom!(u64);
41impl_cast_custom!(i8);
42impl_cast_custom!(i16);
43impl_cast_custom!(i32);
44impl_cast_custom!(i64);
45impl_cast_custom!(f32);
46impl_cast_custom!(f64);
47
48fn vector_horizontal_sum<V, T>(mut v: V) -> T
49where
50    V: IndexMut<usize, Output = T>,
51    T: Add<T, Output = T> + Sized + Copy,
52{
53    // We have to be careful about this reduction, floating
54    // point math is NOT associative so we have to write this
55    // in a form that maps to good shuffle instructions.
56    // We fold the vector onto itself, halved, until we are down to
57    // four elements which we add in a shuffle-friendly way.
58    let mut width = STRIPE;
59    while width > 4 {
60        for j in 0..width / 2 {
61            v[j] = v[j] + v[width / 2 + j];
62        }
63        width /= 2;
64    }
65
66    (v[0] + v[2]) + (v[1] + v[3])
67}
68
69// As a trait to not proliferate SIMD bounds.
70pub trait SumBlock<F> {
71    fn sum_block_vectorized(&self) -> F;
72    fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F;
73}
74
75#[cfg(feature = "simd")]
76impl<T, F> SumBlock<F> for [T; PAIRWISE_RECURSION_LIMIT]
77where
78    T: SimdElement,
79    F: SimdElement + SimdCast + Add<Output = F> + Default,
80    Simd<T, STRIPE>: SimdCastGeneric<STRIPE>,
81    Simd<F, STRIPE>: std::iter::Sum,
82{
83    fn sum_block_vectorized(&self) -> F {
84        let vsum = self
85            .chunks_exact(STRIPE)
86            .map(|a| Simd::<T, STRIPE>::from_slice(a).cast_generic::<F>())
87            .sum::<Simd<F, STRIPE>>();
88        vector_horizontal_sum(vsum)
89    }
90
91    fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
92        let zero = Simd::default();
93        let vsum = self
94            .chunks_exact(STRIPE)
95            .enumerate()
96            .map(|(i, a)| {
97                let m: Mask<_, STRIPE> = mask.get_simd(i * STRIPE);
98                m.select(Simd::from_slice(a).cast_generic::<F>(), zero)
99            })
100            .sum::<Simd<F, STRIPE>>();
101        vector_horizontal_sum(vsum)
102    }
103}
104
105#[cfg(feature = "simd")]
106impl<F> SumBlock<F> for [i128; PAIRWISE_RECURSION_LIMIT]
107where
108    i128: AsPrimitive<F>,
109    F: Float + std::iter::Sum + 'static,
110{
111    fn sum_block_vectorized(&self) -> F {
112        self.iter().map(|x| x.as_()).sum()
113    }
114
115    fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
116        self.iter()
117            .enumerate()
118            .map(|(idx, x)| if mask.get(idx) { x.as_() } else { F::zero() })
119            .sum()
120    }
121}
122
123#[cfg(feature = "simd")]
124impl<F> SumBlock<F> for [u128; PAIRWISE_RECURSION_LIMIT]
125where
126    u128: AsPrimitive<F>,
127    F: Float + std::iter::Sum + 'static,
128{
129    fn sum_block_vectorized(&self) -> F {
130        self.iter().map(|x| x.as_()).sum()
131    }
132
133    fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
134        self.iter()
135            .enumerate()
136            .map(|(idx, x)| if mask.get(idx) { x.as_() } else { F::zero() })
137            .sum()
138    }
139}
140
141#[cfg(not(feature = "simd"))]
142impl<T, F> SumBlock<F> for [T; PAIRWISE_RECURSION_LIMIT]
143where
144    T: AsPrimitive<F> + 'static,
145    F: Default + Add<Output = F> + Copy + 'static,
146{
147    fn sum_block_vectorized(&self) -> F {
148        let mut vsum = [F::default(); STRIPE];
149        for chunk in self.chunks_exact(STRIPE) {
150            for j in 0..STRIPE {
151                vsum[j] = vsum[j] + chunk[j].as_();
152            }
153        }
154        vector_horizontal_sum(vsum)
155    }
156
157    fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
158        let mut vsum = [F::default(); STRIPE];
159        for (i, chunk) in self.chunks_exact(STRIPE).enumerate() {
160            for j in 0..STRIPE {
161                // Unconditional add with select for better branch-free opts.
162                let addend = if mask.get(i * STRIPE + j) {
163                    chunk[j].as_()
164                } else {
165                    F::default()
166                };
167                vsum[j] = vsum[j] + addend;
168            }
169        }
170        vector_horizontal_sum(vsum)
171    }
172}
173
174/// Invariant: f.len() % PAIRWISE_RECURSION_LIMIT == 0 and f.len() > 0.
175unsafe fn pairwise_sum<F, T>(f: &[T]) -> F
176where
177    [T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
178    F: Add<Output = F>,
179{
180    debug_assert!(!f.is_empty() && f.len().is_multiple_of(PAIRWISE_RECURSION_LIMIT));
181
182    let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok();
183    if let Some(block) = block {
184        return block.sum_block_vectorized();
185    }
186
187    // SAFETY: we maintain the invariant. `try_into` array of len PAIRWISE_RECURSION_LIMIT
188    // failed so we know f.len() >= 2*PAIRWISE_RECURSION_LIMIT, and thus blocks >= 2.
189    // This means 0 < left_len < f.len() and left_len is divisible by PAIRWISE_RECURSION_LIMIT,
190    // maintaining the invariant for both recursive calls.
191    unsafe {
192        let blocks = f.len() / PAIRWISE_RECURSION_LIMIT;
193        let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT;
194        let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..));
195        pairwise_sum(left) + pairwise_sum(right)
196    }
197}
198
199/// Invariant: f.len() % PAIRWISE_RECURSION_LIMIT == 0 and f.len() > 0.
200/// Also, f.len() == mask.len().
201unsafe fn pairwise_sum_with_mask<F, T>(f: &[T], mask: BitMask<'_>) -> F
202where
203    [T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
204    F: Add<Output = F>,
205{
206    debug_assert!(!f.is_empty() && f.len().is_multiple_of(PAIRWISE_RECURSION_LIMIT));
207    debug_assert!(f.len() == mask.len());
208
209    let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok();
210    if let Some(block) = block {
211        return block.sum_block_vectorized_with_mask(mask);
212    }
213
214    // SAFETY: see pairwise_sum.
215    unsafe {
216        let blocks = f.len() / PAIRWISE_RECURSION_LIMIT;
217        let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT;
218        let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..));
219        let (left_mask, right_mask) = mask.split_at_unchecked(left_len);
220        pairwise_sum_with_mask(left, left_mask) + pairwise_sum_with_mask(right, right_mask)
221    }
222}
223
224pub trait FloatSum<F>: Sized {
225    fn sum(f: &[Self]) -> F;
226    fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F;
227}
228
229impl<T, F> FloatSum<F> for T
230where
231    F: Float + std::iter::Sum + 'static,
232    T: AsPrimitive<F>,
233    [T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
234{
235    fn sum(f: &[Self]) -> F {
236        let remainder = f.len() % PAIRWISE_RECURSION_LIMIT;
237        let (rest, main) = f.split_at(remainder);
238        let mainsum = if f.len() > remainder {
239            unsafe { pairwise_sum(main) }
240        } else {
241            F::zero()
242        };
243        // TODO: faster remainder.
244        let restsum: F = rest.iter().map(|x| x.as_()).sum();
245        mainsum + restsum
246    }
247
248    fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F {
249        let mask = BitMask::from_bitmap(validity);
250        assert!(f.len() == mask.len());
251
252        let remainder = f.len() % PAIRWISE_RECURSION_LIMIT;
253        let (rest, main) = f.split_at(remainder);
254        let (rest_mask, main_mask) = mask.split_at(remainder);
255        let mainsum = if f.len() > remainder {
256            unsafe { pairwise_sum_with_mask(main, main_mask) }
257        } else {
258            F::zero()
259        };
260        // TODO: faster remainder.
261        let restsum: F = rest
262            .iter()
263            .enumerate()
264            .map(|(i, x)| {
265                // No filter but rather select of 0.0 for cmov opt.
266                if rest_mask.get(i) { x.as_() } else { F::zero() }
267            })
268            .sum();
269        mainsum + restsum
270    }
271}
272
273pub fn sum_arr_as_f32<T>(arr: &PrimitiveArray<T>) -> f32
274where
275    T: NativeType + FloatSum<f32>,
276{
277    let validity = arr.validity().filter(|_| arr.null_count() > 0);
278    if let Some(mask) = validity {
279        FloatSum::sum_with_validity(arr.values(), mask)
280    } else {
281        FloatSum::sum(arr.values())
282    }
283}
284
285pub fn sum_arr_as_f64<T>(arr: &PrimitiveArray<T>) -> f64
286where
287    T: NativeType + FloatSum<f64>,
288{
289    let validity = arr.validity().filter(|_| arr.null_count() > 0);
290    if let Some(mask) = validity {
291        FloatSum::sum_with_validity(arr.values(), mask)
292    } else {
293        FloatSum::sum(arr.values())
294    }
295}