Skip to main content

ferray_stats/
parallel.rs

1// ferray-stats: Rayon threshold dispatch for reductions and sorting (REQ-19, REQ-20)
2//
3// ## REQ status (ferray-stats parallel kernels, NumPy parity)
4//  - REQ-19 (large reductions use parallel tree-reduce via Rayon) — SHIPPED:
5//    `pub fn pairwise_sum` / `pub fn pairwise_sum_f64` / `pub fn pairwise_sum_f32`
6//    (pairwise/tree summation), `pub fn simd_sum_sq_diff_f64` /
7//    `pub fn simd_sum_sq_diff_f32` (fused sum-of-squared-deviations), and
8//    `pub fn parallel_sum` / `pub fn parallel_prod` (Rayon tree-reduce above the
9//    element-count threshold). The pairwise tree is numerically equivalent to
10//    numpy's pairwise summation (numpy/_core/src/umath/loops_utils.h.src
11//    pairwise reduction). Non-test consumers: `reductions::sum`/`mean`/`var`
12//    (`crate::parallel::pairwise_sum*` / `simd_sum_sq_diff*` in
13//    `reductions/mod.rs`) and `correlation::corrcoef` (`crate::parallel::pairwise_sum`).
14//  - REQ-20 (large-array sort uses parallel merge sort via Rayon) — SHIPPED:
15//    `pub fn parallel_sort` / `pub fn parallel_sort_stable` (Rayon
16//    parallel/parallel-stable sort above the sort threshold), plus
17//    `pub fn nan_last_cmp` (the NaN-last total order numpy's sort uses,
18//    numpy/_core/src/npysort). Non-test consumer: `sorting::sort`
19//    (`parallel::parallel_sort` / `parallel_sort_stable` in `sorting.rs`).
20
21// `with_simd`, `simd_pairwise_f64`, `simd_base_sum_f64`, and `base_sum`
22// are pulp dispatch entry points where `#[inline(always)]` is required —
23// without it the SIMD instruction set selected at the call site fails to
24// inline into the kernel and dispatch becomes a no-op.
25#![allow(clippy::inline_always)]
26
27use pulp::Arch;
28use rayon::prelude::*;
29
30/// Threshold above which reductions use parallel tree-reduce.
31///
32/// Set high to avoid rayon thread-pool startup overhead dominating small workloads.
33/// Rayon's thread pool initialization costs ~1-3ms on first use, which is catastrophic
34/// for arrays under ~1M elements where sequential pairwise sum takes <1ms.
35pub const PARALLEL_REDUCTION_THRESHOLD: usize = 1_000_000;
36
37/// Threshold above which sorting uses parallel merge sort.
38pub const PARALLEL_SORT_THRESHOLD: usize = 100_000;
39
40/// Pairwise summation base case threshold.
41///
42/// Below this size, we use an unrolled sequential sum. 256 elements gives a
43/// good balance: halves carry-merge overhead vs 128 while keeping the same
44/// O(ε log N) accuracy bound (just one fewer merge level).
45/// Base-chunk size for pairwise summation. Each base chunk runs the
46/// flat SIMD inner-loop (4-way unrolled f64 adds with one horizontal
47/// reduction at the end); the pairwise tree only kicks in across base
48/// chunks. So this knob is the trade-off between (a) tree-induced
49/// horizontal-sum overhead — paid once per base chunk — and (b)
50/// numerical precision, which scales as O(ε * base_size) within a
51/// chunk and O(ε log N/base_size) across the tree.
52///
53/// 4096 puts the per-chunk error at ~12 ULP for f64 (well below f64's
54/// 52-bit mantissa) while cutting the horizontal-sum overhead 16× vs
55/// the previous 256. For N=1M that's 250 chunks instead of ~4000 — the
56/// difference between memory-bandwidth-bound (NumPy's 154 µs floor at
57/// 1M) and add-throughput-bound (242 µs at 1M).
58const PAIRWISE_BASE: usize = 4096;
59
60/// 8-wide unrolled base-case sum for auto-vectorization.
61#[inline(always)]
62fn base_sum<T: Copy + std::ops::Add<Output = T>>(data: &[T], identity: T) -> T {
63    let n = data.len();
64    let mut acc0 = identity;
65    let mut acc1 = identity;
66    let mut acc2 = identity;
67    let mut acc3 = identity;
68    let mut acc4 = identity;
69    let mut acc5 = identity;
70    let mut acc6 = identity;
71    let mut acc7 = identity;
72    let chunks = n / 8;
73    let rem = n % 8;
74    for i in 0..chunks {
75        let base = i * 8;
76        acc0 = acc0 + data[base];
77        acc1 = acc1 + data[base + 1];
78        acc2 = acc2 + data[base + 2];
79        acc3 = acc3 + data[base + 3];
80        acc4 = acc4 + data[base + 4];
81        acc5 = acc5 + data[base + 5];
82        acc6 = acc6 + data[base + 6];
83        acc7 = acc7 + data[base + 7];
84    }
85    for i in 0..rem {
86        acc0 = acc0 + data[chunks * 8 + i];
87    }
88    (acc0 + acc1) + (acc2 + acc3) + ((acc4 + acc5) + (acc6 + acc7))
89}
90
91/// Pairwise summation of a slice.
92///
93/// Uses an iterative carry-merge algorithm with O(ε log N) error bound,
94/// matching `NumPy`'s summation accuracy. Data is processed in 128-element
95/// base chunks, then merged pairwise using a small stack (like binary
96/// carry addition). Zero recursion overhead — stack depth is at most
97/// ceil(log2(N/128)) + 1 ≈ 20 entries.
98pub fn pairwise_sum<T>(data: &[T], identity: T) -> T
99where
100    T: Copy + std::ops::Add<Output = T>,
101{
102    let n = data.len();
103    if n == 0 {
104        return identity;
105    }
106    if n <= PAIRWISE_BASE {
107        return base_sum(data, identity);
108    }
109
110    // Iterative pairwise summation using a carry-merge stack.
111    // Each entry holds (partial_sum, level) where level is how many base
112    // chunks it represents. Same-level adjacent entries merge on push,
113    // exactly like binary carry addition. Max depth: ~24 for any realistic size.
114    let mut stack_val: [T; 24] = [identity; 24];
115    let mut stack_lvl: [usize; 24] = [0; 24];
116    let mut depth = 0usize;
117
118    let mut offset = 0;
119    while offset < n {
120        let end = (offset + PAIRWISE_BASE).min(n);
121        let mut current = base_sum(&data[offset..end], identity);
122        offset = end;
123
124        // Push and carry-merge: merge with stack top while levels match
125        let mut level = 1usize;
126        while depth > 0 && stack_lvl[depth - 1] == level {
127            depth -= 1;
128            current = stack_val[depth] + current;
129            level += 1;
130        }
131        stack_val[depth] = current;
132        stack_lvl[depth] = level;
133        depth += 1;
134    }
135
136    // Merge remaining stack entries
137    let mut result = stack_val[depth - 1];
138    for i in (0..depth - 1).rev() {
139        result = stack_val[i] + result;
140    }
141    result
142}
143
144// ---------------------------------------------------------------------------
145// SIMD-accelerated pairwise sum for f64
146// ---------------------------------------------------------------------------
147
148/// SIMD-accelerated pairwise summation for f64 slices.
149///
150/// Uses pulp for hardware SIMD dispatch (AVX2/SSE2/NEON) in the 128-element
151/// base case, with the same iterative carry-merge structure for O(ε log N)
152/// accuracy.
153#[must_use]
154pub fn pairwise_sum_f64(data: &[f64]) -> f64 {
155    Arch::new().dispatch(PairwiseSumF64Op { data })
156}
157
158struct PairwiseSumF64Op<'a> {
159    data: &'a [f64],
160}
161
162impl pulp::WithSimd for PairwiseSumF64Op<'_> {
163    type Output = f64;
164
165    #[inline(always)]
166    fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
167        simd_pairwise_f64(simd, self.data)
168    }
169}
170
171#[inline(always)]
172fn simd_base_sum_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
173    let n = data.len();
174    let lane_count = size_of::<S::f64s>() / size_of::<f64>();
175    let simd_end = n - (n % lane_count);
176
177    let zero = simd.splat_f64s(0.0);
178    let mut acc0 = zero;
179    let mut acc1 = zero;
180    let mut acc2 = zero;
181    let mut acc3 = zero;
182
183    // 4-accumulator unroll to saturate FPU throughput
184    let stride = lane_count * 4;
185    let unrolled_end = n - (n % stride);
186    let mut i = 0;
187    while i < unrolled_end {
188        let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
189        let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
190        let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
191        let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
192        acc0 = simd.add_f64s(acc0, v0);
193        acc1 = simd.add_f64s(acc1, v1);
194        acc2 = simd.add_f64s(acc2, v2);
195        acc3 = simd.add_f64s(acc3, v3);
196        i += stride;
197    }
198    while i + lane_count <= simd_end {
199        let v = simd.partial_load_f64s(&data[i..i + lane_count]);
200        acc0 = simd.add_f64s(acc0, v);
201        i += lane_count;
202    }
203    acc0 = simd.add_f64s(acc0, acc1);
204    acc2 = simd.add_f64s(acc2, acc3);
205    acc0 = simd.add_f64s(acc0, acc2);
206
207    // Horizontal sum: store SIMD register to temp array
208    let mut temp = [0.0f64; 8]; // max 8 lanes (AVX-512)
209    simd.partial_store_f64s(&mut temp[..lane_count], acc0);
210    let mut sum = 0.0f64;
211    for t in temp.iter().take(lane_count) {
212        sum += t;
213    }
214    // Scalar remainder
215    for &val in &data[simd_end..n] {
216        sum += val;
217    }
218    sum
219}
220
221#[inline(always)]
222fn simd_pairwise_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
223    let n = data.len();
224    if n == 0 {
225        return 0.0;
226    }
227    if n <= PAIRWISE_BASE {
228        return simd_base_sum_f64(simd, data);
229    }
230
231    let mut stack_val = [0.0f64; 24];
232    let mut stack_lvl = [0usize; 24];
233    let mut depth = 0usize;
234
235    let mut offset = 0;
236    while offset < n {
237        let end = (offset + PAIRWISE_BASE).min(n);
238        let mut current = simd_base_sum_f64(simd, &data[offset..end]);
239        offset = end;
240
241        let mut level = 1usize;
242        while depth > 0 && stack_lvl[depth - 1] == level {
243            depth -= 1;
244            current += stack_val[depth];
245            level += 1;
246        }
247        stack_val[depth] = current;
248        stack_lvl[depth] = level;
249        depth += 1;
250    }
251
252    let mut result = stack_val[depth - 1];
253    for i in (0..depth - 1).rev() {
254        result += stack_val[i];
255    }
256    result
257}
258
259// ---------------------------------------------------------------------------
260// SIMD-accelerated fused sum of squared differences for variance
261// ---------------------------------------------------------------------------
262
263/// SIMD-accelerated computation of sum((x - mean)²) for f64 slices.
264///
265/// Computes the sum of squared differences from the mean in a single pass
266/// without allocating an intermediate Vec. Uses 4 SIMD accumulators for ILP.
267#[must_use]
268pub fn simd_sum_sq_diff_f64(data: &[f64], mean: f64) -> f64 {
269    Arch::new().dispatch(SumSqDiffF64Op { data, mean })
270}
271
272struct SumSqDiffF64Op<'a> {
273    data: &'a [f64],
274    mean: f64,
275}
276
277impl pulp::WithSimd for SumSqDiffF64Op<'_> {
278    type Output = f64;
279
280    #[inline(always)]
281    fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
282        let data = self.data;
283        let n = data.len();
284        let lane_count = size_of::<S::f64s>() / size_of::<f64>();
285        let simd_end = n - (n % lane_count);
286
287        let zero = simd.splat_f64s(0.0);
288        let mean_v = simd.splat_f64s(self.mean);
289        let mut acc0 = zero;
290        let mut acc1 = zero;
291        let mut acc2 = zero;
292        let mut acc3 = zero;
293
294        let stride = lane_count * 4;
295        let unrolled_end = n - (n % stride);
296        let mut i = 0;
297        while i < unrolled_end {
298            let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
299            let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
300            let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
301            let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
302            let d0 = simd.sub_f64s(v0, mean_v);
303            let d1 = simd.sub_f64s(v1, mean_v);
304            let d2 = simd.sub_f64s(v2, mean_v);
305            let d3 = simd.sub_f64s(v3, mean_v);
306            acc0 = simd.mul_add_f64s(d0, d0, acc0);
307            acc1 = simd.mul_add_f64s(d1, d1, acc1);
308            acc2 = simd.mul_add_f64s(d2, d2, acc2);
309            acc3 = simd.mul_add_f64s(d3, d3, acc3);
310            i += stride;
311        }
312        while i + lane_count <= simd_end {
313            let v = simd.partial_load_f64s(&data[i..i + lane_count]);
314            let d = simd.sub_f64s(v, mean_v);
315            acc0 = simd.mul_add_f64s(d, d, acc0);
316            i += lane_count;
317        }
318        acc0 = simd.add_f64s(acc0, acc1);
319        acc2 = simd.add_f64s(acc2, acc3);
320        acc0 = simd.add_f64s(acc0, acc2);
321
322        let mut temp = [0.0f64; 8];
323        simd.partial_store_f64s(&mut temp[..lane_count], acc0);
324        let mut sum = 0.0f64;
325        for t in temp.iter().take(lane_count) {
326            sum += t;
327        }
328        for &val in &data[simd_end..n] {
329            let d = val - self.mean;
330            sum += d * d;
331        }
332        sum
333    }
334}
335
336// ---------------------------------------------------------------------------
337// SIMD-accelerated pairwise sum / sum-of-squared-differences for f32 (#173)
338//
339// Mirrors the f64 versions above with pulp's f32s SIMD type. Twice as
340// many lanes per register (e.g. AVX-512 f32 = 16 lanes vs f64 = 8), so
341// the constants for max-temp-buffer below need to match the widest f32
342// lane count. The base case and stack-based pairwise tree are
343// structurally identical.
344// ---------------------------------------------------------------------------
345
346/// SIMD-accelerated pairwise summation for f32 slices (#173).
347#[must_use]
348pub fn pairwise_sum_f32(data: &[f32]) -> f32 {
349    Arch::new().dispatch(PairwiseSumF32Op { data })
350}
351
352struct PairwiseSumF32Op<'a> {
353    data: &'a [f32],
354}
355
356impl pulp::WithSimd for PairwiseSumF32Op<'_> {
357    type Output = f32;
358
359    #[inline(always)]
360    fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
361        simd_pairwise_f32(simd, self.data)
362    }
363}
364
365#[inline(always)]
366fn simd_base_sum_f32<S: pulp::Simd>(simd: S, data: &[f32]) -> f32 {
367    let n = data.len();
368    let lane_count = size_of::<S::f32s>() / size_of::<f32>();
369    let simd_end = n - (n % lane_count);
370
371    let zero = simd.splat_f32s(0.0);
372    let mut acc0 = zero;
373    let mut acc1 = zero;
374    let mut acc2 = zero;
375    let mut acc3 = zero;
376
377    let stride = lane_count * 4;
378    let unrolled_end = n - (n % stride);
379    let mut i = 0;
380    while i < unrolled_end {
381        let v0 = simd.partial_load_f32s(&data[i..i + lane_count]);
382        let v1 = simd.partial_load_f32s(&data[i + lane_count..i + lane_count * 2]);
383        let v2 = simd.partial_load_f32s(&data[i + lane_count * 2..i + lane_count * 3]);
384        let v3 = simd.partial_load_f32s(&data[i + lane_count * 3..i + stride]);
385        acc0 = simd.add_f32s(acc0, v0);
386        acc1 = simd.add_f32s(acc1, v1);
387        acc2 = simd.add_f32s(acc2, v2);
388        acc3 = simd.add_f32s(acc3, v3);
389        i += stride;
390    }
391    while i + lane_count <= simd_end {
392        let v = simd.partial_load_f32s(&data[i..i + lane_count]);
393        acc0 = simd.add_f32s(acc0, v);
394        i += lane_count;
395    }
396    acc0 = simd.add_f32s(acc0, acc1);
397    acc2 = simd.add_f32s(acc2, acc3);
398    acc0 = simd.add_f32s(acc0, acc2);
399
400    // Horizontal sum: store SIMD register to temp array. AVX-512 f32
401    // is 16 lanes wide, so size the temp accordingly.
402    let mut temp = [0.0f32; 16];
403    simd.partial_store_f32s(&mut temp[..lane_count], acc0);
404    let mut sum = 0.0f32;
405    for t in temp.iter().take(lane_count) {
406        sum += t;
407    }
408    for &val in &data[simd_end..n] {
409        sum += val;
410    }
411    sum
412}
413
414#[inline(always)]
415fn simd_pairwise_f32<S: pulp::Simd>(simd: S, data: &[f32]) -> f32 {
416    let n = data.len();
417    if n == 0 {
418        return 0.0;
419    }
420    if n <= PAIRWISE_BASE {
421        return simd_base_sum_f32(simd, data);
422    }
423
424    let mut stack_val = [0.0f32; 24];
425    let mut stack_lvl = [0usize; 24];
426    let mut depth = 0usize;
427
428    let mut offset = 0;
429    while offset < n {
430        let end = (offset + PAIRWISE_BASE).min(n);
431        let mut current = simd_base_sum_f32(simd, &data[offset..end]);
432        offset = end;
433
434        let mut level = 1usize;
435        while depth > 0 && stack_lvl[depth - 1] == level {
436            depth -= 1;
437            current += stack_val[depth];
438            level += 1;
439        }
440        stack_val[depth] = current;
441        stack_lvl[depth] = level;
442        depth += 1;
443    }
444
445    let mut result = stack_val[depth - 1];
446    for i in (0..depth - 1).rev() {
447        result += stack_val[i];
448    }
449    result
450}
451
452/// SIMD-accelerated computation of sum((x - mean)²) for f32 slices (#173).
453#[must_use]
454pub fn simd_sum_sq_diff_f32(data: &[f32], mean: f32) -> f32 {
455    Arch::new().dispatch(SumSqDiffF32Op { data, mean })
456}
457
458struct SumSqDiffF32Op<'a> {
459    data: &'a [f32],
460    mean: f32,
461}
462
463impl pulp::WithSimd for SumSqDiffF32Op<'_> {
464    type Output = f32;
465
466    #[inline(always)]
467    fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
468        let data = self.data;
469        let n = data.len();
470        let lane_count = size_of::<S::f32s>() / size_of::<f32>();
471        let simd_end = n - (n % lane_count);
472
473        let zero = simd.splat_f32s(0.0);
474        let mean_v = simd.splat_f32s(self.mean);
475        let mut acc0 = zero;
476        let mut acc1 = zero;
477        let mut acc2 = zero;
478        let mut acc3 = zero;
479
480        let stride = lane_count * 4;
481        let unrolled_end = n - (n % stride);
482        let mut i = 0;
483        while i < unrolled_end {
484            let v0 = simd.partial_load_f32s(&data[i..i + lane_count]);
485            let v1 = simd.partial_load_f32s(&data[i + lane_count..i + lane_count * 2]);
486            let v2 = simd.partial_load_f32s(&data[i + lane_count * 2..i + lane_count * 3]);
487            let v3 = simd.partial_load_f32s(&data[i + lane_count * 3..i + stride]);
488            let d0 = simd.sub_f32s(v0, mean_v);
489            let d1 = simd.sub_f32s(v1, mean_v);
490            let d2 = simd.sub_f32s(v2, mean_v);
491            let d3 = simd.sub_f32s(v3, mean_v);
492            acc0 = simd.mul_add_f32s(d0, d0, acc0);
493            acc1 = simd.mul_add_f32s(d1, d1, acc1);
494            acc2 = simd.mul_add_f32s(d2, d2, acc2);
495            acc3 = simd.mul_add_f32s(d3, d3, acc3);
496            i += stride;
497        }
498        while i + lane_count <= simd_end {
499            let v = simd.partial_load_f32s(&data[i..i + lane_count]);
500            let d = simd.sub_f32s(v, mean_v);
501            acc0 = simd.mul_add_f32s(d, d, acc0);
502            i += lane_count;
503        }
504        acc0 = simd.add_f32s(acc0, acc1);
505        acc2 = simd.add_f32s(acc2, acc3);
506        acc0 = simd.add_f32s(acc0, acc2);
507
508        let mut temp = [0.0f32; 16];
509        simd.partial_store_f32s(&mut temp[..lane_count], acc0);
510        let mut sum = 0.0f32;
511        for t in temp.iter().take(lane_count) {
512            sum += t;
513        }
514        for &val in &data[simd_end..n] {
515            let d = val - self.mean;
516            sum += d * d;
517        }
518        sum
519    }
520}
521
522// ---------------------------------------------------------------------------
523// Parallel dispatch
524// ---------------------------------------------------------------------------
525
526/// Perform a parallel pairwise sum on a slice.
527///
528/// Uses rayon tree-reduce for large slices (which is inherently pairwise),
529/// and iterative pairwise summation for smaller slices.
530pub fn parallel_sum<T>(data: &[T], identity: T) -> T
531where
532    T: Copy + Send + Sync + std::ops::Add<Output = T>,
533{
534    if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
535        // Rayon's reduce is a tree-reduce, which gives pairwise-like accuracy
536        data.par_iter().copied().reduce(|| identity, |a, b| a + b)
537    } else {
538        pairwise_sum(data, identity)
539    }
540}
541
542/// Perform a parallel product on a slice of `Copy + Send + Sync` values.
543pub fn parallel_prod<T>(data: &[T], identity: T) -> T
544where
545    T: Copy + Send + Sync + std::ops::Mul<Output = T>,
546{
547    if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
548        data.par_iter().copied().reduce(|| identity, |a, b| a * b)
549    } else {
550        data.iter().copied().fold(identity, |a, b| a * b)
551    }
552}
553
554/// Total-order comparator that sends "unorderable" (NaN) values to the
555/// end, matching `NumPy`'s `np.sort` convention.
556///
557/// Mirrors numpy/_core/src/common/numpy_tag.h:62
558/// `floating_point_type::less(a, b) = a < b || (b != b && a == a)`:
559/// a value is "less than" b iff it is non-NaN and either strictly less
560/// than b or b is NaN. Every NaN (regardless of sign bit) is treated as
561/// the greatest element, so all NaNs sort to the tail — unlike
562/// `f64::total_cmp`, which would order a negative NaN before `-inf`.
563///
564/// Detects NaN without a `Float` bound via self-comparison: for any sane
565/// `PartialOrd`, `x.partial_cmp(&x)` is `Some(Equal)`; only NaN-like
566/// values break that invariant and return `None`. Non-NaN orderable
567/// values keep their natural ordering, so integer and other no-NaN
568/// dtypes are completely unaffected.
569#[inline]
570pub fn nan_last_cmp<T: PartialOrd>(a: &T, b: &T) -> std::cmp::Ordering {
571    use std::cmp::Ordering;
572    if let Some(ord) = a.partial_cmp(b) {
573        ord
574    } else {
575        let a_nan = a.partial_cmp(a).is_none();
576        let b_nan = b.partial_cmp(b).is_none();
577        match (a_nan, b_nan) {
578            (true, true) => Ordering::Equal,
579            (true, false) => Ordering::Greater,
580            (false, true) => Ordering::Less,
581            // Genuinely incomparable non-NaN values shouldn't exist for
582            // any numeric Element type; keep the sort total by treating
583            // them as equal.
584            (false, false) => Ordering::Equal,
585        }
586    }
587}
588
589/// Parallel sort (unstable) for large slices. Returns a sorted copy.
590pub fn parallel_sort<T>(data: &mut [T])
591where
592    T: Copy + Send + Sync + PartialOrd,
593{
594    if data.len() >= PARALLEL_SORT_THRESHOLD {
595        data.par_sort_unstable_by(nan_last_cmp);
596    } else {
597        data.sort_unstable_by(nan_last_cmp);
598    }
599}
600
601/// Parallel stable sort for large slices.
602pub fn parallel_sort_stable<T>(data: &mut [T])
603where
604    T: Copy + Send + Sync + PartialOrd,
605{
606    if data.len() >= PARALLEL_SORT_THRESHOLD {
607        data.par_sort_by(nan_last_cmp);
608    } else {
609        data.sort_by(nan_last_cmp);
610    }
611}
612
613#[cfg(test)]
614mod sort_cmp_tests {
615    use super::nan_last_cmp;
616
617    // Expected values pulled from live numpy 2.4:
618    //   np.sort([3., nan, 1.])            -> [1., 3., nan]
619    //   np.sort([nan, nan, 1., 2.])       -> [1., 2., nan, nan]
620    //   np.sort([inf, -inf, 0., nan, 1.]) -> [-inf, 0., 1., inf, nan]
621    //   np.sort([1., nan, -nan, 2.])      -> [1., 2., nan, nan] (both NaN last)
622    #[test]
623    fn nan_last_basic() {
624        let mut v = [3.0f64, f64::NAN, 1.0];
625        v.sort_by(nan_last_cmp);
626        assert_eq!(v[0], 1.0);
627        assert_eq!(v[1], 3.0);
628        assert!(v[2].is_nan());
629    }
630
631    #[test]
632    fn multiple_nans_last() {
633        let mut v = [f64::NAN, f64::NAN, 1.0, 2.0];
634        v.sort_by(nan_last_cmp);
635        assert_eq!(v[0], 1.0);
636        assert_eq!(v[1], 2.0);
637        assert!(v[2].is_nan() && v[3].is_nan());
638    }
639
640    #[test]
641    fn inf_order_then_nan() {
642        let mut v = [f64::INFINITY, f64::NEG_INFINITY, 0.0, f64::NAN, 1.0];
643        v.sort_by(nan_last_cmp);
644        assert_eq!(v[0], f64::NEG_INFINITY);
645        assert_eq!(v[1], 0.0);
646        assert_eq!(v[2], 1.0);
647        assert_eq!(v[3], f64::INFINITY);
648        assert!(v[4].is_nan());
649    }
650
651    #[test]
652    fn negative_nan_also_last() {
653        let mut v = [1.0f64, f64::NAN, -f64::NAN, 2.0];
654        v.sort_by(nan_last_cmp);
655        assert_eq!(v[0], 1.0);
656        assert_eq!(v[1], 2.0);
657        assert!(v[2].is_nan() && v[3].is_nan());
658    }
659
660    #[test]
661    fn integers_unaffected() {
662        let mut v = [5i32, 2, 8, 1];
663        v.sort_by(nan_last_cmp);
664        assert_eq!(v, [1, 2, 5, 8]);
665    }
666
667    #[test]
668    fn f32_nan_last() {
669        let mut v = [3.0f32, f32::NAN, 1.0];
670        v.sort_by(nan_last_cmp);
671        assert_eq!(v[0], 1.0);
672        assert_eq!(v[1], 3.0);
673        assert!(v[2].is_nan());
674    }
675}