Skip to main content

polars_compute/
sum.rs

1use std::ops::Add;
2#[cfg(feature = "simd")]
3use std::simd::prelude::*;
4
5use arrow::array::{Array, PrimitiveArray};
6use arrow::bitmap::bitmask::BitMask;
7use arrow::types::NativeType;
8use num_traits::Zero;
9use polars_utils::float16::pf16;
10
11macro_rules! wrapping_impl {
12    ($trait_name:ident, $method:ident, $t:ty) => {
13        impl $trait_name for $t {
14            #[inline(always)]
15            fn wrapping_add(&self, v: &Self) -> Self {
16                <$t>::$method(*self, *v)
17            }
18        }
19    };
20}
21
22/// Performs addition that wraps around on overflow.
23///
24/// Differs from num::WrappingAdd in that this is also implemented for floats.
25pub trait WrappingAdd: Sized {
26    /// Wrapping (modular) addition. Computes `self + other`, wrapping around at
27    /// the boundary of the type.
28    fn wrapping_add(&self, v: &Self) -> Self;
29}
30
31wrapping_impl!(WrappingAdd, wrapping_add, u8);
32wrapping_impl!(WrappingAdd, wrapping_add, u16);
33wrapping_impl!(WrappingAdd, wrapping_add, u32);
34wrapping_impl!(WrappingAdd, wrapping_add, u64);
35wrapping_impl!(WrappingAdd, wrapping_add, usize);
36wrapping_impl!(WrappingAdd, wrapping_add, u128);
37
38wrapping_impl!(WrappingAdd, wrapping_add, i8);
39wrapping_impl!(WrappingAdd, wrapping_add, i16);
40wrapping_impl!(WrappingAdd, wrapping_add, i32);
41wrapping_impl!(WrappingAdd, wrapping_add, i64);
42wrapping_impl!(WrappingAdd, wrapping_add, isize);
43wrapping_impl!(WrappingAdd, wrapping_add, i128);
44
45wrapping_impl!(WrappingAdd, add, pf16);
46wrapping_impl!(WrappingAdd, add, f32);
47wrapping_impl!(WrappingAdd, add, f64);
48
49#[cfg(feature = "simd")]
50const STRIPE: usize = 16;
51
52fn wrapping_sum_with_mask_scalar<T: Zero + WrappingAdd + Copy>(vals: &[T], mask: &BitMask) -> T {
53    assert!(vals.len() == mask.len());
54    vals.iter()
55        .enumerate()
56        .map(|(i, x)| {
57            // No filter but rather select of 0 for cmov opt.
58            if mask.get(i) { *x } else { T::zero() }
59        })
60        .fold(T::zero(), |a, b| a.wrapping_add(&b))
61}
62
63#[cfg(not(feature = "simd"))]
64impl<T> WrappingSum for T
65where
66    T: NativeType + WrappingAdd + Zero,
67{
68    fn wrapping_sum(vals: &[Self]) -> Self {
69        vals.iter()
70            .copied()
71            .fold(T::zero(), |a, b| a.wrapping_add(&b))
72    }
73
74    fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
75        wrapping_sum_with_mask_scalar(vals, mask)
76    }
77}
78
79#[cfg(feature = "simd")]
80impl<T> WrappingSum for T
81where
82    T: NativeType + WrappingAdd + Zero + crate::SimdPrimitive,
83{
84    fn wrapping_sum(vals: &[Self]) -> Self {
85        vals.iter()
86            .copied()
87            .fold(T::zero(), |a, b| a.wrapping_add(&b))
88    }
89
90    fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
91        assert!(vals.len() == mask.len());
92        let remainder = vals.len() % STRIPE;
93        let (rest, main) = vals.split_at(remainder);
94        let (rest_mask, main_mask) = mask.split_at(remainder);
95        let zero: Simd<T, STRIPE> = Simd::default();
96
97        let vsum = main
98            .chunks_exact(STRIPE)
99            .enumerate()
100            .map(|(i, a)| {
101                let m: Mask<_, STRIPE> = main_mask.get_simd(i * STRIPE);
102                m.select(Simd::from_slice(a), zero)
103            })
104            .fold(zero, |a, b| {
105                let a = a.to_array();
106                let b = b.to_array();
107                Simd::from_array(std::array::from_fn(|i| a[i].wrapping_add(&b[i])))
108            });
109
110        let mainsum = vsum
111            .to_array()
112            .into_iter()
113            .fold(T::zero(), |a, b| a.wrapping_add(&b));
114
115        // TODO: faster remainder.
116        let restsum = wrapping_sum_with_mask_scalar(rest, &rest_mask);
117        mainsum.wrapping_add(&restsum)
118    }
119}
120
121#[cfg(feature = "simd")]
122impl WrappingSum for u128 {
123    fn wrapping_sum(vals: &[Self]) -> Self {
124        vals.iter().copied().fold(0, |a, b| a.wrapping_add(b))
125    }
126
127    fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
128        wrapping_sum_with_mask_scalar(vals, mask)
129    }
130}
131
132#[cfg(feature = "simd")]
133impl WrappingSum for i128 {
134    fn wrapping_sum(vals: &[Self]) -> Self {
135        vals.iter().copied().fold(0, |a, b| a.wrapping_add(b))
136    }
137
138    fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
139        wrapping_sum_with_mask_scalar(vals, mask)
140    }
141}
142
143#[cfg(feature = "simd")]
144impl WrappingSum for pf16 {
145    fn wrapping_sum(_vals: &[Self]) -> Self {
146        unimplemented!("should have been dispatched to other sum kernel")
147    }
148
149    fn wrapping_sum_with_validity(_vals: &[Self], _mask: &BitMask) -> Self {
150        unimplemented!("should have been dispatched to other sum kernel")
151    }
152}
153
154pub trait WrappingSum: Sized {
155    fn wrapping_sum(vals: &[Self]) -> Self;
156    fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self;
157}
158
159pub fn wrapping_sum_arr<T>(arr: &PrimitiveArray<T>) -> T
160where
161    T: NativeType + WrappingSum,
162{
163    let validity = arr.validity().filter(|_| arr.null_count() > 0);
164    if let Some(mask) = validity {
165        WrappingSum::wrapping_sum_with_validity(arr.values(), &BitMask::from_bitmap(mask))
166    } else {
167        WrappingSum::wrapping_sum(arr.values())
168    }
169}