Skip to main content

polars_compute/
sum.rs

1use std::ops::Add;
2#[cfg(feature = "simd")]
3use std::simd::Select;
4#[cfg(feature = "simd")]
5use std::simd::prelude::*;
6
7use arrow::array::{Array, PrimitiveArray};
8use arrow::bitmap::bitmask::BitMask;
9use arrow::types::NativeType;
10use num_traits::Zero;
11use polars_utils::float16::pf16;
12
13macro_rules! wrapping_impl {
14    ($trait_name:ident, $method:ident, $t:ty) => {
15        impl $trait_name for $t {
16            #[inline(always)]
17            fn wrapping_add(&self, v: &Self) -> Self {
18                <$t>::$method(*self, *v)
19            }
20        }
21    };
22}
23
24/// Performs addition that wraps around on overflow.
25///
26/// Differs from num::WrappingAdd in that this is also implemented for floats.
27pub trait WrappingAdd: Sized {
28    /// Wrapping (modular) addition. Computes `self + other`, wrapping around at
29    /// the boundary of the type.
30    fn wrapping_add(&self, v: &Self) -> Self;
31}
32
33wrapping_impl!(WrappingAdd, wrapping_add, u8);
34wrapping_impl!(WrappingAdd, wrapping_add, u16);
35wrapping_impl!(WrappingAdd, wrapping_add, u32);
36wrapping_impl!(WrappingAdd, wrapping_add, u64);
37wrapping_impl!(WrappingAdd, wrapping_add, usize);
38wrapping_impl!(WrappingAdd, wrapping_add, u128);
39
40wrapping_impl!(WrappingAdd, wrapping_add, i8);
41wrapping_impl!(WrappingAdd, wrapping_add, i16);
42wrapping_impl!(WrappingAdd, wrapping_add, i32);
43wrapping_impl!(WrappingAdd, wrapping_add, i64);
44wrapping_impl!(WrappingAdd, wrapping_add, isize);
45wrapping_impl!(WrappingAdd, wrapping_add, i128);
46
47wrapping_impl!(WrappingAdd, add, pf16);
48wrapping_impl!(WrappingAdd, add, f32);
49wrapping_impl!(WrappingAdd, add, f64);
50
51#[cfg(feature = "simd")]
52const STRIPE: usize = 16;
53
54fn wrapping_sum_with_mask_scalar<T: Zero + WrappingAdd + Copy>(vals: &[T], mask: &BitMask) -> T {
55    assert!(vals.len() == mask.len());
56    vals.iter()
57        .enumerate()
58        .map(|(i, x)| {
59            // No filter but rather select of 0 for cmov opt.
60            if mask.get(i) { *x } else { T::zero() }
61        })
62        .fold(T::zero(), |a, b| a.wrapping_add(&b))
63}
64
65#[cfg(not(feature = "simd"))]
66impl<T> WrappingSum for T
67where
68    T: NativeType + WrappingAdd + Zero,
69{
70    fn wrapping_sum(vals: &[Self]) -> Self {
71        vals.iter()
72            .copied()
73            .fold(T::zero(), |a, b| a.wrapping_add(&b))
74    }
75
76    fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
77        wrapping_sum_with_mask_scalar(vals, mask)
78    }
79}
80
81#[cfg(feature = "simd")]
82impl<T> WrappingSum for T
83where
84    T: NativeType + WrappingAdd + Zero + crate::SimdPrimitive,
85{
86    fn wrapping_sum(vals: &[Self]) -> Self {
87        vals.iter()
88            .copied()
89            .fold(T::zero(), |a, b| a.wrapping_add(&b))
90    }
91
92    fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
93        assert!(vals.len() == mask.len());
94        let remainder = vals.len() % STRIPE;
95        let (rest, main) = vals.split_at(remainder);
96        let (rest_mask, main_mask) = mask.split_at(remainder);
97        let zero: Simd<T, STRIPE> = Simd::default();
98
99        let vsum = main
100            .chunks_exact(STRIPE)
101            .enumerate()
102            .map(|(i, a)| {
103                let m: Mask<T::Mask, STRIPE> = main_mask.get_simd(i * STRIPE);
104                m.select(Simd::from_slice(a), zero)
105            })
106            .fold(zero, |a, b| {
107                let a = a.to_array();
108                let b = b.to_array();
109                Simd::from_array(std::array::from_fn(|i| a[i].wrapping_add(&b[i])))
110            });
111
112        let mainsum = vsum
113            .to_array()
114            .into_iter()
115            .fold(T::zero(), |a, b| a.wrapping_add(&b));
116
117        // TODO: faster remainder.
118        let restsum = wrapping_sum_with_mask_scalar(rest, &rest_mask);
119        mainsum.wrapping_add(&restsum)
120    }
121}
122
123#[cfg(feature = "simd")]
124impl WrappingSum for u128 {
125    fn wrapping_sum(vals: &[Self]) -> Self {
126        vals.iter().copied().fold(0, |a, b| a.wrapping_add(b))
127    }
128
129    fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
130        wrapping_sum_with_mask_scalar(vals, mask)
131    }
132}
133
134#[cfg(feature = "simd")]
135impl WrappingSum for i128 {
136    fn wrapping_sum(vals: &[Self]) -> Self {
137        vals.iter().copied().fold(0, |a, b| a.wrapping_add(b))
138    }
139
140    fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self {
141        wrapping_sum_with_mask_scalar(vals, mask)
142    }
143}
144
145#[cfg(feature = "simd")]
146impl WrappingSum for pf16 {
147    fn wrapping_sum(_vals: &[Self]) -> Self {
148        unimplemented!("should have been dispatched to other sum kernel")
149    }
150
151    fn wrapping_sum_with_validity(_vals: &[Self], _mask: &BitMask) -> Self {
152        unimplemented!("should have been dispatched to other sum kernel")
153    }
154}
155
156pub trait WrappingSum: Sized {
157    fn wrapping_sum(vals: &[Self]) -> Self;
158    fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self;
159}
160
161pub fn wrapping_sum_arr<T>(arr: &PrimitiveArray<T>) -> T
162where
163    T: NativeType + WrappingSum,
164{
165    let validity = arr.validity().filter(|_| arr.null_count() > 0);
166    if let Some(mask) = validity {
167        WrappingSum::wrapping_sum_with_validity(arr.values(), &BitMask::from_bitmap(mask))
168    } else {
169        WrappingSum::wrapping_sum(arr.values())
170    }
171}