1#![allow(clippy::inline_always)]
8
9use pulp::Arch;
10use rayon::prelude::*;
11
12pub const PARALLEL_REDUCTION_THRESHOLD: usize = 1_000_000;
18
19pub const PARALLEL_SORT_THRESHOLD: usize = 100_000;
21
22const PAIRWISE_BASE: usize = 4096;
41
42#[inline(always)]
44fn base_sum<T: Copy + std::ops::Add<Output = T>>(data: &[T], identity: T) -> T {
45 let n = data.len();
46 let mut acc0 = identity;
47 let mut acc1 = identity;
48 let mut acc2 = identity;
49 let mut acc3 = identity;
50 let mut acc4 = identity;
51 let mut acc5 = identity;
52 let mut acc6 = identity;
53 let mut acc7 = identity;
54 let chunks = n / 8;
55 let rem = n % 8;
56 for i in 0..chunks {
57 let base = i * 8;
58 acc0 = acc0 + data[base];
59 acc1 = acc1 + data[base + 1];
60 acc2 = acc2 + data[base + 2];
61 acc3 = acc3 + data[base + 3];
62 acc4 = acc4 + data[base + 4];
63 acc5 = acc5 + data[base + 5];
64 acc6 = acc6 + data[base + 6];
65 acc7 = acc7 + data[base + 7];
66 }
67 for i in 0..rem {
68 acc0 = acc0 + data[chunks * 8 + i];
69 }
70 (acc0 + acc1) + (acc2 + acc3) + ((acc4 + acc5) + (acc6 + acc7))
71}
72
73pub fn pairwise_sum<T>(data: &[T], identity: T) -> T
81where
82 T: Copy + std::ops::Add<Output = T>,
83{
84 let n = data.len();
85 if n == 0 {
86 return identity;
87 }
88 if n <= PAIRWISE_BASE {
89 return base_sum(data, identity);
90 }
91
92 let mut stack_val: [T; 24] = [identity; 24];
97 let mut stack_lvl: [usize; 24] = [0; 24];
98 let mut depth = 0usize;
99
100 let mut offset = 0;
101 while offset < n {
102 let end = (offset + PAIRWISE_BASE).min(n);
103 let mut current = base_sum(&data[offset..end], identity);
104 offset = end;
105
106 let mut level = 1usize;
108 while depth > 0 && stack_lvl[depth - 1] == level {
109 depth -= 1;
110 current = stack_val[depth] + current;
111 level += 1;
112 }
113 stack_val[depth] = current;
114 stack_lvl[depth] = level;
115 depth += 1;
116 }
117
118 let mut result = stack_val[depth - 1];
120 for i in (0..depth - 1).rev() {
121 result = stack_val[i] + result;
122 }
123 result
124}
125
126#[must_use]
136pub fn pairwise_sum_f64(data: &[f64]) -> f64 {
137 Arch::new().dispatch(PairwiseSumF64Op { data })
138}
139
140struct PairwiseSumF64Op<'a> {
141 data: &'a [f64],
142}
143
144impl pulp::WithSimd for PairwiseSumF64Op<'_> {
145 type Output = f64;
146
147 #[inline(always)]
148 fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
149 simd_pairwise_f64(simd, self.data)
150 }
151}
152
153#[inline(always)]
154fn simd_base_sum_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
155 let n = data.len();
156 let lane_count = size_of::<S::f64s>() / size_of::<f64>();
157 let simd_end = n - (n % lane_count);
158
159 let zero = simd.splat_f64s(0.0);
160 let mut acc0 = zero;
161 let mut acc1 = zero;
162 let mut acc2 = zero;
163 let mut acc3 = zero;
164
165 let stride = lane_count * 4;
167 let unrolled_end = n - (n % stride);
168 let mut i = 0;
169 while i < unrolled_end {
170 let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
171 let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
172 let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
173 let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
174 acc0 = simd.add_f64s(acc0, v0);
175 acc1 = simd.add_f64s(acc1, v1);
176 acc2 = simd.add_f64s(acc2, v2);
177 acc3 = simd.add_f64s(acc3, v3);
178 i += stride;
179 }
180 while i + lane_count <= simd_end {
181 let v = simd.partial_load_f64s(&data[i..i + lane_count]);
182 acc0 = simd.add_f64s(acc0, v);
183 i += lane_count;
184 }
185 acc0 = simd.add_f64s(acc0, acc1);
186 acc2 = simd.add_f64s(acc2, acc3);
187 acc0 = simd.add_f64s(acc0, acc2);
188
189 let mut temp = [0.0f64; 8]; simd.partial_store_f64s(&mut temp[..lane_count], acc0);
192 let mut sum = 0.0f64;
193 for t in temp.iter().take(lane_count) {
194 sum += t;
195 }
196 for &val in &data[simd_end..n] {
198 sum += val;
199 }
200 sum
201}
202
203#[inline(always)]
204fn simd_pairwise_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
205 let n = data.len();
206 if n == 0 {
207 return 0.0;
208 }
209 if n <= PAIRWISE_BASE {
210 return simd_base_sum_f64(simd, data);
211 }
212
213 let mut stack_val = [0.0f64; 24];
214 let mut stack_lvl = [0usize; 24];
215 let mut depth = 0usize;
216
217 let mut offset = 0;
218 while offset < n {
219 let end = (offset + PAIRWISE_BASE).min(n);
220 let mut current = simd_base_sum_f64(simd, &data[offset..end]);
221 offset = end;
222
223 let mut level = 1usize;
224 while depth > 0 && stack_lvl[depth - 1] == level {
225 depth -= 1;
226 current += stack_val[depth];
227 level += 1;
228 }
229 stack_val[depth] = current;
230 stack_lvl[depth] = level;
231 depth += 1;
232 }
233
234 let mut result = stack_val[depth - 1];
235 for i in (0..depth - 1).rev() {
236 result += stack_val[i];
237 }
238 result
239}
240
241#[must_use]
250pub fn simd_sum_sq_diff_f64(data: &[f64], mean: f64) -> f64 {
251 Arch::new().dispatch(SumSqDiffF64Op { data, mean })
252}
253
254struct SumSqDiffF64Op<'a> {
255 data: &'a [f64],
256 mean: f64,
257}
258
259impl pulp::WithSimd for SumSqDiffF64Op<'_> {
260 type Output = f64;
261
262 #[inline(always)]
263 fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
264 let data = self.data;
265 let n = data.len();
266 let lane_count = size_of::<S::f64s>() / size_of::<f64>();
267 let simd_end = n - (n % lane_count);
268
269 let zero = simd.splat_f64s(0.0);
270 let mean_v = simd.splat_f64s(self.mean);
271 let mut acc0 = zero;
272 let mut acc1 = zero;
273 let mut acc2 = zero;
274 let mut acc3 = zero;
275
276 let stride = lane_count * 4;
277 let unrolled_end = n - (n % stride);
278 let mut i = 0;
279 while i < unrolled_end {
280 let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
281 let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
282 let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
283 let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
284 let d0 = simd.sub_f64s(v0, mean_v);
285 let d1 = simd.sub_f64s(v1, mean_v);
286 let d2 = simd.sub_f64s(v2, mean_v);
287 let d3 = simd.sub_f64s(v3, mean_v);
288 acc0 = simd.mul_add_f64s(d0, d0, acc0);
289 acc1 = simd.mul_add_f64s(d1, d1, acc1);
290 acc2 = simd.mul_add_f64s(d2, d2, acc2);
291 acc3 = simd.mul_add_f64s(d3, d3, acc3);
292 i += stride;
293 }
294 while i + lane_count <= simd_end {
295 let v = simd.partial_load_f64s(&data[i..i + lane_count]);
296 let d = simd.sub_f64s(v, mean_v);
297 acc0 = simd.mul_add_f64s(d, d, acc0);
298 i += lane_count;
299 }
300 acc0 = simd.add_f64s(acc0, acc1);
301 acc2 = simd.add_f64s(acc2, acc3);
302 acc0 = simd.add_f64s(acc0, acc2);
303
304 let mut temp = [0.0f64; 8];
305 simd.partial_store_f64s(&mut temp[..lane_count], acc0);
306 let mut sum = 0.0f64;
307 for t in temp.iter().take(lane_count) {
308 sum += t;
309 }
310 for &val in &data[simd_end..n] {
311 let d = val - self.mean;
312 sum += d * d;
313 }
314 sum
315 }
316}
317
318#[must_use]
330pub fn pairwise_sum_f32(data: &[f32]) -> f32 {
331 Arch::new().dispatch(PairwiseSumF32Op { data })
332}
333
334struct PairwiseSumF32Op<'a> {
335 data: &'a [f32],
336}
337
338impl pulp::WithSimd for PairwiseSumF32Op<'_> {
339 type Output = f32;
340
341 #[inline(always)]
342 fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
343 simd_pairwise_f32(simd, self.data)
344 }
345}
346
347#[inline(always)]
348fn simd_base_sum_f32<S: pulp::Simd>(simd: S, data: &[f32]) -> f32 {
349 let n = data.len();
350 let lane_count = size_of::<S::f32s>() / size_of::<f32>();
351 let simd_end = n - (n % lane_count);
352
353 let zero = simd.splat_f32s(0.0);
354 let mut acc0 = zero;
355 let mut acc1 = zero;
356 let mut acc2 = zero;
357 let mut acc3 = zero;
358
359 let stride = lane_count * 4;
360 let unrolled_end = n - (n % stride);
361 let mut i = 0;
362 while i < unrolled_end {
363 let v0 = simd.partial_load_f32s(&data[i..i + lane_count]);
364 let v1 = simd.partial_load_f32s(&data[i + lane_count..i + lane_count * 2]);
365 let v2 = simd.partial_load_f32s(&data[i + lane_count * 2..i + lane_count * 3]);
366 let v3 = simd.partial_load_f32s(&data[i + lane_count * 3..i + stride]);
367 acc0 = simd.add_f32s(acc0, v0);
368 acc1 = simd.add_f32s(acc1, v1);
369 acc2 = simd.add_f32s(acc2, v2);
370 acc3 = simd.add_f32s(acc3, v3);
371 i += stride;
372 }
373 while i + lane_count <= simd_end {
374 let v = simd.partial_load_f32s(&data[i..i + lane_count]);
375 acc0 = simd.add_f32s(acc0, v);
376 i += lane_count;
377 }
378 acc0 = simd.add_f32s(acc0, acc1);
379 acc2 = simd.add_f32s(acc2, acc3);
380 acc0 = simd.add_f32s(acc0, acc2);
381
382 let mut temp = [0.0f32; 16];
385 simd.partial_store_f32s(&mut temp[..lane_count], acc0);
386 let mut sum = 0.0f32;
387 for t in temp.iter().take(lane_count) {
388 sum += t;
389 }
390 for &val in &data[simd_end..n] {
391 sum += val;
392 }
393 sum
394}
395
396#[inline(always)]
397fn simd_pairwise_f32<S: pulp::Simd>(simd: S, data: &[f32]) -> f32 {
398 let n = data.len();
399 if n == 0 {
400 return 0.0;
401 }
402 if n <= PAIRWISE_BASE {
403 return simd_base_sum_f32(simd, data);
404 }
405
406 let mut stack_val = [0.0f32; 24];
407 let mut stack_lvl = [0usize; 24];
408 let mut depth = 0usize;
409
410 let mut offset = 0;
411 while offset < n {
412 let end = (offset + PAIRWISE_BASE).min(n);
413 let mut current = simd_base_sum_f32(simd, &data[offset..end]);
414 offset = end;
415
416 let mut level = 1usize;
417 while depth > 0 && stack_lvl[depth - 1] == level {
418 depth -= 1;
419 current += stack_val[depth];
420 level += 1;
421 }
422 stack_val[depth] = current;
423 stack_lvl[depth] = level;
424 depth += 1;
425 }
426
427 let mut result = stack_val[depth - 1];
428 for i in (0..depth - 1).rev() {
429 result += stack_val[i];
430 }
431 result
432}
433
434#[must_use]
436pub fn simd_sum_sq_diff_f32(data: &[f32], mean: f32) -> f32 {
437 Arch::new().dispatch(SumSqDiffF32Op { data, mean })
438}
439
440struct SumSqDiffF32Op<'a> {
441 data: &'a [f32],
442 mean: f32,
443}
444
445impl pulp::WithSimd for SumSqDiffF32Op<'_> {
446 type Output = f32;
447
448 #[inline(always)]
449 fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
450 let data = self.data;
451 let n = data.len();
452 let lane_count = size_of::<S::f32s>() / size_of::<f32>();
453 let simd_end = n - (n % lane_count);
454
455 let zero = simd.splat_f32s(0.0);
456 let mean_v = simd.splat_f32s(self.mean);
457 let mut acc0 = zero;
458 let mut acc1 = zero;
459 let mut acc2 = zero;
460 let mut acc3 = zero;
461
462 let stride = lane_count * 4;
463 let unrolled_end = n - (n % stride);
464 let mut i = 0;
465 while i < unrolled_end {
466 let v0 = simd.partial_load_f32s(&data[i..i + lane_count]);
467 let v1 = simd.partial_load_f32s(&data[i + lane_count..i + lane_count * 2]);
468 let v2 = simd.partial_load_f32s(&data[i + lane_count * 2..i + lane_count * 3]);
469 let v3 = simd.partial_load_f32s(&data[i + lane_count * 3..i + stride]);
470 let d0 = simd.sub_f32s(v0, mean_v);
471 let d1 = simd.sub_f32s(v1, mean_v);
472 let d2 = simd.sub_f32s(v2, mean_v);
473 let d3 = simd.sub_f32s(v3, mean_v);
474 acc0 = simd.mul_add_f32s(d0, d0, acc0);
475 acc1 = simd.mul_add_f32s(d1, d1, acc1);
476 acc2 = simd.mul_add_f32s(d2, d2, acc2);
477 acc3 = simd.mul_add_f32s(d3, d3, acc3);
478 i += stride;
479 }
480 while i + lane_count <= simd_end {
481 let v = simd.partial_load_f32s(&data[i..i + lane_count]);
482 let d = simd.sub_f32s(v, mean_v);
483 acc0 = simd.mul_add_f32s(d, d, acc0);
484 i += lane_count;
485 }
486 acc0 = simd.add_f32s(acc0, acc1);
487 acc2 = simd.add_f32s(acc2, acc3);
488 acc0 = simd.add_f32s(acc0, acc2);
489
490 let mut temp = [0.0f32; 16];
491 simd.partial_store_f32s(&mut temp[..lane_count], acc0);
492 let mut sum = 0.0f32;
493 for t in temp.iter().take(lane_count) {
494 sum += t;
495 }
496 for &val in &data[simd_end..n] {
497 let d = val - self.mean;
498 sum += d * d;
499 }
500 sum
501 }
502}
503
504pub fn parallel_sum<T>(data: &[T], identity: T) -> T
513where
514 T: Copy + Send + Sync + std::ops::Add<Output = T>,
515{
516 if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
517 data.par_iter().copied().reduce(|| identity, |a, b| a + b)
519 } else {
520 pairwise_sum(data, identity)
521 }
522}
523
524pub fn parallel_prod<T>(data: &[T], identity: T) -> T
526where
527 T: Copy + Send + Sync + std::ops::Mul<Output = T>,
528{
529 if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
530 data.par_iter().copied().reduce(|| identity, |a, b| a * b)
531 } else {
532 data.iter().copied().fold(identity, |a, b| a * b)
533 }
534}
535
536pub fn parallel_sort<T>(data: &mut [T])
538where
539 T: Copy + Send + Sync + PartialOrd,
540{
541 if data.len() >= PARALLEL_SORT_THRESHOLD {
542 data.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
543 } else {
544 data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
545 }
546}
547
548pub fn parallel_sort_stable<T>(data: &mut [T])
550where
551 T: Copy + Send + Sync + PartialOrd,
552{
553 if data.len() >= PARALLEL_SORT_THRESHOLD {
554 data.par_sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
555 } else {
556 data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
557 }
558}