1use rten_simd::ops::NumOps;
2use rten_simd::{Isa, Simd, SimdIterable, SimdOp};
3
4pub struct Sum<'a> {
11 input: &'a [f32],
12}
13
14impl<'a> Sum<'a> {
15 pub fn new(input: &'a [f32]) -> Self {
16 Sum { input }
17 }
18}
19
20impl SimdOp for Sum<'_> {
21 type Output = f32;
22
23 #[inline(always)]
24 fn eval<I: Isa>(self, isa: I) -> Self::Output {
25 let ops = isa.f32();
26 let vec_sum = self.input.simd_iter(ops).fold_unroll::<4>(
27 ops.zero(),
28 |sum, x| ops.add(sum, x),
29 |sum, x| ops.add(sum, x),
30 );
31 vec_sum.to_array().into_iter().sum()
32 }
33}
34
35pub struct SumSquare<'a> {
43 input: &'a [f32],
44}
45
46impl<'a> SumSquare<'a> {
47 pub fn new(input: &'a [f32]) -> Self {
48 SumSquare { input }
49 }
50}
51
52impl SimdOp for SumSquare<'_> {
53 type Output = f32;
54
55 #[inline(always)]
56 fn eval<I: Isa>(self, isa: I) -> Self::Output {
57 let ops = isa.f32();
58 let vec_sum = self.input.simd_iter(ops).fold_unroll::<4>(
59 ops.zero(),
60 |sum, x| ops.mul_add(x, x, sum),
61 |sum, x| ops.add(sum, x),
62 );
63 vec_sum.to_array().into_iter().sum()
64 }
65}
66
67pub struct SumSquareSub<'a> {
73 input: &'a [f32],
74 offset: f32,
75}
76
77impl<'a> SumSquareSub<'a> {
78 pub fn new(input: &'a [f32], offset: f32) -> Self {
79 SumSquareSub { input, offset }
80 }
81}
82
83impl SimdOp for SumSquareSub<'_> {
84 type Output = f32;
85
86 #[inline(always)]
87 fn eval<I: Isa>(self, isa: I) -> Self::Output {
88 let ops = isa.f32();
89 let offset_vec = ops.splat(self.offset);
90
91 let vec_sum = self.input.simd_iter(ops).fold_unroll::<4>(
92 ops.zero(),
93 |sum, x| {
94 let x_offset = ops.sub(x, offset_vec);
95 ops.mul_add(x_offset, x_offset, sum)
96 },
97 |sum, x| ops.add(sum, x),
98 );
99
100 vec_sum.to_array().into_iter().sum()
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use crate::ulp::assert_ulp_diff_le;
107
108 use super::{Sum, SumSquare, SumSquareSub};
109 use rten_simd::SimdOp;
110
111 const LEN: usize = 100;
114
115 #[test]
116 fn test_sum() {
117 let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
118 let expected_sum: f64 = xs.iter().map(|x| *x as f64).sum();
119 let sum = Sum::new(&xs).dispatch();
120 assert_ulp_diff_le!(sum, expected_sum as f32, 1.0);
121 }
122
123 #[test]
124 fn test_sum_square() {
125 let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
126 let expected_sum: f64 = xs.iter().copied().map(|x| x as f64 * x as f64).sum();
127 let sum = SumSquare::new(&xs).dispatch();
128 assert_ulp_diff_le!(sum, expected_sum as f32, 2.0);
129 }
130
131 #[test]
132 fn test_sum_square_sub() {
133 let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
134 let mean = xs.iter().sum::<f32>() / xs.len() as f32;
135 let expected_sum: f64 = xs
136 .iter()
137 .copied()
138 .map(|x| (x as f64 - mean as f64) * (x as f64 - mean as f64))
139 .sum();
140 let sum = SumSquareSub::new(&xs, mean).dispatch();
141 assert_ulp_diff_le!(sum, expected_sum as f32, 2.0);
142 }
143}