Skip to main content

polars_compute/
float_sum.rs

1use std::ops::{Add, IndexMut};
2#[cfg(feature = "simd")]
3use std::simd::{prelude::*, *};
4
5use arrow::array::{Array, PrimitiveArray};
6use arrow::bitmap::Bitmap;
7use arrow::bitmap::bitmask::BitMask;
8use arrow::types::NativeType;
9use num_traits::{AsPrimitive, Float};
10#[cfg(feature = "simd")]
11use polars_utils::float16::pf16;
12
13const STRIPE: usize = 16;
14const PAIRWISE_RECURSION_LIMIT: usize = 128;
15
16// We want to be generic over both integers and floats, requiring this helper trait.
17#[cfg(feature = "simd")]
18pub trait SimdCastGeneric<const N: usize>
19where
20    LaneCount<N>: SupportedLaneCount,
21{
22    fn cast_generic<U: SimdCast>(self) -> Simd<U, N>;
23}
24
25macro_rules! impl_cast_custom {
26    ($_type:ty) => {
27        #[cfg(feature = "simd")]
28        impl<const N: usize> SimdCastGeneric<N> for Simd<$_type, N>
29        where
30            LaneCount<N>: SupportedLaneCount,
31        {
32            fn cast_generic<U: SimdCast>(self) -> Simd<U, N> {
33                self.cast::<U>()
34            }
35        }
36    };
37}
38
39impl_cast_custom!(u8);
40impl_cast_custom!(u16);
41impl_cast_custom!(u32);
42impl_cast_custom!(u64);
43impl_cast_custom!(i8);
44impl_cast_custom!(i16);
45impl_cast_custom!(i32);
46impl_cast_custom!(i64);
47impl_cast_custom!(f32);
48impl_cast_custom!(f64);
49
50fn vector_horizontal_sum<V, T>(mut v: V) -> T
51where
52    V: IndexMut<usize, Output = T>,
53    T: Add<T, Output = T> + Sized + Copy,
54{
55    // We have to be careful about this reduction, floating
56    // point math is NOT associative so we have to write this
57    // in a form that maps to good shuffle instructions.
58    // We fold the vector onto itself, halved, until we are down to
59    // four elements which we add in a shuffle-friendly way.
60    let mut width = STRIPE;
61    while width > 4 {
62        for j in 0..width / 2 {
63            v[j] = v[j] + v[width / 2 + j];
64        }
65        width /= 2;
66    }
67
68    (v[0] + v[2]) + (v[1] + v[3])
69}
70
71// As a trait to not proliferate SIMD bounds.
72pub trait SumBlock<F> {
73    fn sum_block_vectorized(&self) -> F;
74    fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F;
75}
76
77#[cfg(feature = "simd")]
78impl<T, F> SumBlock<F> for [T; PAIRWISE_RECURSION_LIMIT]
79where
80    T: SimdElement,
81    F: SimdElement + SimdCast + Add<Output = F> + Default,
82    Simd<T, STRIPE>: SimdCastGeneric<STRIPE>,
83    Simd<F, STRIPE>: std::iter::Sum,
84{
85    fn sum_block_vectorized(&self) -> F {
86        let vsum = self
87            .chunks_exact(STRIPE)
88            .map(|a| Simd::<T, STRIPE>::from_slice(a).cast_generic::<F>())
89            .sum::<Simd<F, STRIPE>>();
90        vector_horizontal_sum(vsum)
91    }
92
93    fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
94        let zero = Simd::default();
95        let vsum = self
96            .chunks_exact(STRIPE)
97            .enumerate()
98            .map(|(i, a)| {
99                let m: Mask<_, STRIPE> = mask.get_simd(i * STRIPE);
100                m.select(Simd::from_slice(a).cast_generic::<F>(), zero)
101            })
102            .sum::<Simd<F, STRIPE>>();
103        vector_horizontal_sum(vsum)
104    }
105}
106
107#[cfg(feature = "simd")]
108impl<F> SumBlock<F> for [i128; PAIRWISE_RECURSION_LIMIT]
109where
110    i128: AsPrimitive<F>,
111    F: Float + std::iter::Sum + 'static,
112{
113    fn sum_block_vectorized(&self) -> F {
114        self.iter().map(|x| x.as_()).sum()
115    }
116
117    fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
118        self.iter()
119            .enumerate()
120            .map(|(idx, x)| if mask.get(idx) { x.as_() } else { F::zero() })
121            .sum()
122    }
123}
124
125#[cfg(feature = "simd")]
126impl<F> SumBlock<F> for [u128; PAIRWISE_RECURSION_LIMIT]
127where
128    u128: AsPrimitive<F>,
129    F: Float + std::iter::Sum + 'static,
130{
131    fn sum_block_vectorized(&self) -> F {
132        self.iter().map(|x| x.as_()).sum()
133    }
134
135    fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
136        self.iter()
137            .enumerate()
138            .map(|(idx, x)| if mask.get(idx) { x.as_() } else { F::zero() })
139            .sum()
140    }
141}
142
143#[cfg(feature = "simd")]
144impl<F> SumBlock<F> for [pf16; PAIRWISE_RECURSION_LIMIT]
145where
146    pf16: AsPrimitive<F>,
147    F: Float + std::iter::Sum + 'static,
148{
149    fn sum_block_vectorized(&self) -> F {
150        self.iter().map(|x| x.as_()).sum()
151    }
152
153    fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
154        self.iter()
155            .enumerate()
156            .map(|(idx, x)| if mask.get(idx) { x.as_() } else { F::zero() })
157            .sum()
158    }
159}
160
161#[cfg(not(feature = "simd"))]
162impl<T, F> SumBlock<F> for [T; PAIRWISE_RECURSION_LIMIT]
163where
164    T: AsPrimitive<F> + 'static,
165    F: Default + Add<Output = F> + Copy + 'static,
166{
167    fn sum_block_vectorized(&self) -> F {
168        let mut vsum = [F::default(); STRIPE];
169        for chunk in self.chunks_exact(STRIPE) {
170            for j in 0..STRIPE {
171                vsum[j] = vsum[j] + chunk[j].as_();
172            }
173        }
174        vector_horizontal_sum(vsum)
175    }
176
177    fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F {
178        let mut vsum = [F::default(); STRIPE];
179        for (i, chunk) in self.chunks_exact(STRIPE).enumerate() {
180            for j in 0..STRIPE {
181                // Unconditional add with select for better branch-free opts.
182                let addend = if mask.get(i * STRIPE + j) {
183                    chunk[j].as_()
184                } else {
185                    F::default()
186                };
187                vsum[j] = vsum[j] + addend;
188            }
189        }
190        vector_horizontal_sum(vsum)
191    }
192}
193
194/// Invariant: f.len() % PAIRWISE_RECURSION_LIMIT == 0 and f.len() > 0.
195unsafe fn pairwise_sum<F, T>(f: &[T]) -> F
196where
197    [T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
198    F: Add<Output = F>,
199{
200    debug_assert!(!f.is_empty() && f.len().is_multiple_of(PAIRWISE_RECURSION_LIMIT));
201
202    let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok();
203    if let Some(block) = block {
204        return block.sum_block_vectorized();
205    }
206
207    // SAFETY: we maintain the invariant. `try_into` array of len PAIRWISE_RECURSION_LIMIT
208    // failed so we know f.len() >= 2*PAIRWISE_RECURSION_LIMIT, and thus blocks >= 2.
209    // This means 0 < left_len < f.len() and left_len is divisible by PAIRWISE_RECURSION_LIMIT,
210    // maintaining the invariant for both recursive calls.
211    unsafe {
212        let blocks = f.len() / PAIRWISE_RECURSION_LIMIT;
213        let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT;
214        let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..));
215        pairwise_sum(left) + pairwise_sum(right)
216    }
217}
218
219/// Invariant: f.len() % PAIRWISE_RECURSION_LIMIT == 0 and f.len() > 0.
220/// Also, f.len() == mask.len().
221unsafe fn pairwise_sum_with_mask<F, T>(f: &[T], mask: BitMask<'_>) -> F
222where
223    [T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
224    F: Add<Output = F>,
225{
226    debug_assert!(!f.is_empty() && f.len().is_multiple_of(PAIRWISE_RECURSION_LIMIT));
227    debug_assert!(f.len() == mask.len());
228
229    let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok();
230    if let Some(block) = block {
231        return block.sum_block_vectorized_with_mask(mask);
232    }
233
234    // SAFETY: see pairwise_sum.
235    unsafe {
236        let blocks = f.len() / PAIRWISE_RECURSION_LIMIT;
237        let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT;
238        let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..));
239        let (left_mask, right_mask) = mask.split_at_unchecked(left_len);
240        pairwise_sum_with_mask(left, left_mask) + pairwise_sum_with_mask(right, right_mask)
241    }
242}
243
244pub trait FloatSum<F>: Sized {
245    fn sum(f: &[Self]) -> F;
246    fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F;
247}
248
249impl<T, F> FloatSum<F> for T
250where
251    F: Float + std::iter::Sum + 'static,
252    T: AsPrimitive<F>,
253    [T; PAIRWISE_RECURSION_LIMIT]: SumBlock<F>,
254{
255    fn sum(f: &[Self]) -> F {
256        let remainder = f.len() % PAIRWISE_RECURSION_LIMIT;
257        let (rest, main) = f.split_at(remainder);
258        let mainsum = if f.len() > remainder {
259            unsafe { pairwise_sum(main) }
260        } else {
261            F::zero()
262        };
263        // TODO: faster remainder.
264        let restsum: F = rest.iter().map(|x| x.as_()).sum();
265        mainsum + restsum
266    }
267
268    fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F {
269        let mask = BitMask::from_bitmap(validity);
270        assert!(f.len() == mask.len());
271
272        let remainder = f.len() % PAIRWISE_RECURSION_LIMIT;
273        let (rest, main) = f.split_at(remainder);
274        let (rest_mask, main_mask) = mask.split_at(remainder);
275        let mainsum = if f.len() > remainder {
276            unsafe { pairwise_sum_with_mask(main, main_mask) }
277        } else {
278            F::zero()
279        };
280        // TODO: faster remainder.
281        let restsum: F = rest
282            .iter()
283            .enumerate()
284            .map(|(i, x)| {
285                // No filter but rather select of 0.0 for cmov opt.
286                if rest_mask.get(i) { x.as_() } else { F::zero() }
287            })
288            .sum();
289        mainsum + restsum
290    }
291}
292
293pub fn sum_arr_as_f32<T>(arr: &PrimitiveArray<T>) -> f32
294where
295    T: NativeType + FloatSum<f32>,
296{
297    let validity = arr.validity().filter(|_| arr.null_count() > 0);
298    if let Some(mask) = validity {
299        FloatSum::sum_with_validity(arr.values(), mask)
300    } else {
301        FloatSum::sum(arr.values())
302    }
303}
304
305pub fn sum_arr_as_f64<T>(arr: &PrimitiveArray<T>) -> f64
306where
307    T: NativeType + FloatSum<f64>,
308{
309    let validity = arr.validity().filter(|_| arr.null_count() > 0);
310    if let Some(mask) = validity {
311        FloatSum::sum_with_validity(arr.values(), mask)
312    } else {
313        FloatSum::sum(arr.values())
314    }
315}