rten_vecmath/
sum.rs

1use rten_simd::ops::NumOps;
2use rten_simd::{Isa, Simd, SimdIterable, SimdOp};
3
4/// Computes the sum of a sequence of numbers.
5///
6/// This is more efficient than `slice.iter().sum()` as it computes multiple
7/// partial sums in parallel using SIMD and then sums across the SIMD lanes at
8/// the end. This will produce very slightly different results because the
9/// additions are happening in a different order.
10pub struct Sum<'a> {
11    input: &'a [f32],
12}
13
14impl<'a> Sum<'a> {
15    pub fn new(input: &'a [f32]) -> Self {
16        Sum { input }
17    }
18}
19
20impl SimdOp for Sum<'_> {
21    type Output = f32;
22
23    #[inline(always)]
24    fn eval<I: Isa>(self, isa: I) -> Self::Output {
25        let ops = isa.f32();
26        let vec_sum = self.input.simd_iter(ops).fold_unroll::<4>(
27            ops.zero(),
28            |sum, x| ops.add(sum, x),
29            |sum, x| ops.add(sum, x),
30        );
31        vec_sum.to_array().into_iter().sum()
32    }
33}
34
35/// Computes the sum of squares of a sequence of numbers.
36///
37/// This is conceptually equivalent to `slice.iter().map(|&x| x * x).sum()` but
38/// more efficient as it computes multiple partial sums in parallel using SIMD
39/// and then sums across the SIMD lanes at the end. This will produce very
40/// slightly different results because the additions are happening in a
41/// different order.
42pub struct SumSquare<'a> {
43    input: &'a [f32],
44}
45
46impl<'a> SumSquare<'a> {
47    pub fn new(input: &'a [f32]) -> Self {
48        SumSquare { input }
49    }
50}
51
52impl SimdOp for SumSquare<'_> {
53    type Output = f32;
54
55    #[inline(always)]
56    fn eval<I: Isa>(self, isa: I) -> Self::Output {
57        let ops = isa.f32();
58        let vec_sum = self.input.simd_iter(ops).fold_unroll::<4>(
59            ops.zero(),
60            |sum, x| ops.mul_add(x, x, sum),
61            |sum, x| ops.add(sum, x),
62        );
63        vec_sum.to_array().into_iter().sum()
64    }
65}
66
67/// Compute the sum of squares of input with a bias subtracted.
68///
69/// This is a variant of [`SumSquare`] which subtracts a constant value from each
70/// element before squaring it. A typical use case is to compute the variance of
71/// a sequence, which is defined as `mean((X - x_mean)^2)`.
72pub struct SumSquareSub<'a> {
73    input: &'a [f32],
74    offset: f32,
75}
76
77impl<'a> SumSquareSub<'a> {
78    pub fn new(input: &'a [f32], offset: f32) -> Self {
79        SumSquareSub { input, offset }
80    }
81}
82
83impl SimdOp for SumSquareSub<'_> {
84    type Output = f32;
85
86    #[inline(always)]
87    fn eval<I: Isa>(self, isa: I) -> Self::Output {
88        let ops = isa.f32();
89        let offset_vec = ops.splat(self.offset);
90
91        let vec_sum = self.input.simd_iter(ops).fold_unroll::<4>(
92            ops.zero(),
93            |sum, x| {
94                let x_offset = ops.sub(x, offset_vec);
95                ops.mul_add(x_offset, x_offset, sum)
96            },
97            |sum, x| ops.add(sum, x),
98        );
99
100        vec_sum.to_array().into_iter().sum()
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use crate::ulp::assert_ulp_diff_le;
107
108    use super::{Sum, SumSquare, SumSquareSub};
109    use rten_simd::SimdOp;
110
111    // Chosen to not be a multiple of vector size, so that tail handling is
112    // exercised.
113    const LEN: usize = 100;
114
115    #[test]
116    fn test_sum() {
117        let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
118        let expected_sum: f64 = xs.iter().map(|x| *x as f64).sum();
119        let sum = Sum::new(&xs).dispatch();
120        assert_ulp_diff_le!(sum, expected_sum as f32, 1.0);
121    }
122
123    #[test]
124    fn test_sum_square() {
125        let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
126        let expected_sum: f64 = xs.iter().copied().map(|x| x as f64 * x as f64).sum();
127        let sum = SumSquare::new(&xs).dispatch();
128        assert_ulp_diff_le!(sum, expected_sum as f32, 2.0);
129    }
130
131    #[test]
132    fn test_sum_square_sub() {
133        let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
134        let mean = xs.iter().sum::<f32>() / xs.len() as f32;
135        let expected_sum: f64 = xs
136            .iter()
137            .copied()
138            .map(|x| (x as f64 - mean as f64) * (x as f64 - mean as f64))
139            .sum();
140        let sum = SumSquareSub::new(&xs, mean).dispatch();
141        assert_ulp_diff_le!(sum, expected_sum as f32, 2.0);
142    }
143}