1#![allow(clippy::inline_always)]
8
9use pulp::Arch;
10use rayon::prelude::*;
11
12pub const PARALLEL_REDUCTION_THRESHOLD: usize = 1_000_000;
18
19pub const PARALLEL_SORT_THRESHOLD: usize = 100_000;
21
22const PAIRWISE_BASE: usize = 256;
28
29#[inline(always)]
31fn base_sum<T: Copy + std::ops::Add<Output = T>>(data: &[T], identity: T) -> T {
32 let n = data.len();
33 let mut acc0 = identity;
34 let mut acc1 = identity;
35 let mut acc2 = identity;
36 let mut acc3 = identity;
37 let mut acc4 = identity;
38 let mut acc5 = identity;
39 let mut acc6 = identity;
40 let mut acc7 = identity;
41 let chunks = n / 8;
42 let rem = n % 8;
43 for i in 0..chunks {
44 let base = i * 8;
45 acc0 = acc0 + data[base];
46 acc1 = acc1 + data[base + 1];
47 acc2 = acc2 + data[base + 2];
48 acc3 = acc3 + data[base + 3];
49 acc4 = acc4 + data[base + 4];
50 acc5 = acc5 + data[base + 5];
51 acc6 = acc6 + data[base + 6];
52 acc7 = acc7 + data[base + 7];
53 }
54 for i in 0..rem {
55 acc0 = acc0 + data[chunks * 8 + i];
56 }
57 (acc0 + acc1) + (acc2 + acc3) + ((acc4 + acc5) + (acc6 + acc7))
58}
59
60pub fn pairwise_sum<T>(data: &[T], identity: T) -> T
68where
69 T: Copy + std::ops::Add<Output = T>,
70{
71 let n = data.len();
72 if n == 0 {
73 return identity;
74 }
75 if n <= PAIRWISE_BASE {
76 return base_sum(data, identity);
77 }
78
79 let mut stack_val: [T; 24] = [identity; 24];
84 let mut stack_lvl: [usize; 24] = [0; 24];
85 let mut depth = 0usize;
86
87 let mut offset = 0;
88 while offset < n {
89 let end = (offset + PAIRWISE_BASE).min(n);
90 let mut current = base_sum(&data[offset..end], identity);
91 offset = end;
92
93 let mut level = 1usize;
95 while depth > 0 && stack_lvl[depth - 1] == level {
96 depth -= 1;
97 current = stack_val[depth] + current;
98 level += 1;
99 }
100 stack_val[depth] = current;
101 stack_lvl[depth] = level;
102 depth += 1;
103 }
104
105 let mut result = stack_val[depth - 1];
107 for i in (0..depth - 1).rev() {
108 result = stack_val[i] + result;
109 }
110 result
111}
112
113#[must_use]
123pub fn pairwise_sum_f64(data: &[f64]) -> f64 {
124 Arch::new().dispatch(PairwiseSumF64Op { data })
125}
126
127struct PairwiseSumF64Op<'a> {
128 data: &'a [f64],
129}
130
131impl pulp::WithSimd for PairwiseSumF64Op<'_> {
132 type Output = f64;
133
134 #[inline(always)]
135 fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
136 simd_pairwise_f64(simd, self.data)
137 }
138}
139
140#[inline(always)]
141fn simd_base_sum_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
142 let n = data.len();
143 let lane_count = size_of::<S::f64s>() / size_of::<f64>();
144 let simd_end = n - (n % lane_count);
145
146 let zero = simd.splat_f64s(0.0);
147 let mut acc0 = zero;
148 let mut acc1 = zero;
149 let mut acc2 = zero;
150 let mut acc3 = zero;
151
152 let stride = lane_count * 4;
154 let unrolled_end = n - (n % stride);
155 let mut i = 0;
156 while i < unrolled_end {
157 let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
158 let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
159 let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
160 let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
161 acc0 = simd.add_f64s(acc0, v0);
162 acc1 = simd.add_f64s(acc1, v1);
163 acc2 = simd.add_f64s(acc2, v2);
164 acc3 = simd.add_f64s(acc3, v3);
165 i += stride;
166 }
167 while i + lane_count <= simd_end {
168 let v = simd.partial_load_f64s(&data[i..i + lane_count]);
169 acc0 = simd.add_f64s(acc0, v);
170 i += lane_count;
171 }
172 acc0 = simd.add_f64s(acc0, acc1);
173 acc2 = simd.add_f64s(acc2, acc3);
174 acc0 = simd.add_f64s(acc0, acc2);
175
176 let mut temp = [0.0f64; 8]; simd.partial_store_f64s(&mut temp[..lane_count], acc0);
179 let mut sum = 0.0f64;
180 for t in temp.iter().take(lane_count) {
181 sum += t;
182 }
183 for &val in &data[simd_end..n] {
185 sum += val;
186 }
187 sum
188}
189
190#[inline(always)]
191fn simd_pairwise_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
192 let n = data.len();
193 if n == 0 {
194 return 0.0;
195 }
196 if n <= PAIRWISE_BASE {
197 return simd_base_sum_f64(simd, data);
198 }
199
200 let mut stack_val = [0.0f64; 24];
201 let mut stack_lvl = [0usize; 24];
202 let mut depth = 0usize;
203
204 let mut offset = 0;
205 while offset < n {
206 let end = (offset + PAIRWISE_BASE).min(n);
207 let mut current = simd_base_sum_f64(simd, &data[offset..end]);
208 offset = end;
209
210 let mut level = 1usize;
211 while depth > 0 && stack_lvl[depth - 1] == level {
212 depth -= 1;
213 current += stack_val[depth];
214 level += 1;
215 }
216 stack_val[depth] = current;
217 stack_lvl[depth] = level;
218 depth += 1;
219 }
220
221 let mut result = stack_val[depth - 1];
222 for i in (0..depth - 1).rev() {
223 result += stack_val[i];
224 }
225 result
226}
227
228#[must_use]
237pub fn simd_sum_sq_diff_f64(data: &[f64], mean: f64) -> f64 {
238 Arch::new().dispatch(SumSqDiffF64Op { data, mean })
239}
240
241struct SumSqDiffF64Op<'a> {
242 data: &'a [f64],
243 mean: f64,
244}
245
246impl pulp::WithSimd for SumSqDiffF64Op<'_> {
247 type Output = f64;
248
249 #[inline(always)]
250 fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
251 let data = self.data;
252 let n = data.len();
253 let lane_count = size_of::<S::f64s>() / size_of::<f64>();
254 let simd_end = n - (n % lane_count);
255
256 let zero = simd.splat_f64s(0.0);
257 let mean_v = simd.splat_f64s(self.mean);
258 let mut acc0 = zero;
259 let mut acc1 = zero;
260 let mut acc2 = zero;
261 let mut acc3 = zero;
262
263 let stride = lane_count * 4;
264 let unrolled_end = n - (n % stride);
265 let mut i = 0;
266 while i < unrolled_end {
267 let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
268 let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
269 let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
270 let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
271 let d0 = simd.sub_f64s(v0, mean_v);
272 let d1 = simd.sub_f64s(v1, mean_v);
273 let d2 = simd.sub_f64s(v2, mean_v);
274 let d3 = simd.sub_f64s(v3, mean_v);
275 acc0 = simd.mul_add_f64s(d0, d0, acc0);
276 acc1 = simd.mul_add_f64s(d1, d1, acc1);
277 acc2 = simd.mul_add_f64s(d2, d2, acc2);
278 acc3 = simd.mul_add_f64s(d3, d3, acc3);
279 i += stride;
280 }
281 while i + lane_count <= simd_end {
282 let v = simd.partial_load_f64s(&data[i..i + lane_count]);
283 let d = simd.sub_f64s(v, mean_v);
284 acc0 = simd.mul_add_f64s(d, d, acc0);
285 i += lane_count;
286 }
287 acc0 = simd.add_f64s(acc0, acc1);
288 acc2 = simd.add_f64s(acc2, acc3);
289 acc0 = simd.add_f64s(acc0, acc2);
290
291 let mut temp = [0.0f64; 8];
292 simd.partial_store_f64s(&mut temp[..lane_count], acc0);
293 let mut sum = 0.0f64;
294 for t in temp.iter().take(lane_count) {
295 sum += t;
296 }
297 for &val in &data[simd_end..n] {
298 let d = val - self.mean;
299 sum += d * d;
300 }
301 sum
302 }
303}
304
305pub fn parallel_sum<T>(data: &[T], identity: T) -> T
314where
315 T: Copy + Send + Sync + std::ops::Add<Output = T>,
316{
317 if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
318 data.par_iter().copied().reduce(|| identity, |a, b| a + b)
320 } else {
321 pairwise_sum(data, identity)
322 }
323}
324
325pub fn parallel_prod<T>(data: &[T], identity: T) -> T
327where
328 T: Copy + Send + Sync + std::ops::Mul<Output = T>,
329{
330 if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
331 data.par_iter().copied().reduce(|| identity, |a, b| a * b)
332 } else {
333 data.iter().copied().fold(identity, |a, b| a * b)
334 }
335}
336
337pub fn parallel_sort<T>(data: &mut [T])
339where
340 T: Copy + Send + Sync + PartialOrd,
341{
342 if data.len() >= PARALLEL_SORT_THRESHOLD {
343 data.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
344 } else {
345 data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
346 }
347}
348
349pub fn parallel_sort_stable<T>(data: &mut [T])
351where
352 T: Copy + Send + Sync + PartialOrd,
353{
354 if data.len() >= PARALLEL_SORT_THRESHOLD {
355 data.par_sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
356 } else {
357 data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
358 }
359}