1#![allow(clippy::ptr_arg)]
2#![allow(clippy::needless_range_loop)]
3#![allow(dead_code)]
21
22use rayon::prelude::*;
23
24pub fn radix_sort_u32(data: &mut Vec<u32>) {
32 if data.len() <= 1 {
33 return;
34 }
35 let n = data.len();
36 let mut buf = vec![0u32; n];
37
38 for pass in 0..4u32 {
39 let shift = pass * 8;
40 let mut counts = [0usize; 256];
41 for &v in data.iter() {
42 let byte = ((v >> shift) & 0xFF) as usize;
43 counts[byte] += 1;
44 }
45 let mut offsets = [0usize; 256];
47 let mut total = 0;
48 for i in 0..256 {
49 offsets[i] = total;
50 total += counts[i];
51 }
52 for &v in data.iter() {
53 let byte = ((v >> shift) & 0xFF) as usize;
54 buf[offsets[byte]] = v;
55 offsets[byte] += 1;
56 }
57 std::mem::swap(data, &mut buf);
58 }
59}
60
61pub fn radix_sort_by_key<T: Clone>(data: &mut Vec<T>, key_fn: impl Fn(&T) -> u32) {
69 if data.len() <= 1 {
70 return;
71 }
72 let mut buf: Vec<T> = data.clone();
73
74 for pass in 0..4u32 {
75 let shift = pass * 8;
76 let mut counts = [0usize; 256];
77 for item in data.iter() {
78 let byte = ((key_fn(item) >> shift) & 0xFF) as usize;
79 counts[byte] += 1;
80 }
81 let mut offsets = [0usize; 256];
82 let mut total = 0;
83 for i in 0..256 {
84 offsets[i] = total;
85 total += counts[i];
86 }
87 for item in data.iter() {
88 let byte = ((key_fn(item) >> shift) & 0xFF) as usize;
89 buf[offsets[byte]] = item.clone();
90 offsets[byte] += 1;
91 }
92 std::mem::swap(data, &mut buf);
93 }
94}
95
96pub fn parallel_prefix_sum(data: &[u32]) -> Vec<u32> {
105 if data.is_empty() {
106 return Vec::new();
107 }
108 let n = data.len();
109 let num_threads = rayon::current_num_threads().max(1);
111 let chunk_size = (n / num_threads).max(1);
112
113 let chunks: Vec<_> = data.chunks(chunk_size).collect();
115 let chunk_sums: Vec<u32> = chunks
116 .par_iter()
117 .map(|chunk| chunk.iter().copied().fold(0u32, u32::wrapping_add))
118 .collect();
119
120 let mut chunk_offsets = vec![0u32; chunk_sums.len()];
122 let mut running = 0u32;
123 for (i, &s) in chunk_sums.iter().enumerate() {
124 chunk_offsets[i] = running;
125 running = running.wrapping_add(s);
126 }
127
128 let mut output = vec![0u32; n];
130 output
131 .par_chunks_mut(chunk_size)
132 .zip(data.par_chunks(chunk_size))
133 .zip(chunk_offsets.par_iter())
134 .for_each(|((out_chunk, in_chunk), &base)| {
135 let mut acc = base;
136 for (o, &v) in out_chunk.iter_mut().zip(in_chunk.iter()) {
137 *o = acc;
138 acc = acc.wrapping_add(v);
139 }
140 });
141
142 output
143}
144
145pub fn parallel_reduce_sum(data: &[f64]) -> f64 {
153 data.par_iter().copied().sum()
154}
155
156pub fn parallel_min_max(data: &[f64]) -> (f64, f64) {
164 if data.is_empty() {
165 return (f64::INFINITY, f64::NEG_INFINITY);
166 }
167 data.par_iter().copied().map(|v| (v, v)).reduce(
168 || (f64::INFINITY, f64::NEG_INFINITY),
169 |(lo1, hi1), (lo2, hi2)| (lo1.min(lo2), hi1.max(hi2)),
170 )
171}
172
173pub fn bitonic_sort(data: &mut Vec<f64>) {
182 let orig_len = data.len();
183 if orig_len <= 1 {
184 return;
185 }
186 let padded = orig_len.next_power_of_two();
188 data.resize(padded, f64::MAX);
189
190 let n = data.len();
191 let mut k = 2;
192 while k <= n {
193 let mut j = k / 2;
194 while j >= 1 {
195 for i in 0..n {
196 let l = i ^ j;
197 if l > i {
198 let ascending = (i & k) == 0;
199 if (ascending && data[i] > data[l]) || (!ascending && data[i] < data[l]) {
200 data.swap(i, l);
201 }
202 }
203 }
204 j /= 2;
205 }
206 k *= 2;
207 }
208
209 data.truncate(orig_len);
210}
211
212pub fn merge_sort_parallel(data: &mut Vec<f64>) {
221 let n = data.len();
222 if n <= 1 {
223 return;
224 }
225 merge_sort_parallel_slice(data);
226}
227
228fn merge_sort_parallel_slice(data: &mut [f64]) {
229 let n = data.len();
230 if n <= 32 {
231 data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
232 return;
233 }
234 let mid = n / 2;
235 let (left, right) = data.split_at_mut(mid);
236
237 rayon::join(
239 || merge_sort_parallel_slice(left),
240 || merge_sort_parallel_slice(right),
241 );
242
243 let mut tmp = Vec::with_capacity(n);
245 let mut i = 0;
246 let mut j = 0;
247 let (left, right) = data.split_at(mid);
249 while i < left.len() && j < right.len() {
250 if left[i] <= right[j] {
251 tmp.push(left[i]);
252 i += 1;
253 } else {
254 tmp.push(right[j]);
255 j += 1;
256 }
257 }
258 tmp.extend_from_slice(&left[i..]);
259 tmp.extend_from_slice(&right[j..]);
260 data.copy_from_slice(&tmp);
261}
262
263pub fn histogram_u32(data: &[u32], num_buckets: usize) -> Vec<u32> {
275 assert!(num_buckets > 0, "num_buckets must be > 0");
276 if data.is_empty() {
277 return vec![0; num_buckets];
278 }
279 let nb = num_buckets;
280 data.par_chunks(256.max(data.len() / rayon::current_num_threads().max(1)))
282 .map(|chunk| {
283 let mut local = vec![0u32; nb];
284 for &v in chunk {
285 local[(v as usize) % nb] += 1;
286 }
287 local
288 })
289 .reduce(
290 || vec![0u32; nb],
291 |mut acc, local| {
292 for i in 0..nb {
293 acc[i] += local[i];
294 }
295 acc
296 },
297 )
298}
299
300pub fn argsort(data: &[f64]) -> Vec<usize> {
308 let mut indices: Vec<usize> = (0..data.len()).collect();
309 indices.sort_unstable_by(|&a, &b| {
310 data[a]
311 .partial_cmp(&data[b])
312 .unwrap_or(std::cmp::Ordering::Greater)
313 });
314 indices
315}
316
317pub fn nth_element(data: &mut Vec<f64>, k: usize) -> f64 {
329 assert!(!data.is_empty(), "nth_element: data must not be empty");
330 assert!(
331 k < data.len(),
332 "nth_element: k={k} out of bounds (len={})",
333 data.len()
334 );
335 nth_element_slice(data, k);
336 data[k]
337}
338
339fn nth_element_slice(data: &mut [f64], k: usize) {
340 if data.len() <= 1 {
341 return;
342 }
343 let pivot_idx = partition(data);
344 if k < pivot_idx {
345 nth_element_slice(&mut data[..pivot_idx], k);
346 } else if k > pivot_idx {
347 nth_element_slice(&mut data[pivot_idx + 1..], k - pivot_idx - 1);
348 }
349 }
351
352fn partition(data: &mut [f64]) -> usize {
354 let n = data.len();
355 let mid = n / 2;
357 let last = n - 1;
358 if data[0] > data[mid] {
359 data.swap(0, mid);
360 }
361 if data[0] > data[last] {
362 data.swap(0, last);
363 }
364 if data[mid] > data[last] {
365 data.swap(mid, last);
366 }
367 data.swap(mid, last - 1.min(last));
369 let pivot_pos = if n >= 3 { last - 1 } else { last };
370 let pivot = data[pivot_pos];
371 data.swap(pivot_pos, last);
372 let mut store = 0;
373 for i in 0..last {
374 let v = data[i];
375 if v < pivot || (v == pivot && store < last) {
376 data.swap(i, store);
377 store += 1;
378 }
379 }
380 data.swap(store, last);
381 store
382}
383
384pub fn is_sorted_f64(data: &[f64]) -> bool {
392 data.windows(2).all(|w| w[0] <= w[1])
393}
394
395pub fn is_sorted_u32(data: &[u32]) -> bool {
397 data.windows(2).all(|w| w[0] <= w[1])
398}
399
400pub fn count_inversions_f64(data: &[f64]) -> u64 {
404 if data.len() <= 1 {
405 return 0;
406 }
407 let mut tmp = data.to_vec();
408 count_inversions_helper(&mut tmp)
409}
410
411fn count_inversions_helper(data: &mut [f64]) -> u64 {
412 let n = data.len();
413 if n <= 1 {
414 return 0;
415 }
416 let mid = n / 2;
417 let mut left = data[..mid].to_vec();
418 let mut right = data[mid..].to_vec();
419 let mut count = count_inversions_helper(&mut left);
420 count += count_inversions_helper(&mut right);
421
422 let mut i = 0;
423 let mut j = 0;
424 let mut k = 0;
425 while i < left.len() && j < right.len() {
426 if left[i] <= right[j] {
427 data[k] = left[i];
428 i += 1;
429 } else {
430 data[k] = right[j];
431 count += (left.len() - i) as u64;
432 j += 1;
433 }
434 k += 1;
435 }
436 while i < left.len() {
437 data[k] = left[i];
438 i += 1;
439 k += 1;
440 }
441 while j < right.len() {
442 data[k] = right[j];
443 j += 1;
444 k += 1;
445 }
446 count
447}
448
449pub struct SortTimingResult {
455 pub name: String,
457 pub n: usize,
459 pub correct: bool,
461}
462
463pub fn compare_sorts(data: &[f64]) -> Vec<SortTimingResult> {
467 let mut results = Vec::new();
468
469 let mut d1 = data.to_vec();
471 bitonic_sort(&mut d1);
472 results.push(SortTimingResult {
473 name: "bitonic".into(),
474 n: data.len(),
475 correct: is_sorted_f64(&d1),
476 });
477
478 let mut d2 = data.to_vec();
480 merge_sort_parallel(&mut d2);
481 results.push(SortTimingResult {
482 name: "merge_parallel".into(),
483 n: data.len(),
484 correct: is_sorted_f64(&d2),
485 });
486
487 let mut d3: Vec<u32> = data.iter().map(|&v| v as u32).collect();
489 radix_sort_u32(&mut d3);
490 results.push(SortTimingResult {
491 name: "radix_u32".into(),
492 n: data.len(),
493 correct: is_sorted_u32(&d3),
494 });
495
496 results
497}
498
499pub fn is_permutation_f64(a: &[f64], b: &[f64]) -> bool {
501 if a.len() != b.len() {
502 return false;
503 }
504 let mut sa = a.to_vec();
505 let mut sb = b.to_vec();
506 sa.sort_unstable_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
507 sb.sort_unstable_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
508 sa == sb
509}
510
511pub fn is_permutation_u32(a: &[u32], b: &[u32]) -> bool {
513 if a.len() != b.len() {
514 return false;
515 }
516 let mut sa = a.to_vec();
517 let mut sb = b.to_vec();
518 sa.sort_unstable();
519 sb.sort_unstable();
520 sa == sb
521}
522
523#[cfg(test)]
528mod tests {
529 use super::*;
530 use crate::gpu_sort::radix_sort_u32;
531
532 use crate::parallel_sort::is_permutation_f64;
533 use crate::parallel_sort::is_permutation_u32;
534 use crate::parallel_sort::is_sorted_f64;
535 use crate::parallel_sort::is_sorted_u32;
536
537 #[test]
540 fn test_radix_sort_empty() {
541 let mut v: Vec<u32> = vec![];
542 radix_sort_u32(&mut v);
543 assert!(v.is_empty());
544 }
545
546 #[test]
547 fn test_radix_sort_single() {
548 let mut v = vec![42u32];
549 radix_sort_u32(&mut v);
550 assert_eq!(v, [42]);
551 }
552
553 #[test]
554 fn test_radix_sort_sorted() {
555 let mut v = vec![1u32, 2, 3, 4, 5];
556 radix_sort_u32(&mut v);
557 assert_eq!(v, [1, 2, 3, 4, 5]);
558 }
559
560 #[test]
561 fn test_radix_sort_reverse() {
562 let mut v = vec![5u32, 4, 3, 2, 1];
563 radix_sort_u32(&mut v);
564 assert_eq!(v, [1, 2, 3, 4, 5]);
565 }
566
567 #[test]
568 fn test_radix_sort_random_u32() {
569 let mut v: Vec<u32> = (0..1000u32).rev().collect();
570 radix_sort_u32(&mut v);
571 for i in 0..1000usize {
572 assert_eq!(v[i], i as u32, "mismatch at index {i}");
573 }
574 }
575
576 #[test]
577 fn test_radix_sort_large_values() {
578 let mut v = vec![u32::MAX, 0, u32::MAX / 2, 1, u32::MAX - 1];
579 radix_sort_u32(&mut v);
580 assert_eq!(v, [0, 1, u32::MAX / 2, u32::MAX - 1, u32::MAX]);
581 }
582
583 #[test]
586 fn test_radix_sort_by_key_strings() {
587 let mut v: Vec<(&str, u32)> = vec![("c", 3), ("a", 1), ("b", 2)];
588 radix_sort_by_key(&mut v, |item| item.1);
589 assert_eq!(v, [("a", 1), ("b", 2), ("c", 3)]);
590 }
591
592 #[test]
593 fn test_radix_sort_by_key_empty() {
594 let mut v: Vec<(usize, u32)> = vec![];
595 radix_sort_by_key(&mut v, |item| item.1);
596 assert!(v.is_empty());
597 }
598
599 #[test]
602 fn test_prefix_sum_empty() {
603 assert!(parallel_prefix_sum(&[]).is_empty());
604 }
605
606 #[test]
607 fn test_prefix_sum_single() {
608 assert_eq!(parallel_prefix_sum(&[7]), vec![0]);
609 }
610
611 #[test]
612 fn test_prefix_sum_basic() {
613 let data = [1u32, 2, 3, 4, 5];
614 let out = parallel_prefix_sum(&data);
615 assert_eq!(out, vec![0, 1, 3, 6, 10]);
616 }
617
618 #[test]
619 fn test_prefix_sum_ones() {
620 let data = vec![1u32; 100];
621 let out = parallel_prefix_sum(&data);
622 for (i, &v) in out.iter().enumerate() {
623 assert_eq!(v, i as u32, "prefix[{i}] should be {i}");
624 }
625 }
626
627 #[test]
630 fn test_reduce_sum_empty() {
631 assert_eq!(parallel_reduce_sum(&[]), 0.0);
632 }
633
634 #[test]
635 fn test_reduce_sum_basic() {
636 let data = [1.0f64, 2.0, 3.0, 4.0, 5.0];
637 assert!((parallel_reduce_sum(&data) - 15.0).abs() < 1e-12);
638 }
639
640 #[test]
641 fn test_reduce_sum_large() {
642 let data: Vec<f64> = (1..=1000).map(|i| i as f64).collect();
643 let expected = 1000.0 * 1001.0 / 2.0;
644 assert!((parallel_reduce_sum(&data) - expected).abs() < 1e-6);
645 }
646
647 #[test]
650 fn test_min_max_empty() {
651 let (lo, hi) = parallel_min_max(&[]);
652 assert!(lo.is_infinite() && lo > 0.0);
653 assert!(hi.is_infinite() && hi < 0.0);
654 }
655
656 #[test]
657 fn test_min_max_single() {
658 let (lo, hi) = parallel_min_max(&[3.125]);
659 assert!((lo - 3.125).abs() < 1e-12);
660 assert!((hi - 3.125).abs() < 1e-12);
661 }
662
663 #[test]
664 fn test_min_max_basic() {
665 let data = [3.0f64, 1.0, 4.0, 1.5, 9.2, 2.6];
666 let (lo, hi) = parallel_min_max(&data);
667 assert!((lo - 1.0).abs() < 1e-12);
668 assert!((hi - 9.2).abs() < 1e-12);
669 }
670
671 #[test]
674 fn test_bitonic_sort_empty() {
675 let mut v: Vec<f64> = vec![];
676 bitonic_sort(&mut v);
677 assert!(v.is_empty());
678 }
679
680 #[test]
681 fn test_bitonic_sort_power_of_two() {
682 let mut v = vec![4.0f64, 2.0, 7.0, 1.0, 5.0, 3.0, 6.0, 8.0];
683 bitonic_sort(&mut v);
684 assert_eq!(v, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
685 }
686
687 #[test]
688 fn test_bitonic_sort_non_power_of_two() {
689 let mut v = vec![5.0f64, 3.0, 1.0, 4.0, 2.0];
690 bitonic_sort(&mut v);
691 assert_eq!(v, [1.0, 2.0, 3.0, 4.0, 5.0]);
692 }
693
694 #[test]
697 fn test_merge_sort_empty() {
698 let mut v: Vec<f64> = vec![];
699 merge_sort_parallel(&mut v);
700 assert!(v.is_empty());
701 }
702
703 #[test]
704 fn test_merge_sort_basic() {
705 let mut v = vec![3.0f64, 1.0, 4.0, 1.5, 9.0, 2.6];
706 merge_sort_parallel(&mut v);
707 assert_eq!(v, [1.0, 1.5, 2.6, 3.0, 4.0, 9.0]);
708 }
709
710 #[test]
711 fn test_merge_sort_large() {
712 let mut v: Vec<f64> = (0..500u32).rev().map(|x| x as f64).collect();
713 merge_sort_parallel(&mut v);
714 for i in 0..500usize {
715 assert!((v[i] - i as f64).abs() < 1e-12, "mismatch at {i}");
716 }
717 }
718
719 #[test]
722 fn test_histogram_empty() {
723 let h = histogram_u32(&[], 4);
724 assert_eq!(h, vec![0, 0, 0, 0]);
725 }
726
727 #[test]
728 fn test_histogram_basic() {
729 let data = [0u32, 1, 2, 3, 0, 1, 2, 0];
730 let h = histogram_u32(&data, 4);
731 assert_eq!(h, vec![3, 2, 2, 1]);
732 }
733
734 #[test]
735 fn test_histogram_one_bucket() {
736 let data: Vec<u32> = (0..10).collect();
737 let h = histogram_u32(&data, 1);
738 assert_eq!(h, vec![10]);
739 }
740
741 #[test]
744 fn test_argsort_empty() {
745 assert!(argsort(&[]).is_empty());
746 }
747
748 #[test]
749 fn test_argsort_basic() {
750 let data = [3.0f64, 1.0, 4.0, 1.5, 9.0];
751 let idx = argsort(&data);
752 let sorted: Vec<f64> = idx.iter().map(|&i| data[i]).collect();
753 assert_eq!(sorted, [1.0, 1.5, 3.0, 4.0, 9.0]);
754 }
755
756 #[test]
757 fn test_argsort_already_sorted() {
758 let data = [1.0f64, 2.0, 3.0, 4.0, 5.0];
759 let idx = argsort(&data);
760 assert_eq!(idx, [0, 1, 2, 3, 4]);
761 }
762
763 #[test]
766 fn test_nth_element_single() {
767 let mut v = vec![42.0f64];
768 assert!((nth_element(&mut v, 0) - 42.0).abs() < 1e-12);
769 }
770
771 #[test]
772 fn test_nth_element_median() {
773 let mut v = vec![3.0f64, 1.0, 4.0, 1.5, 9.0, 2.6, 5.0];
774 let median = nth_element(&mut v, 3);
776 assert!((median - 3.0).abs() < 1e-12, "expected 3.0, got {median}");
777 }
778
779 #[test]
780 fn test_nth_element_min() {
781 let mut v = vec![5.0f64, 3.0, 8.0, 1.0, 4.0];
782 let min = nth_element(&mut v, 0);
783 assert!((min - 1.0).abs() < 1e-12, "expected 1.0, got {min}");
784 }
785
786 #[test]
787 fn test_nth_element_max() {
788 let mut v = vec![5.0f64, 3.0, 8.0, 1.0, 4.0];
789 let max = nth_element(&mut v, 4);
790 assert!((max - 8.0).abs() < 1e-12, "expected 8.0, got {max}");
791 }
792
793 #[test]
794 fn test_nth_element_duplicates() {
795 let mut v = vec![2.0f64, 2.0, 2.0, 2.0, 2.0];
796 let val = nth_element(&mut v, 2);
797 assert!((val - 2.0).abs() < 1e-12);
798 }
799
800 #[test]
803 fn test_is_sorted_f64_empty() {
804 assert!(is_sorted_f64(&[]));
805 }
806
807 #[test]
808 fn test_is_sorted_f64_sorted() {
809 assert!(is_sorted_f64(&[1.0, 2.0, 3.0, 4.0]));
810 }
811
812 #[test]
813 fn test_is_sorted_f64_unsorted() {
814 assert!(!is_sorted_f64(&[1.0, 3.0, 2.0, 4.0]));
815 }
816
817 #[test]
818 fn test_is_sorted_u32_sorted() {
819 assert!(is_sorted_u32(&[0, 1, 2, 3, 4]));
820 }
821
822 #[test]
823 fn test_is_sorted_u32_unsorted() {
824 assert!(!is_sorted_u32(&[0, 2, 1, 3]));
825 }
826
827 #[test]
830 fn test_count_inversions_sorted() {
831 assert_eq!(count_inversions_f64(&[1.0, 2.0, 3.0, 4.0]), 0);
832 }
833
834 #[test]
835 fn test_count_inversions_reversed() {
836 assert_eq!(count_inversions_f64(&[4.0, 3.0, 2.0, 1.0]), 6);
838 }
839
840 #[test]
841 fn test_count_inversions_one_swap() {
842 assert_eq!(count_inversions_f64(&[2.0, 1.0, 3.0, 4.0]), 1);
843 }
844
845 #[test]
846 fn test_count_inversions_empty() {
847 assert_eq!(count_inversions_f64(&[]), 0);
848 }
849
850 #[test]
853 fn test_is_permutation_f64_true() {
854 assert!(is_permutation_f64(&[3.0, 1.0, 2.0], &[1.0, 2.0, 3.0]));
855 }
856
857 #[test]
858 fn test_is_permutation_f64_false() {
859 assert!(!is_permutation_f64(&[3.0, 1.0, 2.0], &[1.0, 2.0, 4.0]));
860 }
861
862 #[test]
863 fn test_is_permutation_f64_different_lengths() {
864 assert!(!is_permutation_f64(&[1.0, 2.0], &[1.0, 2.0, 3.0]));
865 }
866
867 #[test]
868 fn test_is_permutation_u32_true() {
869 assert!(is_permutation_u32(&[3, 1, 2], &[1, 2, 3]));
870 }
871
872 #[test]
873 fn test_is_permutation_u32_false() {
874 assert!(!is_permutation_u32(&[1, 2, 3], &[1, 2, 4]));
875 }
876
877 #[test]
880 fn test_bitonic_sort_preserves_elements() {
881 let original = vec![5.0, 3.0, 8.0, 1.0, 4.0, 7.0, 2.0, 6.0];
882 let mut sorted = original.clone();
883 bitonic_sort(&mut sorted);
884 assert!(is_permutation_f64(&original, &sorted));
885 assert!(is_sorted_f64(&sorted));
886 }
887
888 #[test]
889 fn test_merge_sort_preserves_elements() {
890 let original = vec![5.0, 3.0, 8.0, 1.0, 4.0, 7.0, 2.0, 6.0];
891 let mut sorted = original.clone();
892 merge_sort_parallel(&mut sorted);
893 assert!(is_permutation_f64(&original, &sorted));
894 assert!(is_sorted_f64(&sorted));
895 }
896
897 #[test]
898 fn test_radix_sort_preserves_elements() {
899 let original = vec![5u32, 3, 8, 1, 4, 7, 2, 6];
900 let mut sorted = original.clone();
901 radix_sort_u32(&mut sorted);
902 assert!(is_permutation_u32(&original, &sorted));
903 assert!(is_sorted_u32(&sorted));
904 }
905
906 #[test]
909 fn test_compare_sorts_all_correct() {
910 let data: Vec<f64> = (0..100u32).rev().map(|x| x as f64).collect();
911 let results = compare_sorts(&data);
912 for r in &results {
913 assert!(r.correct, "sort {} failed for n={}", r.name, r.n);
914 }
915 }
916
917 #[test]
918 fn test_compare_sorts_empty() {
919 let results = compare_sorts(&[]);
920 for r in &results {
921 assert!(r.correct);
922 }
923 }
924
925 #[test]
928 fn test_bitonic_sort_single() {
929 let mut v = vec![42.0_f64];
930 bitonic_sort(&mut v);
931 assert_eq!(v, [42.0]);
932 }
933
934 #[test]
935 fn test_bitonic_sort_already_sorted() {
936 let mut v = vec![1.0, 2.0, 3.0, 4.0];
937 bitonic_sort(&mut v);
938 assert_eq!(v, [1.0, 2.0, 3.0, 4.0]);
939 }
940
941 #[test]
942 fn test_bitonic_sort_duplicates() {
943 let mut v = vec![3.0, 1.0, 3.0, 1.0, 2.0, 2.0];
944 bitonic_sort(&mut v);
945 assert_eq!(v, [1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
946 }
947
948 #[test]
951 fn test_merge_sort_single() {
952 let mut v = vec![42.0_f64];
953 merge_sort_parallel(&mut v);
954 assert_eq!(v, [42.0]);
955 }
956
957 #[test]
958 fn test_merge_sort_two_elements() {
959 let mut v = vec![2.0, 1.0];
960 merge_sort_parallel(&mut v);
961 assert_eq!(v, [1.0, 2.0]);
962 }
963
964 #[test]
965 fn test_merge_sort_duplicates() {
966 let mut v = vec![5.0, 1.0, 5.0, 1.0, 3.0];
967 merge_sort_parallel(&mut v);
968 assert_eq!(v, [1.0, 1.0, 3.0, 5.0, 5.0]);
969 }
970
971 #[test]
974 fn test_radix_sort_all_same() {
975 let mut v = vec![7u32, 7, 7, 7, 7];
976 radix_sort_u32(&mut v);
977 assert_eq!(v, [7, 7, 7, 7, 7]);
978 }
979
980 #[test]
981 fn test_radix_sort_two_elements() {
982 let mut v = vec![2u32, 1];
983 radix_sort_u32(&mut v);
984 assert_eq!(v, [1, 2]);
985 }
986
987 #[test]
990 fn test_argsort_duplicates() {
991 let data = [3.0, 1.0, 3.0, 1.0];
992 let idx = argsort(&data);
993 let sorted: Vec<f64> = idx.iter().map(|&i| data[i]).collect();
994 assert!(is_sorted_f64(&sorted));
995 }
996
997 #[test]
998 fn test_argsort_single() {
999 let idx = argsort(&[42.0]);
1000 assert_eq!(idx, [0]);
1001 }
1002
1003 #[test]
1006 fn test_nth_element_sorted_input() {
1007 let mut v = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1008 let val = nth_element(&mut v, 2);
1009 assert!((val - 3.0).abs() < 1e-12);
1010 }
1011
1012 #[test]
1013 fn test_nth_element_reversed() {
1014 let mut v = vec![5.0, 4.0, 3.0, 2.0, 1.0];
1015 let val = nth_element(&mut v, 0);
1016 assert!((val - 1.0).abs() < 1e-12);
1017 }
1018}
1019
1020pub fn radix_sort_stage_u32(data: &[u32], shift: u32) -> (Vec<u32>, [usize; 256]) {
1029 let n = data.len();
1030 let mut counts = [0usize; 256];
1031 for &v in data {
1032 let byte = ((v >> shift) & 0xFF) as usize;
1033 counts[byte] += 1;
1034 }
1035 let mut offsets = [0usize; 256];
1036 let mut total = 0;
1037 for i in 0..256 {
1038 offsets[i] = total;
1039 total += counts[i];
1040 }
1041 let mut out = vec![0u32; n];
1042 let mut pos = offsets;
1043 for &v in data {
1044 let byte = ((v >> shift) & 0xFF) as usize;
1045 out[pos[byte]] = v;
1046 pos[byte] += 1;
1047 }
1048 (out, counts)
1049}
1050
1051pub fn radix_sort_gpu_staged(data: &[u32]) -> Vec<u32> {
1055 if data.is_empty() {
1056 return Vec::new();
1057 }
1058 let mut current = data.to_vec();
1059 for pass in 0..4u32 {
1060 let (sorted, _counts) = radix_sort_stage_u32(¤t, pass * 8);
1061 current = sorted;
1062 }
1063 current
1064}
1065
1066pub fn radix_histogram(data: &[u32], shift: u32) -> Vec<u32> {
1071 let mut counts = vec![0u32; 256];
1072 for &v in data {
1073 let byte = ((v >> shift) & 0xFF) as usize;
1074 counts[byte] += 1;
1075 }
1076 counts
1077}
1078
1079pub fn validate_radix_sort(original: &[u32], sorted: &[u32]) -> bool {
1081 is_permutation_u32(original, sorted) && is_sorted_u32(sorted)
1082}
1083
1084pub fn counting_sort_u32(data: &[u32], max_val: u32) -> Vec<u32> {
1095 if data.is_empty() {
1096 return Vec::new();
1097 }
1098 let m = max_val as usize + 1;
1099 let mut counts = vec![0u32; m];
1100 for &v in data {
1101 assert!((v as usize) < m, "value {v} exceeds max_val {max_val}");
1102 counts[v as usize] += 1;
1103 }
1104 let mut out = Vec::with_capacity(data.len());
1105 for (v, &c) in counts.iter().enumerate() {
1106 for _ in 0..c {
1107 out.push(v as u32);
1108 }
1109 }
1110 out
1111}
1112
1113pub fn counting_sort_by_key<T: Clone>(data: &[(u32, T)], max_key: u32) -> Vec<(u32, T)> {
1117 if data.is_empty() {
1118 return Vec::new();
1119 }
1120 let m = max_key as usize + 1;
1121 let mut counts = vec![0usize; m];
1122 for (k, _) in data {
1123 assert!((*k as usize) < m, "key {k} exceeds max_key {max_key}");
1124 counts[*k as usize] += 1;
1125 }
1126 let mut offsets = vec![0usize; m];
1128 let mut running = 0;
1129 for i in 0..m {
1130 offsets[i] = running;
1131 running += counts[i];
1132 }
1133 let mut out: Vec<Option<(u32, T)>> = (0..data.len()).map(|_| None).collect();
1134 for (k, v) in data {
1135 let idx = *k as usize;
1136 out[offsets[idx]] = Some((*k, v.clone()));
1137 offsets[idx] += 1;
1138 }
1139 out.into_iter().flatten().collect()
1140}
1141
1142pub fn histogram_bucket_sort(data: &mut Vec<f64>, n_buckets: usize) {
1151 let n = data.len();
1152 if n <= 1 || n_buckets == 0 {
1153 data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1154 return;
1155 }
1156
1157 let (lo, hi) = {
1158 let mut lo = f64::INFINITY;
1159 let mut hi = f64::NEG_INFINITY;
1160 for &v in data.iter() {
1161 if v < lo {
1162 lo = v;
1163 }
1164 if v > hi {
1165 hi = v;
1166 }
1167 }
1168 (lo, hi)
1169 };
1170
1171 if (hi - lo).abs() < f64::EPSILON {
1172 return; }
1174
1175 let nb = n_buckets;
1176 let range = hi - lo;
1177 let mut buckets: Vec<Vec<f64>> = vec![Vec::new(); nb];
1178
1179 for &v in data.iter() {
1180 let idx = ((v - lo) / range * nb as f64) as usize;
1181 let idx = idx.min(nb - 1);
1182 buckets[idx].push(v);
1183 }
1184
1185 for b in &mut buckets {
1186 b.sort_unstable_by(|a, c| a.partial_cmp(c).unwrap_or(std::cmp::Ordering::Equal));
1187 }
1188
1189 let mut pos = 0;
1190 for b in &buckets {
1191 for &v in b {
1192 data[pos] = v;
1193 pos += 1;
1194 }
1195 }
1196}
1197
1198pub fn adaptive_bucket_sort(data: &mut Vec<f64>, n_buckets: usize) {
1203 histogram_bucket_sort(data, n_buckets.max(1));
1204}
1205
1206pub struct SortValidation {
1212 pub is_sorted: bool,
1214 pub is_permutation: bool,
1216 pub n: usize,
1218 pub inversions: u64,
1220}
1221
1222impl SortValidation {
1223 pub fn validate_f64(original: &[f64], sorted: &[f64]) -> Self {
1225 let is_sorted = is_sorted_f64(sorted);
1226 let is_perm = is_permutation_f64(original, sorted);
1227 let inversions = if is_sorted {
1228 0
1229 } else {
1230 count_inversions_f64(sorted)
1231 };
1232 Self {
1233 is_sorted,
1234 is_permutation: is_perm,
1235 n: sorted.len(),
1236 inversions,
1237 }
1238 }
1239
1240 pub fn validate_u32(original: &[u32], sorted: &[u32]) -> Self {
1242 let is_sorted = is_sorted_u32(sorted);
1243 let is_perm = is_permutation_u32(original, sorted);
1244 Self {
1245 is_sorted,
1246 is_permutation: is_perm,
1247 n: sorted.len(),
1248 inversions: 0,
1249 }
1250 }
1251
1252 pub fn is_correct(&self) -> bool {
1254 self.is_sorted && self.is_permutation
1255 }
1256}
1257
1258pub fn merge_sorted(left: &[f64], right: &[f64]) -> Vec<f64> {
1266 let mut out = Vec::with_capacity(left.len() + right.len());
1267 let mut i = 0;
1268 let mut j = 0;
1269 while i < left.len() && j < right.len() {
1270 if left[i] <= right[j] {
1271 out.push(left[i]);
1272 i += 1;
1273 } else {
1274 out.push(right[j]);
1275 j += 1;
1276 }
1277 }
1278 out.extend_from_slice(&left[i..]);
1279 out.extend_from_slice(&right[j..]);
1280 out
1281}
1282
1283pub fn merge_sorted_u32(left: &[u32], right: &[u32]) -> Vec<u32> {
1285 let mut out = Vec::with_capacity(left.len() + right.len());
1286 let mut i = 0;
1287 let mut j = 0;
1288 while i < left.len() && j < right.len() {
1289 if left[i] <= right[j] {
1290 out.push(left[i]);
1291 i += 1;
1292 } else {
1293 out.push(right[j]);
1294 j += 1;
1295 }
1296 }
1297 out.extend_from_slice(&left[i..]);
1298 out.extend_from_slice(&right[j..]);
1299 out
1300}
1301
1302pub fn k_way_merge(slices: &[Vec<f64>]) -> Vec<f64> {
1306 let total: usize = slices.iter().map(|s| s.len()).sum();
1308 let mut result = Vec::with_capacity(total);
1309 for s in slices {
1310 result.extend_from_slice(s);
1311 }
1312 result.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1313 result
1314}
1315
1316pub fn merge_sort_parallel_threshold(data: &mut Vec<f64>, parallel_threshold: usize) {
1321 let n = data.len();
1322 if n <= 1 {
1323 return;
1324 }
1325 merge_sort_threshold_slice(data, parallel_threshold);
1326}
1327
1328fn merge_sort_threshold_slice(data: &mut [f64], threshold: usize) {
1329 let n = data.len();
1330 if n <= 1 {
1331 return;
1332 }
1333 if n <= 16 {
1334 data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1335 return;
1336 }
1337 let mid = n / 2;
1338 let (left, right) = data.split_at_mut(mid);
1339
1340 if n >= threshold {
1341 rayon::join(
1342 || merge_sort_threshold_slice(left, threshold),
1343 || merge_sort_threshold_slice(right, threshold),
1344 );
1345 } else {
1346 merge_sort_threshold_slice(left, threshold);
1347 merge_sort_threshold_slice(right, threshold);
1348 }
1349
1350 let mut tmp = Vec::with_capacity(n);
1351 let (left, right) = data.split_at(mid);
1352 let mut i = 0;
1353 let mut j = 0;
1354 while i < left.len() && j < right.len() {
1355 if left[i] <= right[j] {
1356 tmp.push(left[i]);
1357 i += 1;
1358 } else {
1359 tmp.push(right[j]);
1360 j += 1;
1361 }
1362 }
1363 tmp.extend_from_slice(&left[i..]);
1364 tmp.extend_from_slice(&right[j..]);
1365 data.copy_from_slice(&tmp);
1366}
1367
1368#[cfg(test)]
1373mod tests_new_sort {
1374 use super::*;
1375 use crate::gpu_sort::radix_sort_u32;
1376 use crate::parallel_sort::SortValidation;
1377 use crate::parallel_sort::adaptive_bucket_sort;
1378 use crate::parallel_sort::counting_sort_by_key;
1379 use crate::parallel_sort::counting_sort_u32;
1380 use crate::parallel_sort::histogram_bucket_sort;
1381 use crate::parallel_sort::is_permutation_f64;
1382 use crate::parallel_sort::is_permutation_u32;
1383 use crate::parallel_sort::is_sorted_f64;
1384 use crate::parallel_sort::is_sorted_u32;
1385 use crate::parallel_sort::k_way_merge;
1386 use crate::parallel_sort::merge_sort_parallel_threshold;
1387 use crate::parallel_sort::merge_sorted;
1388 use crate::parallel_sort::merge_sorted_u32;
1389 use crate::parallel_sort::radix_histogram;
1390 use crate::parallel_sort::radix_sort_gpu_staged;
1391 use crate::parallel_sort::radix_sort_stage_u32;
1392 use crate::parallel_sort::validate_radix_sort;
1393
1394 #[test]
1397 fn test_radix_sort_stage_pass0() {
1398 let data = vec![300u32, 1, 255, 100, 50];
1399 let (sorted_once, counts) = radix_sort_stage_u32(&data, 0);
1400 assert_eq!(sorted_once.len(), data.len());
1401 let total: usize = counts.iter().sum();
1403 assert_eq!(total, data.len());
1404 }
1405
1406 #[test]
1407 fn test_radix_sort_gpu_staged_sorted() {
1408 let data: Vec<u32> = vec![500, 1, 200, 50, 900, 3, 150];
1409 let sorted = radix_sort_gpu_staged(&data);
1410 assert!(
1411 is_sorted_u32(&sorted),
1412 "staged sort should produce sorted output"
1413 );
1414 assert!(is_permutation_u32(&data, &sorted));
1415 }
1416
1417 #[test]
1418 fn test_radix_sort_gpu_staged_empty() {
1419 let sorted = radix_sort_gpu_staged(&[]);
1420 assert!(sorted.is_empty());
1421 }
1422
1423 #[test]
1424 fn test_radix_histogram_sums() {
1425 let data: Vec<u32> = (0..256).collect();
1426 let h = radix_histogram(&data, 0);
1427 let total: u32 = h.iter().sum();
1428 assert_eq!(total, 256);
1429 for &c in &h {
1431 assert_eq!(c, 1);
1432 }
1433 }
1434
1435 #[test]
1436 fn test_validate_radix_sort() {
1437 let original: Vec<u32> = vec![5, 3, 8, 1, 4];
1438 let mut sorted = original.clone();
1439 radix_sort_u32(&mut sorted);
1440 assert!(validate_radix_sort(&original, &sorted));
1441 }
1442
1443 #[test]
1444 fn test_validate_radix_sort_false_for_unsorted() {
1445 let original = vec![3u32, 1, 2];
1446 let not_sorted = vec![3u32, 1, 2];
1447 assert!(!validate_radix_sort(&original, ¬_sorted));
1448 }
1449
1450 #[test]
1453 fn test_counting_sort_basic() {
1454 let data = vec![3u32, 1, 4, 1, 5, 9, 2, 6, 5, 3];
1455 let sorted = counting_sort_u32(&data, 9);
1456 assert!(is_sorted_u32(&sorted));
1457 assert!(is_permutation_u32(&data, &sorted));
1458 }
1459
1460 #[test]
1461 fn test_counting_sort_empty() {
1462 let sorted = counting_sort_u32(&[], 10);
1463 assert!(sorted.is_empty());
1464 }
1465
1466 #[test]
1467 fn test_counting_sort_all_same() {
1468 let data = vec![5u32; 10];
1469 let sorted = counting_sort_u32(&data, 5);
1470 assert_eq!(sorted, vec![5u32; 10]);
1471 }
1472
1473 #[test]
1474 fn test_counting_sort_by_key() {
1475 let data: Vec<(u32, &str)> = vec![(3, "c"), (1, "a"), (2, "b")];
1476 let sorted = counting_sort_by_key(&data, 3);
1477 assert_eq!(sorted[0].0, 1);
1478 assert_eq!(sorted[1].0, 2);
1479 assert_eq!(sorted[2].0, 3);
1480 }
1481
1482 #[test]
1483 fn test_counting_sort_by_key_stable() {
1484 let data: Vec<(u32, u32)> = vec![(2, 10), (1, 20), (2, 30)];
1486 let sorted = counting_sort_by_key(&data, 2);
1487 assert_eq!(sorted[0].0, 1);
1488 assert_eq!(sorted[1].0, 2);
1489 assert_eq!(sorted[2].0, 2);
1490 assert_eq!(sorted[1].1, 10);
1492 assert_eq!(sorted[2].1, 30);
1493 }
1494
1495 #[test]
1498 fn test_histogram_bucket_sort_basic() {
1499 let mut data = vec![5.0, 3.0, 8.0, 1.0, 4.0, 7.0, 2.0, 6.0];
1500 let original = data.clone();
1501 histogram_bucket_sort(&mut data, 4);
1502 assert!(is_sorted_f64(&data));
1503 assert!(is_permutation_f64(&original, &data));
1504 }
1505
1506 #[test]
1507 fn test_histogram_bucket_sort_single_bucket() {
1508 let mut data = vec![3.0, 1.0, 2.0, 4.0];
1509 let original = data.clone();
1510 histogram_bucket_sort(&mut data, 1);
1511 assert!(is_sorted_f64(&data));
1512 assert!(is_permutation_f64(&original, &data));
1513 }
1514
1515 #[test]
1516 fn test_histogram_bucket_sort_all_equal() {
1517 let mut data = vec![5.0; 10];
1518 histogram_bucket_sort(&mut data, 4);
1519 assert!(is_sorted_f64(&data));
1520 }
1521
1522 #[test]
1523 fn test_histogram_bucket_sort_large() {
1524 let mut data: Vec<f64> = (0..200u32).rev().map(|x| x as f64).collect();
1525 let original = data.clone();
1526 histogram_bucket_sort(&mut data, 20);
1527 assert!(is_sorted_f64(&data));
1528 assert!(is_permutation_f64(&original, &data));
1529 }
1530
1531 #[test]
1532 fn test_adaptive_bucket_sort() {
1533 let mut data = vec![9.0, 3.0, 6.0, 1.0, 8.0, 4.0, 2.0, 7.0, 5.0];
1534 let orig = data.clone();
1535 adaptive_bucket_sort(&mut data, 3);
1536 assert!(is_sorted_f64(&data));
1537 assert!(is_permutation_f64(&orig, &data));
1538 }
1539
1540 #[test]
1543 fn test_sort_validation_correct() {
1544 let orig = vec![3.0, 1.0, 4.0, 1.5, 9.0];
1545 let mut sorted = orig.clone();
1546 merge_sort_parallel(&mut sorted);
1547 let v = SortValidation::validate_f64(&orig, &sorted);
1548 assert!(v.is_correct());
1549 assert_eq!(v.inversions, 0);
1550 assert_eq!(v.n, 5);
1551 }
1552
1553 #[test]
1554 fn test_sort_validation_unsorted() {
1555 let orig = vec![1.0, 3.0, 2.0];
1556 let not_sorted = vec![1.0, 3.0, 2.0];
1557 let v = SortValidation::validate_f64(&orig, ¬_sorted);
1558 assert!(!v.is_sorted);
1559 assert!(v.is_permutation);
1560 assert!(!v.is_correct());
1561 }
1562
1563 #[test]
1564 fn test_sort_validation_u32() {
1565 let orig = vec![5u32, 3, 8, 1];
1566 let mut sorted = orig.clone();
1567 radix_sort_u32(&mut sorted);
1568 let v = SortValidation::validate_u32(&orig, &sorted);
1569 assert!(v.is_correct());
1570 }
1571
1572 #[test]
1575 fn test_merge_sorted_basic() {
1576 let a = vec![1.0, 3.0, 5.0];
1577 let b = vec![2.0, 4.0, 6.0];
1578 let m = merge_sorted(&a, &b);
1579 assert_eq!(m, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1580 }
1581
1582 #[test]
1583 fn test_merge_sorted_empty_left() {
1584 let a: Vec<f64> = vec![];
1585 let b = vec![1.0, 2.0, 3.0];
1586 let m = merge_sorted(&a, &b);
1587 assert_eq!(m, b);
1588 }
1589
1590 #[test]
1591 fn test_merge_sorted_empty_right() {
1592 let a = vec![1.0, 2.0, 3.0];
1593 let b: Vec<f64> = vec![];
1594 let m = merge_sorted(&a, &b);
1595 assert_eq!(m, a);
1596 }
1597
1598 #[test]
1599 fn test_merge_sorted_u32() {
1600 let a = vec![1u32, 4, 7];
1601 let b = vec![2u32, 5, 8];
1602 let m = merge_sorted_u32(&a, &b);
1603 assert_eq!(m, vec![1, 2, 4, 5, 7, 8]);
1604 }
1605
1606 #[test]
1607 fn test_k_way_merge() {
1608 let s1 = vec![1.0, 4.0, 7.0];
1609 let s2 = vec![2.0, 5.0, 8.0];
1610 let s3 = vec![3.0, 6.0, 9.0];
1611 let m = k_way_merge(&[s1, s2, s3]);
1612 assert!(is_sorted_f64(&m));
1613 assert_eq!(m.len(), 9);
1614 }
1615
1616 #[test]
1617 fn test_k_way_merge_single() {
1618 let s = vec![vec![3.0, 1.0, 2.0]]; let m = k_way_merge(&s);
1620 assert!(is_sorted_f64(&m));
1621 }
1622
1623 #[test]
1624 fn test_merge_sort_parallel_threshold() {
1625 let mut data: Vec<f64> = (0..100u32).rev().map(|x| x as f64).collect();
1626 let orig = data.clone();
1627 merge_sort_parallel_threshold(&mut data, 32);
1628 assert!(is_sorted_f64(&data));
1629 assert!(is_permutation_f64(&orig, &data));
1630 }
1631
1632 #[test]
1633 fn test_merge_sort_parallel_threshold_small() {
1634 let mut data = vec![3.0, 1.0, 2.0];
1635 let orig = data.clone();
1636 merge_sort_parallel_threshold(&mut data, 1024);
1637 assert!(is_sorted_f64(&data));
1638 assert!(is_permutation_f64(&orig, &data));
1639 }
1640}