1#![allow(clippy::inline_always)]
26
27use pulp::Arch;
28use rayon::prelude::*;
29
30pub const PARALLEL_REDUCTION_THRESHOLD: usize = 1_000_000;
36
37pub const PARALLEL_SORT_THRESHOLD: usize = 100_000;
39
40const PAIRWISE_BASE: usize = 4096;
59
60#[inline(always)]
62fn base_sum<T: Copy + std::ops::Add<Output = T>>(data: &[T], identity: T) -> T {
63 let n = data.len();
64 let mut acc0 = identity;
65 let mut acc1 = identity;
66 let mut acc2 = identity;
67 let mut acc3 = identity;
68 let mut acc4 = identity;
69 let mut acc5 = identity;
70 let mut acc6 = identity;
71 let mut acc7 = identity;
72 let chunks = n / 8;
73 let rem = n % 8;
74 for i in 0..chunks {
75 let base = i * 8;
76 acc0 = acc0 + data[base];
77 acc1 = acc1 + data[base + 1];
78 acc2 = acc2 + data[base + 2];
79 acc3 = acc3 + data[base + 3];
80 acc4 = acc4 + data[base + 4];
81 acc5 = acc5 + data[base + 5];
82 acc6 = acc6 + data[base + 6];
83 acc7 = acc7 + data[base + 7];
84 }
85 for i in 0..rem {
86 acc0 = acc0 + data[chunks * 8 + i];
87 }
88 (acc0 + acc1) + (acc2 + acc3) + ((acc4 + acc5) + (acc6 + acc7))
89}
90
91pub fn pairwise_sum<T>(data: &[T], identity: T) -> T
99where
100 T: Copy + std::ops::Add<Output = T>,
101{
102 let n = data.len();
103 if n == 0 {
104 return identity;
105 }
106 if n <= PAIRWISE_BASE {
107 return base_sum(data, identity);
108 }
109
110 let mut stack_val: [T; 24] = [identity; 24];
115 let mut stack_lvl: [usize; 24] = [0; 24];
116 let mut depth = 0usize;
117
118 let mut offset = 0;
119 while offset < n {
120 let end = (offset + PAIRWISE_BASE).min(n);
121 let mut current = base_sum(&data[offset..end], identity);
122 offset = end;
123
124 let mut level = 1usize;
126 while depth > 0 && stack_lvl[depth - 1] == level {
127 depth -= 1;
128 current = stack_val[depth] + current;
129 level += 1;
130 }
131 stack_val[depth] = current;
132 stack_lvl[depth] = level;
133 depth += 1;
134 }
135
136 let mut result = stack_val[depth - 1];
138 for i in (0..depth - 1).rev() {
139 result = stack_val[i] + result;
140 }
141 result
142}
143
144#[must_use]
154pub fn pairwise_sum_f64(data: &[f64]) -> f64 {
155 Arch::new().dispatch(PairwiseSumF64Op { data })
156}
157
158struct PairwiseSumF64Op<'a> {
159 data: &'a [f64],
160}
161
162impl pulp::WithSimd for PairwiseSumF64Op<'_> {
163 type Output = f64;
164
165 #[inline(always)]
166 fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
167 simd_pairwise_f64(simd, self.data)
168 }
169}
170
171#[inline(always)]
172fn simd_base_sum_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
173 let n = data.len();
174 let lane_count = size_of::<S::f64s>() / size_of::<f64>();
175 let simd_end = n - (n % lane_count);
176
177 let zero = simd.splat_f64s(0.0);
178 let mut acc0 = zero;
179 let mut acc1 = zero;
180 let mut acc2 = zero;
181 let mut acc3 = zero;
182
183 let stride = lane_count * 4;
185 let unrolled_end = n - (n % stride);
186 let mut i = 0;
187 while i < unrolled_end {
188 let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
189 let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
190 let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
191 let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
192 acc0 = simd.add_f64s(acc0, v0);
193 acc1 = simd.add_f64s(acc1, v1);
194 acc2 = simd.add_f64s(acc2, v2);
195 acc3 = simd.add_f64s(acc3, v3);
196 i += stride;
197 }
198 while i + lane_count <= simd_end {
199 let v = simd.partial_load_f64s(&data[i..i + lane_count]);
200 acc0 = simd.add_f64s(acc0, v);
201 i += lane_count;
202 }
203 acc0 = simd.add_f64s(acc0, acc1);
204 acc2 = simd.add_f64s(acc2, acc3);
205 acc0 = simd.add_f64s(acc0, acc2);
206
207 let mut temp = [0.0f64; 8]; simd.partial_store_f64s(&mut temp[..lane_count], acc0);
210 let mut sum = 0.0f64;
211 for t in temp.iter().take(lane_count) {
212 sum += t;
213 }
214 for &val in &data[simd_end..n] {
216 sum += val;
217 }
218 sum
219}
220
221#[inline(always)]
222fn simd_pairwise_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
223 let n = data.len();
224 if n == 0 {
225 return 0.0;
226 }
227 if n <= PAIRWISE_BASE {
228 return simd_base_sum_f64(simd, data);
229 }
230
231 let mut stack_val = [0.0f64; 24];
232 let mut stack_lvl = [0usize; 24];
233 let mut depth = 0usize;
234
235 let mut offset = 0;
236 while offset < n {
237 let end = (offset + PAIRWISE_BASE).min(n);
238 let mut current = simd_base_sum_f64(simd, &data[offset..end]);
239 offset = end;
240
241 let mut level = 1usize;
242 while depth > 0 && stack_lvl[depth - 1] == level {
243 depth -= 1;
244 current += stack_val[depth];
245 level += 1;
246 }
247 stack_val[depth] = current;
248 stack_lvl[depth] = level;
249 depth += 1;
250 }
251
252 let mut result = stack_val[depth - 1];
253 for i in (0..depth - 1).rev() {
254 result += stack_val[i];
255 }
256 result
257}
258
259#[must_use]
268pub fn simd_sum_sq_diff_f64(data: &[f64], mean: f64) -> f64 {
269 Arch::new().dispatch(SumSqDiffF64Op { data, mean })
270}
271
272struct SumSqDiffF64Op<'a> {
273 data: &'a [f64],
274 mean: f64,
275}
276
277impl pulp::WithSimd for SumSqDiffF64Op<'_> {
278 type Output = f64;
279
280 #[inline(always)]
281 fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
282 let data = self.data;
283 let n = data.len();
284 let lane_count = size_of::<S::f64s>() / size_of::<f64>();
285 let simd_end = n - (n % lane_count);
286
287 let zero = simd.splat_f64s(0.0);
288 let mean_v = simd.splat_f64s(self.mean);
289 let mut acc0 = zero;
290 let mut acc1 = zero;
291 let mut acc2 = zero;
292 let mut acc3 = zero;
293
294 let stride = lane_count * 4;
295 let unrolled_end = n - (n % stride);
296 let mut i = 0;
297 while i < unrolled_end {
298 let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
299 let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
300 let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
301 let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
302 let d0 = simd.sub_f64s(v0, mean_v);
303 let d1 = simd.sub_f64s(v1, mean_v);
304 let d2 = simd.sub_f64s(v2, mean_v);
305 let d3 = simd.sub_f64s(v3, mean_v);
306 acc0 = simd.mul_add_f64s(d0, d0, acc0);
307 acc1 = simd.mul_add_f64s(d1, d1, acc1);
308 acc2 = simd.mul_add_f64s(d2, d2, acc2);
309 acc3 = simd.mul_add_f64s(d3, d3, acc3);
310 i += stride;
311 }
312 while i + lane_count <= simd_end {
313 let v = simd.partial_load_f64s(&data[i..i + lane_count]);
314 let d = simd.sub_f64s(v, mean_v);
315 acc0 = simd.mul_add_f64s(d, d, acc0);
316 i += lane_count;
317 }
318 acc0 = simd.add_f64s(acc0, acc1);
319 acc2 = simd.add_f64s(acc2, acc3);
320 acc0 = simd.add_f64s(acc0, acc2);
321
322 let mut temp = [0.0f64; 8];
323 simd.partial_store_f64s(&mut temp[..lane_count], acc0);
324 let mut sum = 0.0f64;
325 for t in temp.iter().take(lane_count) {
326 sum += t;
327 }
328 for &val in &data[simd_end..n] {
329 let d = val - self.mean;
330 sum += d * d;
331 }
332 sum
333 }
334}
335
336#[must_use]
348pub fn pairwise_sum_f32(data: &[f32]) -> f32 {
349 Arch::new().dispatch(PairwiseSumF32Op { data })
350}
351
352struct PairwiseSumF32Op<'a> {
353 data: &'a [f32],
354}
355
356impl pulp::WithSimd for PairwiseSumF32Op<'_> {
357 type Output = f32;
358
359 #[inline(always)]
360 fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
361 simd_pairwise_f32(simd, self.data)
362 }
363}
364
365#[inline(always)]
366fn simd_base_sum_f32<S: pulp::Simd>(simd: S, data: &[f32]) -> f32 {
367 let n = data.len();
368 let lane_count = size_of::<S::f32s>() / size_of::<f32>();
369 let simd_end = n - (n % lane_count);
370
371 let zero = simd.splat_f32s(0.0);
372 let mut acc0 = zero;
373 let mut acc1 = zero;
374 let mut acc2 = zero;
375 let mut acc3 = zero;
376
377 let stride = lane_count * 4;
378 let unrolled_end = n - (n % stride);
379 let mut i = 0;
380 while i < unrolled_end {
381 let v0 = simd.partial_load_f32s(&data[i..i + lane_count]);
382 let v1 = simd.partial_load_f32s(&data[i + lane_count..i + lane_count * 2]);
383 let v2 = simd.partial_load_f32s(&data[i + lane_count * 2..i + lane_count * 3]);
384 let v3 = simd.partial_load_f32s(&data[i + lane_count * 3..i + stride]);
385 acc0 = simd.add_f32s(acc0, v0);
386 acc1 = simd.add_f32s(acc1, v1);
387 acc2 = simd.add_f32s(acc2, v2);
388 acc3 = simd.add_f32s(acc3, v3);
389 i += stride;
390 }
391 while i + lane_count <= simd_end {
392 let v = simd.partial_load_f32s(&data[i..i + lane_count]);
393 acc0 = simd.add_f32s(acc0, v);
394 i += lane_count;
395 }
396 acc0 = simd.add_f32s(acc0, acc1);
397 acc2 = simd.add_f32s(acc2, acc3);
398 acc0 = simd.add_f32s(acc0, acc2);
399
400 let mut temp = [0.0f32; 16];
403 simd.partial_store_f32s(&mut temp[..lane_count], acc0);
404 let mut sum = 0.0f32;
405 for t in temp.iter().take(lane_count) {
406 sum += t;
407 }
408 for &val in &data[simd_end..n] {
409 sum += val;
410 }
411 sum
412}
413
414#[inline(always)]
415fn simd_pairwise_f32<S: pulp::Simd>(simd: S, data: &[f32]) -> f32 {
416 let n = data.len();
417 if n == 0 {
418 return 0.0;
419 }
420 if n <= PAIRWISE_BASE {
421 return simd_base_sum_f32(simd, data);
422 }
423
424 let mut stack_val = [0.0f32; 24];
425 let mut stack_lvl = [0usize; 24];
426 let mut depth = 0usize;
427
428 let mut offset = 0;
429 while offset < n {
430 let end = (offset + PAIRWISE_BASE).min(n);
431 let mut current = simd_base_sum_f32(simd, &data[offset..end]);
432 offset = end;
433
434 let mut level = 1usize;
435 while depth > 0 && stack_lvl[depth - 1] == level {
436 depth -= 1;
437 current += stack_val[depth];
438 level += 1;
439 }
440 stack_val[depth] = current;
441 stack_lvl[depth] = level;
442 depth += 1;
443 }
444
445 let mut result = stack_val[depth - 1];
446 for i in (0..depth - 1).rev() {
447 result += stack_val[i];
448 }
449 result
450}
451
452#[must_use]
454pub fn simd_sum_sq_diff_f32(data: &[f32], mean: f32) -> f32 {
455 Arch::new().dispatch(SumSqDiffF32Op { data, mean })
456}
457
458struct SumSqDiffF32Op<'a> {
459 data: &'a [f32],
460 mean: f32,
461}
462
463impl pulp::WithSimd for SumSqDiffF32Op<'_> {
464 type Output = f32;
465
466 #[inline(always)]
467 fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
468 let data = self.data;
469 let n = data.len();
470 let lane_count = size_of::<S::f32s>() / size_of::<f32>();
471 let simd_end = n - (n % lane_count);
472
473 let zero = simd.splat_f32s(0.0);
474 let mean_v = simd.splat_f32s(self.mean);
475 let mut acc0 = zero;
476 let mut acc1 = zero;
477 let mut acc2 = zero;
478 let mut acc3 = zero;
479
480 let stride = lane_count * 4;
481 let unrolled_end = n - (n % stride);
482 let mut i = 0;
483 while i < unrolled_end {
484 let v0 = simd.partial_load_f32s(&data[i..i + lane_count]);
485 let v1 = simd.partial_load_f32s(&data[i + lane_count..i + lane_count * 2]);
486 let v2 = simd.partial_load_f32s(&data[i + lane_count * 2..i + lane_count * 3]);
487 let v3 = simd.partial_load_f32s(&data[i + lane_count * 3..i + stride]);
488 let d0 = simd.sub_f32s(v0, mean_v);
489 let d1 = simd.sub_f32s(v1, mean_v);
490 let d2 = simd.sub_f32s(v2, mean_v);
491 let d3 = simd.sub_f32s(v3, mean_v);
492 acc0 = simd.mul_add_f32s(d0, d0, acc0);
493 acc1 = simd.mul_add_f32s(d1, d1, acc1);
494 acc2 = simd.mul_add_f32s(d2, d2, acc2);
495 acc3 = simd.mul_add_f32s(d3, d3, acc3);
496 i += stride;
497 }
498 while i + lane_count <= simd_end {
499 let v = simd.partial_load_f32s(&data[i..i + lane_count]);
500 let d = simd.sub_f32s(v, mean_v);
501 acc0 = simd.mul_add_f32s(d, d, acc0);
502 i += lane_count;
503 }
504 acc0 = simd.add_f32s(acc0, acc1);
505 acc2 = simd.add_f32s(acc2, acc3);
506 acc0 = simd.add_f32s(acc0, acc2);
507
508 let mut temp = [0.0f32; 16];
509 simd.partial_store_f32s(&mut temp[..lane_count], acc0);
510 let mut sum = 0.0f32;
511 for t in temp.iter().take(lane_count) {
512 sum += t;
513 }
514 for &val in &data[simd_end..n] {
515 let d = val - self.mean;
516 sum += d * d;
517 }
518 sum
519 }
520}
521
522pub fn parallel_sum<T>(data: &[T], identity: T) -> T
531where
532 T: Copy + Send + Sync + std::ops::Add<Output = T>,
533{
534 if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
535 data.par_iter().copied().reduce(|| identity, |a, b| a + b)
537 } else {
538 pairwise_sum(data, identity)
539 }
540}
541
542pub fn parallel_prod<T>(data: &[T], identity: T) -> T
544where
545 T: Copy + Send + Sync + std::ops::Mul<Output = T>,
546{
547 if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
548 data.par_iter().copied().reduce(|| identity, |a, b| a * b)
549 } else {
550 data.iter().copied().fold(identity, |a, b| a * b)
551 }
552}
553
554#[inline]
570pub fn nan_last_cmp<T: PartialOrd>(a: &T, b: &T) -> std::cmp::Ordering {
571 use std::cmp::Ordering;
572 if let Some(ord) = a.partial_cmp(b) {
573 ord
574 } else {
575 let a_nan = a.partial_cmp(a).is_none();
576 let b_nan = b.partial_cmp(b).is_none();
577 match (a_nan, b_nan) {
578 (true, true) => Ordering::Equal,
579 (true, false) => Ordering::Greater,
580 (false, true) => Ordering::Less,
581 (false, false) => Ordering::Equal,
585 }
586 }
587}
588
589pub fn parallel_sort<T>(data: &mut [T])
591where
592 T: Copy + Send + Sync + PartialOrd,
593{
594 if data.len() >= PARALLEL_SORT_THRESHOLD {
595 data.par_sort_unstable_by(nan_last_cmp);
596 } else {
597 data.sort_unstable_by(nan_last_cmp);
598 }
599}
600
601pub fn parallel_sort_stable<T>(data: &mut [T])
603where
604 T: Copy + Send + Sync + PartialOrd,
605{
606 if data.len() >= PARALLEL_SORT_THRESHOLD {
607 data.par_sort_by(nan_last_cmp);
608 } else {
609 data.sort_by(nan_last_cmp);
610 }
611}
612
613#[cfg(test)]
614mod sort_cmp_tests {
615 use super::nan_last_cmp;
616
617 #[test]
623 fn nan_last_basic() {
624 let mut v = [3.0f64, f64::NAN, 1.0];
625 v.sort_by(nan_last_cmp);
626 assert_eq!(v[0], 1.0);
627 assert_eq!(v[1], 3.0);
628 assert!(v[2].is_nan());
629 }
630
631 #[test]
632 fn multiple_nans_last() {
633 let mut v = [f64::NAN, f64::NAN, 1.0, 2.0];
634 v.sort_by(nan_last_cmp);
635 assert_eq!(v[0], 1.0);
636 assert_eq!(v[1], 2.0);
637 assert!(v[2].is_nan() && v[3].is_nan());
638 }
639
640 #[test]
641 fn inf_order_then_nan() {
642 let mut v = [f64::INFINITY, f64::NEG_INFINITY, 0.0, f64::NAN, 1.0];
643 v.sort_by(nan_last_cmp);
644 assert_eq!(v[0], f64::NEG_INFINITY);
645 assert_eq!(v[1], 0.0);
646 assert_eq!(v[2], 1.0);
647 assert_eq!(v[3], f64::INFINITY);
648 assert!(v[4].is_nan());
649 }
650
651 #[test]
652 fn negative_nan_also_last() {
653 let mut v = [1.0f64, f64::NAN, -f64::NAN, 2.0];
654 v.sort_by(nan_last_cmp);
655 assert_eq!(v[0], 1.0);
656 assert_eq!(v[1], 2.0);
657 assert!(v[2].is_nan() && v[3].is_nan());
658 }
659
660 #[test]
661 fn integers_unaffected() {
662 let mut v = [5i32, 2, 8, 1];
663 v.sort_by(nan_last_cmp);
664 assert_eq!(v, [1, 2, 5, 8]);
665 }
666
667 #[test]
668 fn f32_nan_last() {
669 let mut v = [3.0f32, f32::NAN, 1.0];
670 v.sort_by(nan_last_cmp);
671 assert_eq!(v[0], 1.0);
672 assert_eq!(v[1], 3.0);
673 assert!(v[2].is_nan());
674 }
675}