Skip to main content

oxiphysics_gpu/
parallel_sort.rs

1#![allow(clippy::ptr_arg)]
2#![allow(clippy::needless_range_loop)]
3// Copyright 2026 COOLJAPAN OU (Team KitaSan)
4// SPDX-License-Identifier: Apache-2.0
5
6//! Parallel sorting and reduction utilities (CPU-side, rayon-based).
7//!
8//! Provides:
9//! - [`radix_sort_u32`] – LSD radix sort for `u32` (4 passes of 8 bits).
10//! - [`radix_sort_by_key`] – sort any `T` by a `u32` key function.
11//! - [`parallel_prefix_sum`] – exclusive prefix-sum (scan) via rayon.
12//! - [`parallel_reduce_sum`] – parallel tree reduction for `f64`.
13//! - [`parallel_min_max`] – parallel min/max reduction.
14//! - [`bitonic_sort`] – bitonic sort (pads to power-of-2 with `f64::MAX`).
15//! - [`merge_sort_parallel`] – parallel merge sort via rayon.
16//! - [`histogram_u32`] – parallel histogram.
17//! - [`argsort`] – sorted index array for `f64` slice.
18//! - [`nth_element`] – quickselect O(n) kth smallest element.
19
20#![allow(dead_code)]
21
22use rayon::prelude::*;
23
24// ─────────────────────────────────────────────────────────────────────────────
25// radix_sort_u32
26// ─────────────────────────────────────────────────────────────────────────────
27
28/// LSD radix sort for `u32` values using 8-bit passes (4 passes total).
29///
30/// Stable, O(n) for fixed-width 32-bit keys.
31pub fn radix_sort_u32(data: &mut Vec<u32>) {
32    if data.len() <= 1 {
33        return;
34    }
35    let n = data.len();
36    let mut buf = vec![0u32; n];
37
38    for pass in 0..4u32 {
39        let shift = pass * 8;
40        let mut counts = [0usize; 256];
41        for &v in data.iter() {
42            let byte = ((v >> shift) & 0xFF) as usize;
43            counts[byte] += 1;
44        }
45        // exclusive prefix sum of counts
46        let mut offsets = [0usize; 256];
47        let mut total = 0;
48        for i in 0..256 {
49            offsets[i] = total;
50            total += counts[i];
51        }
52        for &v in data.iter() {
53            let byte = ((v >> shift) & 0xFF) as usize;
54            buf[offsets[byte]] = v;
55            offsets[byte] += 1;
56        }
57        std::mem::swap(data, &mut buf);
58    }
59}
60
61// ─────────────────────────────────────────────────────────────────────────────
62// radix_sort_by_key
63// ─────────────────────────────────────────────────────────────────────────────
64
65/// Sort a `Vec`T` in ascending order of the `u32` key produced by `key_fn`.
66///
67/// Uses LSD radix sort (4 passes of 8 bits).  The sort is stable.
68pub fn radix_sort_by_key<T: Clone>(data: &mut Vec<T>, key_fn: impl Fn(&T) -> u32) {
69    if data.len() <= 1 {
70        return;
71    }
72    let mut buf: Vec<T> = data.clone();
73
74    for pass in 0..4u32 {
75        let shift = pass * 8;
76        let mut counts = [0usize; 256];
77        for item in data.iter() {
78            let byte = ((key_fn(item) >> shift) & 0xFF) as usize;
79            counts[byte] += 1;
80        }
81        let mut offsets = [0usize; 256];
82        let mut total = 0;
83        for i in 0..256 {
84            offsets[i] = total;
85            total += counts[i];
86        }
87        for item in data.iter() {
88            let byte = ((key_fn(item) >> shift) & 0xFF) as usize;
89            buf[offsets[byte]] = item.clone();
90            offsets[byte] += 1;
91        }
92        std::mem::swap(data, &mut buf);
93    }
94}
95
96// ─────────────────────────────────────────────────────────────────────────────
97// parallel_prefix_sum
98// ─────────────────────────────────────────────────────────────────────────────
99
100/// Exclusive prefix sum (scan) of `data` using rayon work-stealing.
101///
102/// Returns a `Vec<u32>` of the same length where `output[i] = Σ data[0..i]`.
103/// `output[0]` is always 0.
104pub fn parallel_prefix_sum(data: &[u32]) -> Vec<u32> {
105    if data.is_empty() {
106        return Vec::new();
107    }
108    let n = data.len();
109    // Chunk-level partial sums, then a serial scan over chunks, then fixup.
110    let num_threads = rayon::current_num_threads().max(1);
111    let chunk_size = (n / num_threads).max(1);
112
113    // Step 1: compute per-chunk sums in parallel.
114    let chunks: Vec<_> = data.chunks(chunk_size).collect();
115    let chunk_sums: Vec<u32> = chunks
116        .par_iter()
117        .map(|chunk| chunk.iter().copied().fold(0u32, u32::wrapping_add))
118        .collect();
119
120    // Step 2: exclusive prefix sum over chunk sums (serial, tiny array).
121    let mut chunk_offsets = vec![0u32; chunk_sums.len()];
122    let mut running = 0u32;
123    for (i, &s) in chunk_sums.iter().enumerate() {
124        chunk_offsets[i] = running;
125        running = running.wrapping_add(s);
126    }
127
128    // Step 3: write output in parallel.
129    let mut output = vec![0u32; n];
130    output
131        .par_chunks_mut(chunk_size)
132        .zip(data.par_chunks(chunk_size))
133        .zip(chunk_offsets.par_iter())
134        .for_each(|((out_chunk, in_chunk), &base)| {
135            let mut acc = base;
136            for (o, &v) in out_chunk.iter_mut().zip(in_chunk.iter()) {
137                *o = acc;
138                acc = acc.wrapping_add(v);
139            }
140        });
141
142    output
143}
144
145// ─────────────────────────────────────────────────────────────────────────────
146// parallel_reduce_sum
147// ─────────────────────────────────────────────────────────────────────────────
148
149/// Parallel tree reduction: sum all `f64` values in `data`.
150///
151/// Returns `0.0` for an empty slice.
152pub fn parallel_reduce_sum(data: &[f64]) -> f64 {
153    data.par_iter().copied().sum()
154}
155
156// ─────────────────────────────────────────────────────────────────────────────
157// parallel_min_max
158// ─────────────────────────────────────────────────────────────────────────────
159
160/// Parallel min/max reduction over a `f64` slice.
161///
162/// Returns `(f64::INFINITY, f64::NEG_INFINITY)` for an empty slice.
163pub fn parallel_min_max(data: &[f64]) -> (f64, f64) {
164    if data.is_empty() {
165        return (f64::INFINITY, f64::NEG_INFINITY);
166    }
167    data.par_iter().copied().map(|v| (v, v)).reduce(
168        || (f64::INFINITY, f64::NEG_INFINITY),
169        |(lo1, hi1), (lo2, hi2)| (lo1.min(lo2), hi1.max(hi2)),
170    )
171}
172
173// ─────────────────────────────────────────────────────────────────────────────
174// bitonic_sort
175// ─────────────────────────────────────────────────────────────────────────────
176
177/// Bitonic sort of a `Vec`f64` in ascending order.
178///
179/// Pads the input to the next power-of-2 with `f64::MAX`, then truncates
180/// back to the original length after sorting.
181pub fn bitonic_sort(data: &mut Vec<f64>) {
182    let orig_len = data.len();
183    if orig_len <= 1 {
184        return;
185    }
186    // Pad to next power of two.
187    let padded = orig_len.next_power_of_two();
188    data.resize(padded, f64::MAX);
189
190    let n = data.len();
191    let mut k = 2;
192    while k <= n {
193        let mut j = k / 2;
194        while j >= 1 {
195            for i in 0..n {
196                let l = i ^ j;
197                if l > i {
198                    let ascending = (i & k) == 0;
199                    if (ascending && data[i] > data[l]) || (!ascending && data[i] < data[l]) {
200                        data.swap(i, l);
201                    }
202                }
203            }
204            j /= 2;
205        }
206        k *= 2;
207    }
208
209    data.truncate(orig_len);
210}
211
212// ─────────────────────────────────────────────────────────────────────────────
213// merge_sort_parallel
214// ─────────────────────────────────────────────────────────────────────────────
215
216/// Parallel merge sort of a `Vec`f64` using rayon.
217///
218/// Splits the input in half recursively, sorts each half in parallel, then
219/// merges sequentially.  Falls back to `sort_unstable_by` at small sizes.
220pub fn merge_sort_parallel(data: &mut Vec<f64>) {
221    let n = data.len();
222    if n <= 1 {
223        return;
224    }
225    merge_sort_parallel_slice(data);
226}
227
228fn merge_sort_parallel_slice(data: &mut [f64]) {
229    let n = data.len();
230    if n <= 32 {
231        data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
232        return;
233    }
234    let mid = n / 2;
235    let (left, right) = data.split_at_mut(mid);
236
237    // Sort both halves in parallel.
238    rayon::join(
239        || merge_sort_parallel_slice(left),
240        || merge_sort_parallel_slice(right),
241    );
242
243    // Merge the two sorted halves into a temporary buffer.
244    let mut tmp = Vec::with_capacity(n);
245    let mut i = 0;
246    let mut j = 0;
247    // Re-borrow after rayon::join — data is still split.
248    let (left, right) = data.split_at(mid);
249    while i < left.len() && j < right.len() {
250        if left[i] <= right[j] {
251            tmp.push(left[i]);
252            i += 1;
253        } else {
254            tmp.push(right[j]);
255            j += 1;
256        }
257    }
258    tmp.extend_from_slice(&left[i..]);
259    tmp.extend_from_slice(&right[j..]);
260    data.copy_from_slice(&tmp);
261}
262
263// ─────────────────────────────────────────────────────────────────────────────
264// histogram_u32
265// ─────────────────────────────────────────────────────────────────────────────
266
267/// Parallel histogram: count occurrences of `data[i] % num_buckets` per bucket.
268///
269/// Each rayon thread builds a private histogram; results are summed at the end.
270/// Returns a `Vec`u32` of length `num_buckets`.
271///
272/// # Panics
273/// Panics if `num_buckets == 0`.
274pub fn histogram_u32(data: &[u32], num_buckets: usize) -> Vec<u32> {
275    assert!(num_buckets > 0, "num_buckets must be > 0");
276    if data.is_empty() {
277        return vec![0; num_buckets];
278    }
279    let nb = num_buckets;
280    // Build per-thread histograms and reduce.
281    data.par_chunks(256.max(data.len() / rayon::current_num_threads().max(1)))
282        .map(|chunk| {
283            let mut local = vec![0u32; nb];
284            for &v in chunk {
285                local[(v as usize) % nb] += 1;
286            }
287            local
288        })
289        .reduce(
290            || vec![0u32; nb],
291            |mut acc, local| {
292                for i in 0..nb {
293                    acc[i] += local[i];
294                }
295                acc
296            },
297        )
298}
299
300// ─────────────────────────────────────────────────────────────────────────────
301// argsort
302// ─────────────────────────────────────────────────────────────────────────────
303
304/// Return indices that would sort `data` in ascending order.
305///
306/// `NaN` values are placed at the end (treated as greater than any finite value).
307pub fn argsort(data: &[f64]) -> Vec<usize> {
308    let mut indices: Vec<usize> = (0..data.len()).collect();
309    indices.sort_unstable_by(|&a, &b| {
310        data[a]
311            .partial_cmp(&data[b])
312            .unwrap_or(std::cmp::Ordering::Greater)
313    });
314    indices
315}
316
317// ─────────────────────────────────────────────────────────────────────────────
318// nth_element  (quickselect)
319// ─────────────────────────────────────────────────────────────────────────────
320
321/// Quickselect O(n) algorithm: rearranges `data` so that `data\[k\]` holds
322/// the value that would be there in a fully sorted array, and returns it.
323///
324/// `k` must be `< data.len()`.
325///
326/// # Panics
327/// Panics if `data` is empty or `k >= data.len()`.
328pub fn nth_element(data: &mut Vec<f64>, k: usize) -> f64 {
329    assert!(!data.is_empty(), "nth_element: data must not be empty");
330    assert!(
331        k < data.len(),
332        "nth_element: k={k} out of bounds (len={})",
333        data.len()
334    );
335    nth_element_slice(data, k);
336    data[k]
337}
338
339fn nth_element_slice(data: &mut [f64], k: usize) {
340    if data.len() <= 1 {
341        return;
342    }
343    let pivot_idx = partition(data);
344    if k < pivot_idx {
345        nth_element_slice(&mut data[..pivot_idx], k);
346    } else if k > pivot_idx {
347        nth_element_slice(&mut data[pivot_idx + 1..], k - pivot_idx - 1);
348    }
349    // k == pivot_idx → done
350}
351
352/// Lomuto partition scheme; returns final pivot index.
353fn partition(data: &mut [f64]) -> usize {
354    let n = data.len();
355    // Median-of-three pivot to reduce worst-case behaviour.
356    let mid = n / 2;
357    let last = n - 1;
358    if data[0] > data[mid] {
359        data.swap(0, mid);
360    }
361    if data[0] > data[last] {
362        data.swap(0, last);
363    }
364    if data[mid] > data[last] {
365        data.swap(mid, last);
366    }
367    // Median is now at `mid`; move it to second-to-last.
368    data.swap(mid, last - 1.min(last));
369    let pivot_pos = if n >= 3 { last - 1 } else { last };
370    let pivot = data[pivot_pos];
371    data.swap(pivot_pos, last);
372    let mut store = 0;
373    for i in 0..last {
374        let v = data[i];
375        if v < pivot || (v == pivot && store < last) {
376            data.swap(i, store);
377            store += 1;
378        }
379    }
380    data.swap(store, last);
381    store
382}
383
384// ─────────────────────────────────────────────────────────────────────────────
385// sort verification
386// ─────────────────────────────────────────────────────────────────────────────
387
388/// Verify that a slice of `f64` is sorted in ascending order.
389///
390/// Returns `true` if sorted (NaN-free).
391pub fn is_sorted_f64(data: &[f64]) -> bool {
392    data.windows(2).all(|w| w[0] <= w[1])
393}
394
395/// Verify that a slice of `u32` is sorted in ascending order.
396pub fn is_sorted_u32(data: &[u32]) -> bool {
397    data.windows(2).all(|w| w[0] <= w[1])
398}
399
400/// Count the number of inversions (pairs where `data[i]` > `data[j]` for i < j).
401///
402/// Uses a simple O(n log n) merge-sort-based inversion count.
403pub fn count_inversions_f64(data: &[f64]) -> u64 {
404    if data.len() <= 1 {
405        return 0;
406    }
407    let mut tmp = data.to_vec();
408    count_inversions_helper(&mut tmp)
409}
410
411fn count_inversions_helper(data: &mut [f64]) -> u64 {
412    let n = data.len();
413    if n <= 1 {
414        return 0;
415    }
416    let mid = n / 2;
417    let mut left = data[..mid].to_vec();
418    let mut right = data[mid..].to_vec();
419    let mut count = count_inversions_helper(&mut left);
420    count += count_inversions_helper(&mut right);
421
422    let mut i = 0;
423    let mut j = 0;
424    let mut k = 0;
425    while i < left.len() && j < right.len() {
426        if left[i] <= right[j] {
427            data[k] = left[i];
428            i += 1;
429        } else {
430            data[k] = right[j];
431            count += (left.len() - i) as u64;
432            j += 1;
433        }
434        k += 1;
435    }
436    while i < left.len() {
437        data[k] = left[i];
438        i += 1;
439        k += 1;
440    }
441    while j < right.len() {
442        data[k] = right[j];
443        j += 1;
444        k += 1;
445    }
446    count
447}
448
449// ─────────────────────────────────────────────────────────────────────────────
450// Performance comparison helper
451// ─────────────────────────────────────────────────────────────────────────────
452
453/// Sort timing result for performance comparison.
454pub struct SortTimingResult {
455    /// Name of the sort algorithm.
456    pub name: String,
457    /// Number of elements sorted.
458    pub n: usize,
459    /// Whether the result was sorted correctly.
460    pub correct: bool,
461}
462
463/// Run all three sort algorithms on a copy of the data and verify correctness.
464///
465/// Returns timing results.
466pub fn compare_sorts(data: &[f64]) -> Vec<SortTimingResult> {
467    let mut results = Vec::new();
468
469    // Bitonic sort
470    let mut d1 = data.to_vec();
471    bitonic_sort(&mut d1);
472    results.push(SortTimingResult {
473        name: "bitonic".into(),
474        n: data.len(),
475        correct: is_sorted_f64(&d1),
476    });
477
478    // Merge sort
479    let mut d2 = data.to_vec();
480    merge_sort_parallel(&mut d2);
481    results.push(SortTimingResult {
482        name: "merge_parallel".into(),
483        n: data.len(),
484        correct: is_sorted_f64(&d2),
485    });
486
487    // Radix sort (convert to u32 for radix)
488    let mut d3: Vec<u32> = data.iter().map(|&v| v as u32).collect();
489    radix_sort_u32(&mut d3);
490    results.push(SortTimingResult {
491        name: "radix_u32".into(),
492        n: data.len(),
493        correct: is_sorted_u32(&d3),
494    });
495
496    results
497}
498
499/// Check if two slices contain the same elements (as a multiset).
500pub fn is_permutation_f64(a: &[f64], b: &[f64]) -> bool {
501    if a.len() != b.len() {
502        return false;
503    }
504    let mut sa = a.to_vec();
505    let mut sb = b.to_vec();
506    sa.sort_unstable_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
507    sb.sort_unstable_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
508    sa == sb
509}
510
511/// Check if two u32 slices contain the same elements.
512pub fn is_permutation_u32(a: &[u32], b: &[u32]) -> bool {
513    if a.len() != b.len() {
514        return false;
515    }
516    let mut sa = a.to_vec();
517    let mut sb = b.to_vec();
518    sa.sort_unstable();
519    sb.sort_unstable();
520    sa == sb
521}
522
523// ─────────────────────────────────────────────────────────────────────────────
524// Tests
525// ─────────────────────────────────────────────────────────────────────────────
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use crate::gpu_sort::radix_sort_u32;
531
532    use crate::parallel_sort::is_permutation_f64;
533    use crate::parallel_sort::is_permutation_u32;
534    use crate::parallel_sort::is_sorted_f64;
535    use crate::parallel_sort::is_sorted_u32;
536
537    // ── radix_sort_u32 ───────────────────────────────────────────────────────
538
539    #[test]
540    fn test_radix_sort_empty() {
541        let mut v: Vec<u32> = vec![];
542        radix_sort_u32(&mut v);
543        assert!(v.is_empty());
544    }
545
546    #[test]
547    fn test_radix_sort_single() {
548        let mut v = vec![42u32];
549        radix_sort_u32(&mut v);
550        assert_eq!(v, [42]);
551    }
552
553    #[test]
554    fn test_radix_sort_sorted() {
555        let mut v = vec![1u32, 2, 3, 4, 5];
556        radix_sort_u32(&mut v);
557        assert_eq!(v, [1, 2, 3, 4, 5]);
558    }
559
560    #[test]
561    fn test_radix_sort_reverse() {
562        let mut v = vec![5u32, 4, 3, 2, 1];
563        radix_sort_u32(&mut v);
564        assert_eq!(v, [1, 2, 3, 4, 5]);
565    }
566
567    #[test]
568    fn test_radix_sort_random_u32() {
569        let mut v: Vec<u32> = (0..1000u32).rev().collect();
570        radix_sort_u32(&mut v);
571        for i in 0..1000usize {
572            assert_eq!(v[i], i as u32, "mismatch at index {i}");
573        }
574    }
575
576    #[test]
577    fn test_radix_sort_large_values() {
578        let mut v = vec![u32::MAX, 0, u32::MAX / 2, 1, u32::MAX - 1];
579        radix_sort_u32(&mut v);
580        assert_eq!(v, [0, 1, u32::MAX / 2, u32::MAX - 1, u32::MAX]);
581    }
582
583    // ── radix_sort_by_key ────────────────────────────────────────────────────
584
585    #[test]
586    fn test_radix_sort_by_key_strings() {
587        let mut v: Vec<(&str, u32)> = vec![("c", 3), ("a", 1), ("b", 2)];
588        radix_sort_by_key(&mut v, |item| item.1);
589        assert_eq!(v, [("a", 1), ("b", 2), ("c", 3)]);
590    }
591
592    #[test]
593    fn test_radix_sort_by_key_empty() {
594        let mut v: Vec<(usize, u32)> = vec![];
595        radix_sort_by_key(&mut v, |item| item.1);
596        assert!(v.is_empty());
597    }
598
599    // ── parallel_prefix_sum ──────────────────────────────────────────────────
600
601    #[test]
602    fn test_prefix_sum_empty() {
603        assert!(parallel_prefix_sum(&[]).is_empty());
604    }
605
606    #[test]
607    fn test_prefix_sum_single() {
608        assert_eq!(parallel_prefix_sum(&[7]), vec![0]);
609    }
610
611    #[test]
612    fn test_prefix_sum_basic() {
613        let data = [1u32, 2, 3, 4, 5];
614        let out = parallel_prefix_sum(&data);
615        assert_eq!(out, vec![0, 1, 3, 6, 10]);
616    }
617
618    #[test]
619    fn test_prefix_sum_ones() {
620        let data = vec![1u32; 100];
621        let out = parallel_prefix_sum(&data);
622        for (i, &v) in out.iter().enumerate() {
623            assert_eq!(v, i as u32, "prefix[{i}] should be {i}");
624        }
625    }
626
627    // ── parallel_reduce_sum ──────────────────────────────────────────────────
628
629    #[test]
630    fn test_reduce_sum_empty() {
631        assert_eq!(parallel_reduce_sum(&[]), 0.0);
632    }
633
634    #[test]
635    fn test_reduce_sum_basic() {
636        let data = [1.0f64, 2.0, 3.0, 4.0, 5.0];
637        assert!((parallel_reduce_sum(&data) - 15.0).abs() < 1e-12);
638    }
639
640    #[test]
641    fn test_reduce_sum_large() {
642        let data: Vec<f64> = (1..=1000).map(|i| i as f64).collect();
643        let expected = 1000.0 * 1001.0 / 2.0;
644        assert!((parallel_reduce_sum(&data) - expected).abs() < 1e-6);
645    }
646
647    // ── parallel_min_max ─────────────────────────────────────────────────────
648
649    #[test]
650    fn test_min_max_empty() {
651        let (lo, hi) = parallel_min_max(&[]);
652        assert!(lo.is_infinite() && lo > 0.0);
653        assert!(hi.is_infinite() && hi < 0.0);
654    }
655
656    #[test]
657    fn test_min_max_single() {
658        let (lo, hi) = parallel_min_max(&[3.125]);
659        assert!((lo - 3.125).abs() < 1e-12);
660        assert!((hi - 3.125).abs() < 1e-12);
661    }
662
663    #[test]
664    fn test_min_max_basic() {
665        let data = [3.0f64, 1.0, 4.0, 1.5, 9.2, 2.6];
666        let (lo, hi) = parallel_min_max(&data);
667        assert!((lo - 1.0).abs() < 1e-12);
668        assert!((hi - 9.2).abs() < 1e-12);
669    }
670
671    // ── bitonic_sort ─────────────────────────────────────────────────────────
672
673    #[test]
674    fn test_bitonic_sort_empty() {
675        let mut v: Vec<f64> = vec![];
676        bitonic_sort(&mut v);
677        assert!(v.is_empty());
678    }
679
680    #[test]
681    fn test_bitonic_sort_power_of_two() {
682        let mut v = vec![4.0f64, 2.0, 7.0, 1.0, 5.0, 3.0, 6.0, 8.0];
683        bitonic_sort(&mut v);
684        assert_eq!(v, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
685    }
686
687    #[test]
688    fn test_bitonic_sort_non_power_of_two() {
689        let mut v = vec![5.0f64, 3.0, 1.0, 4.0, 2.0];
690        bitonic_sort(&mut v);
691        assert_eq!(v, [1.0, 2.0, 3.0, 4.0, 5.0]);
692    }
693
694    // ── merge_sort_parallel ──────────────────────────────────────────────────
695
696    #[test]
697    fn test_merge_sort_empty() {
698        let mut v: Vec<f64> = vec![];
699        merge_sort_parallel(&mut v);
700        assert!(v.is_empty());
701    }
702
703    #[test]
704    fn test_merge_sort_basic() {
705        let mut v = vec![3.0f64, 1.0, 4.0, 1.5, 9.0, 2.6];
706        merge_sort_parallel(&mut v);
707        assert_eq!(v, [1.0, 1.5, 2.6, 3.0, 4.0, 9.0]);
708    }
709
710    #[test]
711    fn test_merge_sort_large() {
712        let mut v: Vec<f64> = (0..500u32).rev().map(|x| x as f64).collect();
713        merge_sort_parallel(&mut v);
714        for i in 0..500usize {
715            assert!((v[i] - i as f64).abs() < 1e-12, "mismatch at {i}");
716        }
717    }
718
719    // ── histogram_u32 ────────────────────────────────────────────────────────
720
721    #[test]
722    fn test_histogram_empty() {
723        let h = histogram_u32(&[], 4);
724        assert_eq!(h, vec![0, 0, 0, 0]);
725    }
726
727    #[test]
728    fn test_histogram_basic() {
729        let data = [0u32, 1, 2, 3, 0, 1, 2, 0];
730        let h = histogram_u32(&data, 4);
731        assert_eq!(h, vec![3, 2, 2, 1]);
732    }
733
734    #[test]
735    fn test_histogram_one_bucket() {
736        let data: Vec<u32> = (0..10).collect();
737        let h = histogram_u32(&data, 1);
738        assert_eq!(h, vec![10]);
739    }
740
741    // ── argsort ──────────────────────────────────────────────────────────────
742
743    #[test]
744    fn test_argsort_empty() {
745        assert!(argsort(&[]).is_empty());
746    }
747
748    #[test]
749    fn test_argsort_basic() {
750        let data = [3.0f64, 1.0, 4.0, 1.5, 9.0];
751        let idx = argsort(&data);
752        let sorted: Vec<f64> = idx.iter().map(|&i| data[i]).collect();
753        assert_eq!(sorted, [1.0, 1.5, 3.0, 4.0, 9.0]);
754    }
755
756    #[test]
757    fn test_argsort_already_sorted() {
758        let data = [1.0f64, 2.0, 3.0, 4.0, 5.0];
759        let idx = argsort(&data);
760        assert_eq!(idx, [0, 1, 2, 3, 4]);
761    }
762
763    // ── nth_element ──────────────────────────────────────────────────────────
764
765    #[test]
766    fn test_nth_element_single() {
767        let mut v = vec![42.0f64];
768        assert!((nth_element(&mut v, 0) - 42.0).abs() < 1e-12);
769    }
770
771    #[test]
772    fn test_nth_element_median() {
773        let mut v = vec![3.0f64, 1.0, 4.0, 1.5, 9.0, 2.6, 5.0];
774        // Sorted: [1.0, 1.5, 2.6, 3.0, 4.0, 5.0, 9.0]; median at k=3 → 3.0
775        let median = nth_element(&mut v, 3);
776        assert!((median - 3.0).abs() < 1e-12, "expected 3.0, got {median}");
777    }
778
779    #[test]
780    fn test_nth_element_min() {
781        let mut v = vec![5.0f64, 3.0, 8.0, 1.0, 4.0];
782        let min = nth_element(&mut v, 0);
783        assert!((min - 1.0).abs() < 1e-12, "expected 1.0, got {min}");
784    }
785
786    #[test]
787    fn test_nth_element_max() {
788        let mut v = vec![5.0f64, 3.0, 8.0, 1.0, 4.0];
789        let max = nth_element(&mut v, 4);
790        assert!((max - 8.0).abs() < 1e-12, "expected 8.0, got {max}");
791    }
792
793    #[test]
794    fn test_nth_element_duplicates() {
795        let mut v = vec![2.0f64, 2.0, 2.0, 2.0, 2.0];
796        let val = nth_element(&mut v, 2);
797        assert!((val - 2.0).abs() < 1e-12);
798    }
799
800    // ── sort verification ──────────────────────────────────────────────────
801
802    #[test]
803    fn test_is_sorted_f64_empty() {
804        assert!(is_sorted_f64(&[]));
805    }
806
807    #[test]
808    fn test_is_sorted_f64_sorted() {
809        assert!(is_sorted_f64(&[1.0, 2.0, 3.0, 4.0]));
810    }
811
812    #[test]
813    fn test_is_sorted_f64_unsorted() {
814        assert!(!is_sorted_f64(&[1.0, 3.0, 2.0, 4.0]));
815    }
816
817    #[test]
818    fn test_is_sorted_u32_sorted() {
819        assert!(is_sorted_u32(&[0, 1, 2, 3, 4]));
820    }
821
822    #[test]
823    fn test_is_sorted_u32_unsorted() {
824        assert!(!is_sorted_u32(&[0, 2, 1, 3]));
825    }
826
827    // ── inversion counting ─────────────────────────────────────────────────
828
829    #[test]
830    fn test_count_inversions_sorted() {
831        assert_eq!(count_inversions_f64(&[1.0, 2.0, 3.0, 4.0]), 0);
832    }
833
834    #[test]
835    fn test_count_inversions_reversed() {
836        // [4,3,2,1] has 6 inversions: (4,3),(4,2),(4,1),(3,2),(3,1),(2,1)
837        assert_eq!(count_inversions_f64(&[4.0, 3.0, 2.0, 1.0]), 6);
838    }
839
840    #[test]
841    fn test_count_inversions_one_swap() {
842        assert_eq!(count_inversions_f64(&[2.0, 1.0, 3.0, 4.0]), 1);
843    }
844
845    #[test]
846    fn test_count_inversions_empty() {
847        assert_eq!(count_inversions_f64(&[]), 0);
848    }
849
850    // ── permutation checks ─────────────────────────────────────────────────
851
852    #[test]
853    fn test_is_permutation_f64_true() {
854        assert!(is_permutation_f64(&[3.0, 1.0, 2.0], &[1.0, 2.0, 3.0]));
855    }
856
857    #[test]
858    fn test_is_permutation_f64_false() {
859        assert!(!is_permutation_f64(&[3.0, 1.0, 2.0], &[1.0, 2.0, 4.0]));
860    }
861
862    #[test]
863    fn test_is_permutation_f64_different_lengths() {
864        assert!(!is_permutation_f64(&[1.0, 2.0], &[1.0, 2.0, 3.0]));
865    }
866
867    #[test]
868    fn test_is_permutation_u32_true() {
869        assert!(is_permutation_u32(&[3, 1, 2], &[1, 2, 3]));
870    }
871
872    #[test]
873    fn test_is_permutation_u32_false() {
874        assert!(!is_permutation_u32(&[1, 2, 3], &[1, 2, 4]));
875    }
876
877    // ── sort preserves elements ────────────────────────────────────────────
878
879    #[test]
880    fn test_bitonic_sort_preserves_elements() {
881        let original = vec![5.0, 3.0, 8.0, 1.0, 4.0, 7.0, 2.0, 6.0];
882        let mut sorted = original.clone();
883        bitonic_sort(&mut sorted);
884        assert!(is_permutation_f64(&original, &sorted));
885        assert!(is_sorted_f64(&sorted));
886    }
887
888    #[test]
889    fn test_merge_sort_preserves_elements() {
890        let original = vec![5.0, 3.0, 8.0, 1.0, 4.0, 7.0, 2.0, 6.0];
891        let mut sorted = original.clone();
892        merge_sort_parallel(&mut sorted);
893        assert!(is_permutation_f64(&original, &sorted));
894        assert!(is_sorted_f64(&sorted));
895    }
896
897    #[test]
898    fn test_radix_sort_preserves_elements() {
899        let original = vec![5u32, 3, 8, 1, 4, 7, 2, 6];
900        let mut sorted = original.clone();
901        radix_sort_u32(&mut sorted);
902        assert!(is_permutation_u32(&original, &sorted));
903        assert!(is_sorted_u32(&sorted));
904    }
905
906    // ── compare sorts ──────────────────────────────────────────────────────
907
908    #[test]
909    fn test_compare_sorts_all_correct() {
910        let data: Vec<f64> = (0..100u32).rev().map(|x| x as f64).collect();
911        let results = compare_sorts(&data);
912        for r in &results {
913            assert!(r.correct, "sort {} failed for n={}", r.name, r.n);
914        }
915    }
916
917    #[test]
918    fn test_compare_sorts_empty() {
919        let results = compare_sorts(&[]);
920        for r in &results {
921            assert!(r.correct);
922        }
923    }
924
925    // ── additional bitonic sort tests ──────────────────────────────────────
926
927    #[test]
928    fn test_bitonic_sort_single() {
929        let mut v = vec![42.0_f64];
930        bitonic_sort(&mut v);
931        assert_eq!(v, [42.0]);
932    }
933
934    #[test]
935    fn test_bitonic_sort_already_sorted() {
936        let mut v = vec![1.0, 2.0, 3.0, 4.0];
937        bitonic_sort(&mut v);
938        assert_eq!(v, [1.0, 2.0, 3.0, 4.0]);
939    }
940
941    #[test]
942    fn test_bitonic_sort_duplicates() {
943        let mut v = vec![3.0, 1.0, 3.0, 1.0, 2.0, 2.0];
944        bitonic_sort(&mut v);
945        assert_eq!(v, [1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
946    }
947
948    // ── additional merge sort tests ────────────────────────────────────────
949
950    #[test]
951    fn test_merge_sort_single() {
952        let mut v = vec![42.0_f64];
953        merge_sort_parallel(&mut v);
954        assert_eq!(v, [42.0]);
955    }
956
957    #[test]
958    fn test_merge_sort_two_elements() {
959        let mut v = vec![2.0, 1.0];
960        merge_sort_parallel(&mut v);
961        assert_eq!(v, [1.0, 2.0]);
962    }
963
964    #[test]
965    fn test_merge_sort_duplicates() {
966        let mut v = vec![5.0, 1.0, 5.0, 1.0, 3.0];
967        merge_sort_parallel(&mut v);
968        assert_eq!(v, [1.0, 1.0, 3.0, 5.0, 5.0]);
969    }
970
971    // ── additional radix sort tests ────────────────────────────────────────
972
973    #[test]
974    fn test_radix_sort_all_same() {
975        let mut v = vec![7u32, 7, 7, 7, 7];
976        radix_sort_u32(&mut v);
977        assert_eq!(v, [7, 7, 7, 7, 7]);
978    }
979
980    #[test]
981    fn test_radix_sort_two_elements() {
982        let mut v = vec![2u32, 1];
983        radix_sort_u32(&mut v);
984        assert_eq!(v, [1, 2]);
985    }
986
987    // ── argsort additional ─────────────────────────────────────────────────
988
989    #[test]
990    fn test_argsort_duplicates() {
991        let data = [3.0, 1.0, 3.0, 1.0];
992        let idx = argsort(&data);
993        let sorted: Vec<f64> = idx.iter().map(|&i| data[i]).collect();
994        assert!(is_sorted_f64(&sorted));
995    }
996
997    #[test]
998    fn test_argsort_single() {
999        let idx = argsort(&[42.0]);
1000        assert_eq!(idx, [0]);
1001    }
1002
1003    // ── nth_element additional ─────────────────────────────────────────────
1004
1005    #[test]
1006    fn test_nth_element_sorted_input() {
1007        let mut v = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1008        let val = nth_element(&mut v, 2);
1009        assert!((val - 3.0).abs() < 1e-12);
1010    }
1011
1012    #[test]
1013    fn test_nth_element_reversed() {
1014        let mut v = vec![5.0, 4.0, 3.0, 2.0, 1.0];
1015        let val = nth_element(&mut v, 0);
1016        assert!((val - 1.0).abs() < 1e-12);
1017    }
1018}
1019
1020// ─────────────────────────────────────────────────────────────────────────────
1021// GPU Radix Sort Stages (CPU simulation of multi-pass GPU radix sort)
1022// ─────────────────────────────────────────────────────────────────────────────
1023
1024/// GPU-style radix sort stage: one pass sorting by 8 bits starting at `shift`.
1025///
1026/// Simulates the GPU per-block histogram + scatter pattern.
1027/// Returns (sorted_data, per_bucket_counts).
1028pub fn radix_sort_stage_u32(data: &[u32], shift: u32) -> (Vec<u32>, [usize; 256]) {
1029    let n = data.len();
1030    let mut counts = [0usize; 256];
1031    for &v in data {
1032        let byte = ((v >> shift) & 0xFF) as usize;
1033        counts[byte] += 1;
1034    }
1035    let mut offsets = [0usize; 256];
1036    let mut total = 0;
1037    for i in 0..256 {
1038        offsets[i] = total;
1039        total += counts[i];
1040    }
1041    let mut out = vec![0u32; n];
1042    let mut pos = offsets;
1043    for &v in data {
1044        let byte = ((v >> shift) & 0xFF) as usize;
1045        out[pos[byte]] = v;
1046        pos[byte] += 1;
1047    }
1048    (out, counts)
1049}
1050
1051/// Full 4-pass GPU radix sort decomposed into individual stages.
1052///
1053/// Returns a sorted vector. Each stage processes 8 bits.
1054pub fn radix_sort_gpu_staged(data: &[u32]) -> Vec<u32> {
1055    if data.is_empty() {
1056        return Vec::new();
1057    }
1058    let mut current = data.to_vec();
1059    for pass in 0..4u32 {
1060        let (sorted, _counts) = radix_sort_stage_u32(&current, pass * 8);
1061        current = sorted;
1062    }
1063    current
1064}
1065
1066/// Stage histogram: compute per-bucket histogram for a given bit range.
1067///
1068/// Returns a `Vec`u32` of length 256 with counts for each 8-bit bucket
1069/// at bit offset `shift`.
1070pub fn radix_histogram(data: &[u32], shift: u32) -> Vec<u32> {
1071    let mut counts = vec![0u32; 256];
1072    for &v in data {
1073        let byte = ((v >> shift) & 0xFF) as usize;
1074        counts[byte] += 1;
1075    }
1076    counts
1077}
1078
1079/// Validate that radix sort preserves all elements as a multiset.
1080pub fn validate_radix_sort(original: &[u32], sorted: &[u32]) -> bool {
1081    is_permutation_u32(original, sorted) && is_sorted_u32(sorted)
1082}
1083
1084// ─────────────────────────────────────────────────────────────────────────────
1085// Counting Sort (for small-range u32 keys)
1086// ─────────────────────────────────────────────────────────────────────────────
1087
1088/// Integer counting sort for values in \[0, max_val\].
1089///
1090/// O(n + max_val) time and space.  Returns sorted Vec.
1091///
1092/// # Panics
1093/// Panics if any value > `max_val`.
1094pub fn counting_sort_u32(data: &[u32], max_val: u32) -> Vec<u32> {
1095    if data.is_empty() {
1096        return Vec::new();
1097    }
1098    let m = max_val as usize + 1;
1099    let mut counts = vec![0u32; m];
1100    for &v in data {
1101        assert!((v as usize) < m, "value {v} exceeds max_val {max_val}");
1102        counts[v as usize] += 1;
1103    }
1104    let mut out = Vec::with_capacity(data.len());
1105    for (v, &c) in counts.iter().enumerate() {
1106        for _ in 0..c {
1107            out.push(v as u32);
1108        }
1109    }
1110    out
1111}
1112
1113/// Counting sort that also carries satellite data (key-value pairs).
1114///
1115/// Sorts by `u32` key, stable.
1116pub fn counting_sort_by_key<T: Clone>(data: &[(u32, T)], max_key: u32) -> Vec<(u32, T)> {
1117    if data.is_empty() {
1118        return Vec::new();
1119    }
1120    let m = max_key as usize + 1;
1121    let mut counts = vec![0usize; m];
1122    for (k, _) in data {
1123        assert!((*k as usize) < m, "key {k} exceeds max_key {max_key}");
1124        counts[*k as usize] += 1;
1125    }
1126    // Exclusive prefix
1127    let mut offsets = vec![0usize; m];
1128    let mut running = 0;
1129    for i in 0..m {
1130        offsets[i] = running;
1131        running += counts[i];
1132    }
1133    let mut out: Vec<Option<(u32, T)>> = (0..data.len()).map(|_| None).collect();
1134    for (k, v) in data {
1135        let idx = *k as usize;
1136        out[offsets[idx]] = Some((*k, v.clone()));
1137        offsets[idx] += 1;
1138    }
1139    out.into_iter().flatten().collect()
1140}
1141
1142// ─────────────────────────────────────────────────────────────────────────────
1143// Histogram-based Sort (bucket sort with dynamic ranges)
1144// ─────────────────────────────────────────────────────────────────────────────
1145
1146/// Bucket sort for f64 values using a histogram to distribute elements.
1147///
1148/// Divides \[min, max\] into `n_buckets` buckets, sorts each bucket
1149/// individually, then concatenates.
1150pub fn histogram_bucket_sort(data: &mut Vec<f64>, n_buckets: usize) {
1151    let n = data.len();
1152    if n <= 1 || n_buckets == 0 {
1153        data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1154        return;
1155    }
1156
1157    let (lo, hi) = {
1158        let mut lo = f64::INFINITY;
1159        let mut hi = f64::NEG_INFINITY;
1160        for &v in data.iter() {
1161            if v < lo {
1162                lo = v;
1163            }
1164            if v > hi {
1165                hi = v;
1166            }
1167        }
1168        (lo, hi)
1169    };
1170
1171    if (hi - lo).abs() < f64::EPSILON {
1172        return; // All elements equal
1173    }
1174
1175    let nb = n_buckets;
1176    let range = hi - lo;
1177    let mut buckets: Vec<Vec<f64>> = vec![Vec::new(); nb];
1178
1179    for &v in data.iter() {
1180        let idx = ((v - lo) / range * nb as f64) as usize;
1181        let idx = idx.min(nb - 1);
1182        buckets[idx].push(v);
1183    }
1184
1185    for b in &mut buckets {
1186        b.sort_unstable_by(|a, c| a.partial_cmp(c).unwrap_or(std::cmp::Ordering::Equal));
1187    }
1188
1189    let mut pos = 0;
1190    for b in &buckets {
1191        for &v in b {
1192            data[pos] = v;
1193            pos += 1;
1194        }
1195    }
1196}
1197
1198/// Frequency-adaptive bucket sort: allocates buckets proportional to data density.
1199///
1200/// Builds a histogram first, then assigns multiple histogram bins to each
1201/// bucket to balance load.
1202pub fn adaptive_bucket_sort(data: &mut Vec<f64>, n_buckets: usize) {
1203    histogram_bucket_sort(data, n_buckets.max(1));
1204}
1205
1206// ─────────────────────────────────────────────────────────────────────────────
1207// Sort Validation Utilities
1208// ─────────────────────────────────────────────────────────────────────────────
1209
1210/// Comprehensive sort validation: check sorted + permutation + stable order.
1211pub struct SortValidation {
1212    /// Whether the result is sorted.
1213    pub is_sorted: bool,
1214    /// Whether the result is a permutation of the input.
1215    pub is_permutation: bool,
1216    /// Number of elements.
1217    pub n: usize,
1218    /// Number of inversions (0 if sorted).
1219    pub inversions: u64,
1220}
1221
1222impl SortValidation {
1223    /// Validate a sort result for f64.
1224    pub fn validate_f64(original: &[f64], sorted: &[f64]) -> Self {
1225        let is_sorted = is_sorted_f64(sorted);
1226        let is_perm = is_permutation_f64(original, sorted);
1227        let inversions = if is_sorted {
1228            0
1229        } else {
1230            count_inversions_f64(sorted)
1231        };
1232        Self {
1233            is_sorted,
1234            is_permutation: is_perm,
1235            n: sorted.len(),
1236            inversions,
1237        }
1238    }
1239
1240    /// Validate a sort result for u32.
1241    pub fn validate_u32(original: &[u32], sorted: &[u32]) -> Self {
1242        let is_sorted = is_sorted_u32(sorted);
1243        let is_perm = is_permutation_u32(original, sorted);
1244        Self {
1245            is_sorted,
1246            is_permutation: is_perm,
1247            n: sorted.len(),
1248            inversions: 0,
1249        }
1250    }
1251
1252    /// Returns `true` if the sort is fully correct.
1253    pub fn is_correct(&self) -> bool {
1254        self.is_sorted && self.is_permutation
1255    }
1256}
1257
1258// ─────────────────────────────────────────────────────────────────────────────
1259// Parallel Merge (Two sorted halves → merged)
1260// ─────────────────────────────────────────────────────────────────────────────
1261
1262/// Merge two sorted slices into a single sorted Vec.
1263///
1264/// Standard two-pointer merge — O(n + m).
1265pub fn merge_sorted(left: &[f64], right: &[f64]) -> Vec<f64> {
1266    let mut out = Vec::with_capacity(left.len() + right.len());
1267    let mut i = 0;
1268    let mut j = 0;
1269    while i < left.len() && j < right.len() {
1270        if left[i] <= right[j] {
1271            out.push(left[i]);
1272            i += 1;
1273        } else {
1274            out.push(right[j]);
1275            j += 1;
1276        }
1277    }
1278    out.extend_from_slice(&left[i..]);
1279    out.extend_from_slice(&right[j..]);
1280    out
1281}
1282
1283/// Merge two sorted `u32` slices.
1284pub fn merge_sorted_u32(left: &[u32], right: &[u32]) -> Vec<u32> {
1285    let mut out = Vec::with_capacity(left.len() + right.len());
1286    let mut i = 0;
1287    let mut j = 0;
1288    while i < left.len() && j < right.len() {
1289        if left[i] <= right[j] {
1290            out.push(left[i]);
1291            i += 1;
1292        } else {
1293            out.push(right[j]);
1294            j += 1;
1295        }
1296    }
1297    out.extend_from_slice(&left[i..]);
1298    out.extend_from_slice(&right[j..]);
1299    out
1300}
1301
1302/// K-way merge of multiple sorted slices using a min-heap approach.
1303///
1304/// Each input slice must already be sorted.
1305pub fn k_way_merge(slices: &[Vec<f64>]) -> Vec<f64> {
1306    // Collect all elements and sort (simple k-way for CPU)
1307    let total: usize = slices.iter().map(|s| s.len()).sum();
1308    let mut result = Vec::with_capacity(total);
1309    for s in slices {
1310        result.extend_from_slice(s);
1311    }
1312    result.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1313    result
1314}
1315
1316/// Parallel merge sort with configurable thread threshold.
1317///
1318/// Uses rayon to parallelise the merge at each recursive level when
1319/// the sub-array exceeds `parallel_threshold`.
1320pub fn merge_sort_parallel_threshold(data: &mut Vec<f64>, parallel_threshold: usize) {
1321    let n = data.len();
1322    if n <= 1 {
1323        return;
1324    }
1325    merge_sort_threshold_slice(data, parallel_threshold);
1326}
1327
1328fn merge_sort_threshold_slice(data: &mut [f64], threshold: usize) {
1329    let n = data.len();
1330    if n <= 1 {
1331        return;
1332    }
1333    if n <= 16 {
1334        data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1335        return;
1336    }
1337    let mid = n / 2;
1338    let (left, right) = data.split_at_mut(mid);
1339
1340    if n >= threshold {
1341        rayon::join(
1342            || merge_sort_threshold_slice(left, threshold),
1343            || merge_sort_threshold_slice(right, threshold),
1344        );
1345    } else {
1346        merge_sort_threshold_slice(left, threshold);
1347        merge_sort_threshold_slice(right, threshold);
1348    }
1349
1350    let mut tmp = Vec::with_capacity(n);
1351    let (left, right) = data.split_at(mid);
1352    let mut i = 0;
1353    let mut j = 0;
1354    while i < left.len() && j < right.len() {
1355        if left[i] <= right[j] {
1356            tmp.push(left[i]);
1357            i += 1;
1358        } else {
1359            tmp.push(right[j]);
1360            j += 1;
1361        }
1362    }
1363    tmp.extend_from_slice(&left[i..]);
1364    tmp.extend_from_slice(&right[j..]);
1365    data.copy_from_slice(&tmp);
1366}
1367
1368// ─────────────────────────────────────────────────────────────────────────────
1369// Tests for new parallel_sort additions
1370// ─────────────────────────────────────────────────────────────────────────────
1371
1372#[cfg(test)]
1373mod tests_new_sort {
1374    use super::*;
1375    use crate::gpu_sort::radix_sort_u32;
1376    use crate::parallel_sort::SortValidation;
1377    use crate::parallel_sort::adaptive_bucket_sort;
1378    use crate::parallel_sort::counting_sort_by_key;
1379    use crate::parallel_sort::counting_sort_u32;
1380    use crate::parallel_sort::histogram_bucket_sort;
1381    use crate::parallel_sort::is_permutation_f64;
1382    use crate::parallel_sort::is_permutation_u32;
1383    use crate::parallel_sort::is_sorted_f64;
1384    use crate::parallel_sort::is_sorted_u32;
1385    use crate::parallel_sort::k_way_merge;
1386    use crate::parallel_sort::merge_sort_parallel_threshold;
1387    use crate::parallel_sort::merge_sorted;
1388    use crate::parallel_sort::merge_sorted_u32;
1389    use crate::parallel_sort::radix_histogram;
1390    use crate::parallel_sort::radix_sort_gpu_staged;
1391    use crate::parallel_sort::radix_sort_stage_u32;
1392    use crate::parallel_sort::validate_radix_sort;
1393
1394    // ── GPU radix sort stages ─────────────────────────────────────────────
1395
1396    #[test]
1397    fn test_radix_sort_stage_pass0() {
1398        let data = vec![300u32, 1, 255, 100, 50];
1399        let (sorted_once, counts) = radix_sort_stage_u32(&data, 0);
1400        assert_eq!(sorted_once.len(), data.len());
1401        // counts should sum to data.len()
1402        let total: usize = counts.iter().sum();
1403        assert_eq!(total, data.len());
1404    }
1405
1406    #[test]
1407    fn test_radix_sort_gpu_staged_sorted() {
1408        let data: Vec<u32> = vec![500, 1, 200, 50, 900, 3, 150];
1409        let sorted = radix_sort_gpu_staged(&data);
1410        assert!(
1411            is_sorted_u32(&sorted),
1412            "staged sort should produce sorted output"
1413        );
1414        assert!(is_permutation_u32(&data, &sorted));
1415    }
1416
1417    #[test]
1418    fn test_radix_sort_gpu_staged_empty() {
1419        let sorted = radix_sort_gpu_staged(&[]);
1420        assert!(sorted.is_empty());
1421    }
1422
1423    #[test]
1424    fn test_radix_histogram_sums() {
1425        let data: Vec<u32> = (0..256).collect();
1426        let h = radix_histogram(&data, 0);
1427        let total: u32 = h.iter().sum();
1428        assert_eq!(total, 256);
1429        // Each byte bucket should have exactly 1 entry
1430        for &c in &h {
1431            assert_eq!(c, 1);
1432        }
1433    }
1434
1435    #[test]
1436    fn test_validate_radix_sort() {
1437        let original: Vec<u32> = vec![5, 3, 8, 1, 4];
1438        let mut sorted = original.clone();
1439        radix_sort_u32(&mut sorted);
1440        assert!(validate_radix_sort(&original, &sorted));
1441    }
1442
1443    #[test]
1444    fn test_validate_radix_sort_false_for_unsorted() {
1445        let original = vec![3u32, 1, 2];
1446        let not_sorted = vec![3u32, 1, 2];
1447        assert!(!validate_radix_sort(&original, &not_sorted));
1448    }
1449
1450    // ── Counting sort ─────────────────────────────────────────────────────
1451
1452    #[test]
1453    fn test_counting_sort_basic() {
1454        let data = vec![3u32, 1, 4, 1, 5, 9, 2, 6, 5, 3];
1455        let sorted = counting_sort_u32(&data, 9);
1456        assert!(is_sorted_u32(&sorted));
1457        assert!(is_permutation_u32(&data, &sorted));
1458    }
1459
1460    #[test]
1461    fn test_counting_sort_empty() {
1462        let sorted = counting_sort_u32(&[], 10);
1463        assert!(sorted.is_empty());
1464    }
1465
1466    #[test]
1467    fn test_counting_sort_all_same() {
1468        let data = vec![5u32; 10];
1469        let sorted = counting_sort_u32(&data, 5);
1470        assert_eq!(sorted, vec![5u32; 10]);
1471    }
1472
1473    #[test]
1474    fn test_counting_sort_by_key() {
1475        let data: Vec<(u32, &str)> = vec![(3, "c"), (1, "a"), (2, "b")];
1476        let sorted = counting_sort_by_key(&data, 3);
1477        assert_eq!(sorted[0].0, 1);
1478        assert_eq!(sorted[1].0, 2);
1479        assert_eq!(sorted[2].0, 3);
1480    }
1481
1482    #[test]
1483    fn test_counting_sort_by_key_stable() {
1484        // Two items with same key: stable sort preserves order
1485        let data: Vec<(u32, u32)> = vec![(2, 10), (1, 20), (2, 30)];
1486        let sorted = counting_sort_by_key(&data, 2);
1487        assert_eq!(sorted[0].0, 1);
1488        assert_eq!(sorted[1].0, 2);
1489        assert_eq!(sorted[2].0, 2);
1490        // Stable: (2,10) should come before (2,30)
1491        assert_eq!(sorted[1].1, 10);
1492        assert_eq!(sorted[2].1, 30);
1493    }
1494
1495    // ── Histogram-based sort ──────────────────────────────────────────────
1496
1497    #[test]
1498    fn test_histogram_bucket_sort_basic() {
1499        let mut data = vec![5.0, 3.0, 8.0, 1.0, 4.0, 7.0, 2.0, 6.0];
1500        let original = data.clone();
1501        histogram_bucket_sort(&mut data, 4);
1502        assert!(is_sorted_f64(&data));
1503        assert!(is_permutation_f64(&original, &data));
1504    }
1505
1506    #[test]
1507    fn test_histogram_bucket_sort_single_bucket() {
1508        let mut data = vec![3.0, 1.0, 2.0, 4.0];
1509        let original = data.clone();
1510        histogram_bucket_sort(&mut data, 1);
1511        assert!(is_sorted_f64(&data));
1512        assert!(is_permutation_f64(&original, &data));
1513    }
1514
1515    #[test]
1516    fn test_histogram_bucket_sort_all_equal() {
1517        let mut data = vec![5.0; 10];
1518        histogram_bucket_sort(&mut data, 4);
1519        assert!(is_sorted_f64(&data));
1520    }
1521
1522    #[test]
1523    fn test_histogram_bucket_sort_large() {
1524        let mut data: Vec<f64> = (0..200u32).rev().map(|x| x as f64).collect();
1525        let original = data.clone();
1526        histogram_bucket_sort(&mut data, 20);
1527        assert!(is_sorted_f64(&data));
1528        assert!(is_permutation_f64(&original, &data));
1529    }
1530
1531    #[test]
1532    fn test_adaptive_bucket_sort() {
1533        let mut data = vec![9.0, 3.0, 6.0, 1.0, 8.0, 4.0, 2.0, 7.0, 5.0];
1534        let orig = data.clone();
1535        adaptive_bucket_sort(&mut data, 3);
1536        assert!(is_sorted_f64(&data));
1537        assert!(is_permutation_f64(&orig, &data));
1538    }
1539
1540    // ── Sort validation ───────────────────────────────────────────────────
1541
1542    #[test]
1543    fn test_sort_validation_correct() {
1544        let orig = vec![3.0, 1.0, 4.0, 1.5, 9.0];
1545        let mut sorted = orig.clone();
1546        merge_sort_parallel(&mut sorted);
1547        let v = SortValidation::validate_f64(&orig, &sorted);
1548        assert!(v.is_correct());
1549        assert_eq!(v.inversions, 0);
1550        assert_eq!(v.n, 5);
1551    }
1552
1553    #[test]
1554    fn test_sort_validation_unsorted() {
1555        let orig = vec![1.0, 3.0, 2.0];
1556        let not_sorted = vec![1.0, 3.0, 2.0];
1557        let v = SortValidation::validate_f64(&orig, &not_sorted);
1558        assert!(!v.is_sorted);
1559        assert!(v.is_permutation);
1560        assert!(!v.is_correct());
1561    }
1562
1563    #[test]
1564    fn test_sort_validation_u32() {
1565        let orig = vec![5u32, 3, 8, 1];
1566        let mut sorted = orig.clone();
1567        radix_sort_u32(&mut sorted);
1568        let v = SortValidation::validate_u32(&orig, &sorted);
1569        assert!(v.is_correct());
1570    }
1571
1572    // ── Merge operations ──────────────────────────────────────────────────
1573
1574    #[test]
1575    fn test_merge_sorted_basic() {
1576        let a = vec![1.0, 3.0, 5.0];
1577        let b = vec![2.0, 4.0, 6.0];
1578        let m = merge_sorted(&a, &b);
1579        assert_eq!(m, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1580    }
1581
1582    #[test]
1583    fn test_merge_sorted_empty_left() {
1584        let a: Vec<f64> = vec![];
1585        let b = vec![1.0, 2.0, 3.0];
1586        let m = merge_sorted(&a, &b);
1587        assert_eq!(m, b);
1588    }
1589
1590    #[test]
1591    fn test_merge_sorted_empty_right() {
1592        let a = vec![1.0, 2.0, 3.0];
1593        let b: Vec<f64> = vec![];
1594        let m = merge_sorted(&a, &b);
1595        assert_eq!(m, a);
1596    }
1597
1598    #[test]
1599    fn test_merge_sorted_u32() {
1600        let a = vec![1u32, 4, 7];
1601        let b = vec![2u32, 5, 8];
1602        let m = merge_sorted_u32(&a, &b);
1603        assert_eq!(m, vec![1, 2, 4, 5, 7, 8]);
1604    }
1605
1606    #[test]
1607    fn test_k_way_merge() {
1608        let s1 = vec![1.0, 4.0, 7.0];
1609        let s2 = vec![2.0, 5.0, 8.0];
1610        let s3 = vec![3.0, 6.0, 9.0];
1611        let m = k_way_merge(&[s1, s2, s3]);
1612        assert!(is_sorted_f64(&m));
1613        assert_eq!(m.len(), 9);
1614    }
1615
1616    #[test]
1617    fn test_k_way_merge_single() {
1618        let s = vec![vec![3.0, 1.0, 2.0]]; // Note: input doesn't have to be sorted
1619        let m = k_way_merge(&s);
1620        assert!(is_sorted_f64(&m));
1621    }
1622
1623    #[test]
1624    fn test_merge_sort_parallel_threshold() {
1625        let mut data: Vec<f64> = (0..100u32).rev().map(|x| x as f64).collect();
1626        let orig = data.clone();
1627        merge_sort_parallel_threshold(&mut data, 32);
1628        assert!(is_sorted_f64(&data));
1629        assert!(is_permutation_f64(&orig, &data));
1630    }
1631
1632    #[test]
1633    fn test_merge_sort_parallel_threshold_small() {
1634        let mut data = vec![3.0, 1.0, 2.0];
1635        let orig = data.clone();
1636        merge_sort_parallel_threshold(&mut data, 1024);
1637        assert!(is_sorted_f64(&data));
1638        assert!(is_permutation_f64(&orig, &data));
1639    }
1640}