Skip to main content

oxiphysics_gpu/
gpu_sort.rs

1#![allow(clippy::if_same_then_else, clippy::ptr_arg)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! GPU-style parallel sorting algorithms (CPU simulation).
6//!
7//! This module provides GPU-pattern algorithms for sorting and related
8//! operations on `f32`/`u32` data:
9//! - Bitonic sort (power-of-2 padded)
10//! - LSD radix sort for `u32` and `f32`
11//! - Exclusive prefix sum (Blelloch scan style)
12//! - Parallel histogram
13//! - Counting sort
14//! - Morton Z-order sort for 3-D point clouds
15//! - [`GpuSortBuffer`] — key+value buffer with `sort_pairs`
16//! - Parallel merge
17
18#![allow(dead_code)]
19
20// ─────────────────────────────────────────────────────────────────────────────
21// Bitonic sort (f32)
22// ─────────────────────────────────────────────────────────────────────────────
23
24/// Sort a `Vec`f32` in ascending order using the bitonic sort algorithm.
25///
26/// If `data.len()` is not a power of two, the vector is padded to the next
27/// power of two with `f32::MAX` sentinel values, sorted, then truncated.
28pub fn bitonic_sort(data: &mut Vec<f32>) {
29    let orig = data.len();
30    if orig <= 1 {
31        return;
32    }
33    let padded = next_pow2(orig);
34    data.resize(padded, f32::MAX);
35    bitonic_sort_slice_f32(data);
36    data.truncate(orig);
37}
38
39/// Sort a `Vec`T` by the `u32` key returned by `key_fn`.
40///
41/// Padding uses sentinel key `u32::MAX`; after sorting, only the first
42/// `orig_len` elements (those with valid indices) are retained.
43pub fn bitonic_sort_by_key<T: Clone>(data: &mut Vec<T>, key_fn: impl Fn(&T) -> u32) {
44    let orig = data.len();
45    if orig <= 1 {
46        return;
47    }
48    let padded = next_pow2(orig);
49    // Build (key, original-index) pairs and pad.
50    let mut pairs: Vec<(u32, usize)> = data
51        .iter()
52        .enumerate()
53        .map(|(i, v)| (key_fn(v), i))
54        .collect();
55    pairs.resize(padded, (u32::MAX, usize::MAX));
56
57    // Bitonic sort the pairs.
58    let n = pairs.len();
59    let mut k = 2;
60    while k <= n {
61        let mut j = k / 2;
62        while j >= 1 {
63            for i in 0..n {
64                let l = i ^ j;
65                if l > i {
66                    let ascending = (i & k) == 0;
67                    if ascending && pairs[i].0 > pairs[l].0 {
68                        pairs.swap(i, l);
69                    } else if !ascending && pairs[i].0 < pairs[l].0 {
70                        pairs.swap(i, l);
71                    }
72                }
73            }
74            j /= 2;
75        }
76        k *= 2;
77    }
78
79    // Reconstruct data: take the first `orig` sorted entries that have valid indices.
80    let old = data.clone();
81    let mut out: Vec<T> = pairs
82        .iter()
83        .filter(|&&(_, idx)| idx < orig)
84        .map(|&(_, idx)| old[idx].clone())
85        .collect();
86    out.truncate(orig);
87    // In case some items were duplicated by the sentinel logic, ensure length.
88    for (i, v) in out.into_iter().enumerate().take(orig) {
89        data[i] = v;
90    }
91}
92
93/// Internal: in-place bitonic sort on a slice whose length is a power of two.
94fn bitonic_sort_slice_f32(data: &mut [f32]) {
95    let n = data.len();
96    let mut k = 2;
97    while k <= n {
98        let mut j = k / 2;
99        while j >= 1 {
100            for i in 0..n {
101                let l = i ^ j;
102                if l > i {
103                    let ascending = (i & k) == 0;
104                    if ascending && data[i] > data[l] {
105                        data.swap(i, l);
106                    } else if !ascending && data[i] < data[l] {
107                        data.swap(i, l);
108                    }
109                }
110            }
111            j /= 2;
112        }
113        k *= 2;
114    }
115}
116
117// ─────────────────────────────────────────────────────────────────────────────
118// Radix sort
119// ─────────────────────────────────────────────────────────────────────────────
120
121/// LSD radix sort for `u32` using 4 passes of 8-bit digits.
122///
123/// Stable, O(n) complexity for 32-bit keys.
124pub fn radix_sort_u32(data: &mut Vec<u32>) {
125    if data.len() <= 1 {
126        return;
127    }
128    let n = data.len();
129    let mut buf = vec![0u32; n];
130    for pass in 0..4u32 {
131        let shift = pass * 8;
132        let mut counts = [0usize; 256];
133        for &v in data.iter() {
134            counts[((v >> shift) & 0xFF) as usize] += 1;
135        }
136        let mut offsets = [0usize; 256];
137        let mut total = 0;
138        for i in 0..256 {
139            offsets[i] = total;
140            total += counts[i];
141        }
142        for &v in data.iter() {
143            let b = ((v >> shift) & 0xFF) as usize;
144            buf[offsets[b]] = v;
145            offsets[b] += 1;
146        }
147        std::mem::swap(data, &mut buf);
148    }
149}
150
151/// Sort `f32` values by reinterpreting bits and flipping the sign bit for
152/// negatives so that the full IEEE 754 ordering maps to unsigned integer order.
153///
154/// After sorting, the bits are un-flipped back to valid `f32` values.
155pub fn radix_sort_f32(data: &mut Vec<f32>) {
156    if data.len() <= 1 {
157        return;
158    }
159    // Map f32 to a sortable u32.
160    let mut keys: Vec<u32> = data.iter().map(|&v| f32_to_sort_key(v)).collect();
161    radix_sort_u32(&mut keys);
162    for (dst, k) in data.iter_mut().zip(keys.iter()) {
163        *dst = sort_key_to_f32(*k);
164    }
165}
166
167/// Convert an `f32` to a sortable `u32` (flip sign bit; for negatives also
168/// flip the remaining bits so the order is preserved).
169#[inline]
170fn f32_to_sort_key(v: f32) -> u32 {
171    let bits = v.to_bits();
172    if bits >> 31 == 0 {
173        bits | 0x8000_0000 // positive: set sign bit
174    } else {
175        !bits // negative: flip all bits
176    }
177}
178
179/// Inverse of [`f32_to_sort_key`].
180#[inline]
181fn sort_key_to_f32(key: u32) -> f32 {
182    let bits = if key >> 31 != 0 {
183        key & 0x7FFF_FFFF // positive
184    } else {
185        !key // negative
186    };
187    f32::from_bits(bits)
188}
189
190// ─────────────────────────────────────────────────────────────────────────────
191// Prefix sum (Blelloch exclusive scan)
192// ─────────────────────────────────────────────────────────────────────────────
193
194/// Exclusive prefix sum (Blelloch scan) for `u32` values.
195///
196/// Returns a new `Vec`u32` where `result\[i\] = sum(data\[0..i\])`.
197/// `result\[0\]` is always `0`.
198pub fn prefix_sum(data: &[u32]) -> Vec<u32> {
199    let mut result = Vec::with_capacity(data.len());
200    let mut acc = 0u32;
201    for &v in data {
202        result.push(acc);
203        acc = acc.wrapping_add(v);
204    }
205    result
206}
207
208// ─────────────────────────────────────────────────────────────────────────────
209// Histogram
210// ─────────────────────────────────────────────────────────────────────────────
211
212/// Compute a histogram of `data` into `n_bins` equal-width bins over
213/// `\[min_val, max_val)`.
214///
215/// Values outside the range are clamped into the first or last bin.
216/// Returns a `Vec`u32` of length `n_bins`.
217///
218/// # Panics
219/// Panics if `n_bins == 0`.
220pub fn histogram(data: &[u32], n_bins: usize) -> Vec<u32> {
221    assert!(n_bins > 0, "n_bins must be > 0");
222    if data.is_empty() {
223        return vec![0u32; n_bins];
224    }
225    let max_val = *data.iter().max().unwrap_or(&0) as u64 + 1;
226    let mut bins = vec![0u32; n_bins];
227    for &v in data {
228        let idx = ((v as u64 * n_bins as u64) / max_val) as usize;
229        let idx = idx.min(n_bins - 1);
230        bins[idx] += 1;
231    }
232    bins
233}
234
235// ─────────────────────────────────────────────────────────────────────────────
236// Counting sort
237// ─────────────────────────────────────────────────────────────────────────────
238
239/// Counting sort for `u32` values bounded by `max_val` (inclusive).
240///
241/// Creates a count array of size `max_val + 1` and reconstructs sorted data
242/// from it. O(n + max_val) time and space.
243pub fn counting_sort(data: &mut Vec<u32>, max_val: u32) {
244    if data.len() <= 1 {
245        return;
246    }
247    let size = (max_val as usize).saturating_add(1);
248    let mut counts = vec![0u32; size];
249    for &v in data.iter() {
250        let idx = (v as usize).min(size - 1);
251        counts[idx] += 1;
252    }
253    let mut pos = 0usize;
254    for (val, &cnt) in counts.iter().enumerate() {
255        for _ in 0..cnt {
256            data[pos] = val as u32;
257            pos += 1;
258        }
259    }
260}
261
262// ─────────────────────────────────────────────────────────────────────────────
263// Morton sort (Z-order curve for 3-D points)
264// ─────────────────────────────────────────────────────────────────────────────
265
266/// Sort 3-D `f32` points by their Morton Z-order (space-filling curve) code.
267///
268/// Coordinates are quantised to 10-bit integers before interleaving, which
269/// gives 30-bit Morton codes that fit in a `u32`.
270pub fn morton_sort_3d(points: &mut Vec<[f32; 3]>) {
271    if points.len() <= 1 {
272        return;
273    }
274    // Compute bounding box.
275    let mut lo = [f32::INFINITY; 3];
276    let mut hi = [f32::NEG_INFINITY; 3];
277    for p in points.iter() {
278        for d in 0..3 {
279            lo[d] = lo[d].min(p[d]);
280            hi[d] = hi[d].max(p[d]);
281        }
282    }
283    let scale: Vec<f32> = (0..3)
284        .map(|d| {
285            let range = hi[d] - lo[d];
286            if range > 0.0 { 1023.0 / range } else { 0.0 }
287        })
288        .collect();
289
290    let mut pairs: Vec<(u32, usize)> = points
291        .iter()
292        .enumerate()
293        .map(|(i, p)| {
294            let ix = ((p[0] - lo[0]) * scale[0]) as u32;
295            let iy = ((p[1] - lo[1]) * scale[1]) as u32;
296            let iz = ((p[2] - lo[2]) * scale[2]) as u32;
297            (morton3(ix.min(1023), iy.min(1023), iz.min(1023)), i)
298        })
299        .collect();
300
301    pairs.sort_unstable_by_key(|&(code, _)| code);
302
303    let old = points.clone();
304    for (i, &(_, idx)) in pairs.iter().enumerate() {
305        points[i] = old[idx];
306    }
307}
308
309/// Interleave the lower 10 bits of x, y, z into a 30-bit Morton code.
310fn morton3(x: u32, y: u32, z: u32) -> u32 {
311    spread_bits(x) | (spread_bits(y) << 1) | (spread_bits(z) << 2)
312}
313
314/// Spread the lower 10 bits of `v` into every third bit position.
315fn spread_bits(mut v: u32) -> u32 {
316    v &= 0x3FF; // keep lower 10 bits
317    v = (v | (v << 16)) & 0x030000FF;
318    v = (v | (v << 8)) & 0x0300F00F;
319    v = (v | (v << 4)) & 0x030C30C3;
320    v = (v | (v << 2)) & 0x09249249;
321    v
322}
323
324// ─────────────────────────────────────────────────────────────────────────────
325// GpuSortBuffer
326// ─────────────────────────────────────────────────────────────────────────────
327
328/// A buffer abstraction that holds parallel key and value arrays.
329///
330/// Provides `sort_pairs` to co-sort both arrays by key using radix sort.
331#[derive(Debug, Clone)]
332pub struct GpuSortBuffer {
333    /// Sort keys.
334    pub keys: Vec<u32>,
335    /// Associated values (same length as `keys`).
336    pub values: Vec<u32>,
337}
338
339impl GpuSortBuffer {
340    /// Create a new `GpuSortBuffer` with the given key and value arrays.
341    ///
342    /// # Panics
343    /// Panics if `keys` and `values` have different lengths.
344    pub fn new(keys: Vec<u32>, values: Vec<u32>) -> Self {
345        assert_eq!(
346            keys.len(),
347            values.len(),
348            "keys and values must have equal length"
349        );
350        Self { keys, values }
351    }
352
353    /// Create an empty buffer.
354    pub fn empty() -> Self {
355        Self {
356            keys: Vec::new(),
357            values: Vec::new(),
358        }
359    }
360
361    /// Number of key-value pairs.
362    pub fn len(&self) -> usize {
363        self.keys.len()
364    }
365
366    /// Returns `true` if the buffer is empty.
367    pub fn is_empty(&self) -> bool {
368        self.keys.is_empty()
369    }
370
371    /// Sort keys and values together using LSD radix sort, stable by key.
372    pub fn sort_pairs(&mut self) {
373        if self.len() <= 1 {
374            return;
375        }
376        let n = self.len();
377        let mut key_buf = vec![0u32; n];
378        let mut val_buf = vec![0u32; n];
379        for pass in 0..4u32 {
380            let shift = pass * 8;
381            let mut counts = [0usize; 256];
382            for &k in self.keys.iter() {
383                counts[((k >> shift) & 0xFF) as usize] += 1;
384            }
385            let mut offsets = [0usize; 256];
386            let mut total = 0;
387            for i in 0..256 {
388                offsets[i] = total;
389                total += counts[i];
390            }
391            for (i, &k) in self.keys.iter().enumerate() {
392                let b = ((k >> shift) & 0xFF) as usize;
393                let dest = offsets[b];
394                key_buf[dest] = k;
395                val_buf[dest] = self.values[i];
396                offsets[b] += 1;
397            }
398            std::mem::swap(&mut self.keys, &mut key_buf);
399            std::mem::swap(&mut self.values, &mut val_buf);
400        }
401    }
402
403    /// Append a key-value pair to the buffer.
404    pub fn push(&mut self, key: u32, value: u32) {
405        self.keys.push(key);
406        self.values.push(value);
407    }
408}
409
410// ─────────────────────────────────────────────────────────────────────────────
411// Parallel merge
412// ─────────────────────────────────────────────────────────────────────────────
413
414/// Merge two sorted `f32` slices into a new sorted `Vec`f32`.
415///
416/// Both inputs must already be sorted in non-decreasing order.
417/// Uses a standard two-pointer merge (O(n+m)).
418pub fn parallel_merge(left: &[f32], right: &[f32]) -> Vec<f32> {
419    let mut result = Vec::with_capacity(left.len() + right.len());
420    let (mut i, mut j) = (0, 0);
421    while i < left.len() && j < right.len() {
422        if left[i] <= right[j] {
423            result.push(left[i]);
424            i += 1;
425        } else {
426            result.push(right[j]);
427            j += 1;
428        }
429    }
430    result.extend_from_slice(&left[i..]);
431    result.extend_from_slice(&right[j..]);
432    result
433}
434
435// ─────────────────────────────────────────────────────────────────────────────
436// Internal helpers
437// ─────────────────────────────────────────────────────────────────────────────
438
439/// Round `n` up to the next power of two (returns 1 for 0).
440fn next_pow2(n: usize) -> usize {
441    if n == 0 {
442        return 1;
443    }
444    let mut p = 1usize;
445    while p < n {
446        p <<= 1;
447    }
448    p
449}
450
451// ─────────────────────────────────────────────────────────────────────────────
452// Tests
453// ─────────────────────────────────────────────────────────────────────────────
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    // ── next_pow2 ─────────────────────────────────────────────────────────────
460
461    #[test]
462    fn test_next_pow2_zero() {
463        assert_eq!(next_pow2(0), 1);
464    }
465
466    #[test]
467    fn test_next_pow2_one() {
468        assert_eq!(next_pow2(1), 1);
469    }
470
471    #[test]
472    fn test_next_pow2_exact() {
473        assert_eq!(next_pow2(8), 8);
474    }
475
476    #[test]
477    fn test_next_pow2_non_exact() {
478        assert_eq!(next_pow2(9), 16);
479    }
480
481    // ── bitonic_sort ──────────────────────────────────────────────────────────
482
483    #[test]
484    fn test_bitonic_sort_power_of_two() {
485        let mut data = vec![4.0f32, 2.0, 7.0, 1.0, 8.0, 3.0, 6.0, 5.0];
486        bitonic_sort(&mut data);
487        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
488    }
489
490    #[test]
491    fn test_bitonic_sort_non_power_of_two() {
492        let mut data = vec![5.0f32, 3.0, 1.0, 4.0, 2.0];
493        bitonic_sort(&mut data);
494        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
495    }
496
497    #[test]
498    fn test_bitonic_sort_empty() {
499        let mut data: Vec<f32> = vec![];
500        bitonic_sort(&mut data);
501        assert!(data.is_empty());
502    }
503
504    #[test]
505    fn test_bitonic_sort_single() {
506        let mut data = vec![42.0f32];
507        bitonic_sort(&mut data);
508        assert_eq!(data, vec![42.0]);
509    }
510
511    #[test]
512    fn test_bitonic_sort_already_sorted() {
513        let mut data = vec![1.0f32, 2.0, 3.0, 4.0];
514        bitonic_sort(&mut data);
515        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
516    }
517
518    #[test]
519    fn test_bitonic_sort_reverse() {
520        let mut data = vec![8.0f32, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
521        bitonic_sort(&mut data);
522        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
523    }
524
525    #[test]
526    fn test_bitonic_sort_duplicates() {
527        let mut data = vec![3.0f32, 1.0, 3.0, 2.0, 1.0];
528        bitonic_sort(&mut data);
529        assert_eq!(data, vec![1.0, 1.0, 2.0, 3.0, 3.0]);
530    }
531
532    #[test]
533    fn test_bitonic_sort_large_non_pow2() {
534        let mut data: Vec<f32> = (0..100).map(|i| (100 - i) as f32).collect();
535        bitonic_sort(&mut data);
536        for i in 0..data.len() - 1 {
537            assert!(data[i] <= data[i + 1]);
538        }
539    }
540
541    // ── bitonic_sort_by_key ───────────────────────────────────────────────────
542
543    #[test]
544    fn test_bitonic_sort_by_key_u32() {
545        let mut data = vec![30u32, 10, 20, 40];
546        bitonic_sort_by_key(&mut data, |x| *x);
547        assert_eq!(data, vec![10, 20, 30, 40]);
548    }
549
550    #[test]
551    fn test_bitonic_sort_by_key_empty() {
552        let mut data: Vec<u32> = vec![];
553        bitonic_sort_by_key(&mut data, |x| *x);
554        assert!(data.is_empty());
555    }
556
557    #[test]
558    fn test_bitonic_sort_by_key_single() {
559        let mut data = vec![42u32];
560        bitonic_sort_by_key(&mut data, |x| *x);
561        assert_eq!(data, vec![42]);
562    }
563
564    // ── radix_sort_u32 ────────────────────────────────────────────────────────
565
566    #[test]
567    fn test_radix_sort_u32_basic() {
568        let mut data = vec![5u32, 3, 8, 1, 9, 2, 7, 4, 6];
569        radix_sort_u32(&mut data);
570        assert_eq!(data, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
571    }
572
573    #[test]
574    fn test_radix_sort_u32_empty() {
575        let mut data: Vec<u32> = vec![];
576        radix_sort_u32(&mut data);
577        assert!(data.is_empty());
578    }
579
580    #[test]
581    fn test_radix_sort_u32_single() {
582        let mut data = vec![42u32];
583        radix_sort_u32(&mut data);
584        assert_eq!(data, vec![42]);
585    }
586
587    #[test]
588    fn test_radix_sort_u32_large_values() {
589        let mut data = vec![u32::MAX, 0u32, u32::MAX / 2, 1u32];
590        radix_sort_u32(&mut data);
591        assert_eq!(data[0], 0);
592        assert_eq!(data[3], u32::MAX);
593    }
594
595    #[test]
596    fn test_radix_sort_u32_duplicates() {
597        let mut data = vec![3u32, 1, 3, 2, 1];
598        radix_sort_u32(&mut data);
599        assert_eq!(data, vec![1, 1, 2, 3, 3]);
600    }
601
602    // ── radix_sort_f32 ────────────────────────────────────────────────────────
603
604    #[test]
605    fn test_radix_sort_f32_positive_only() {
606        let mut data = vec![3.0f32, 1.0, 4.0, 1.5, 0.5];
607        radix_sort_f32(&mut data);
608        for i in 0..data.len() - 1 {
609            assert!(data[i] <= data[i + 1]);
610        }
611    }
612
613    #[test]
614    fn test_radix_sort_f32_with_negatives() {
615        let mut data = vec![1.0f32, -2.0, 0.5, -0.5, 3.0, -1.0];
616        radix_sort_f32(&mut data);
617        for i in 0..data.len() - 1 {
618            assert!(
619                data[i] <= data[i + 1],
620                "not sorted at {i}: {} > {}",
621                data[i],
622                data[i + 1]
623            );
624        }
625    }
626
627    #[test]
628    fn test_radix_sort_f32_empty() {
629        let mut data: Vec<f32> = vec![];
630        radix_sort_f32(&mut data);
631        assert!(data.is_empty());
632    }
633
634    #[test]
635    fn test_f32_sort_key_roundtrip_positive() {
636        let v = 3.125f32;
637        assert_eq!(sort_key_to_f32(f32_to_sort_key(v)), v);
638    }
639
640    #[test]
641    fn test_f32_sort_key_roundtrip_negative() {
642        let v = -2.719f32;
643        assert_eq!(sort_key_to_f32(f32_to_sort_key(v)), v);
644    }
645
646    // ── prefix_sum ────────────────────────────────────────────────────────────
647
648    #[test]
649    fn test_prefix_sum_basic() {
650        let data = [1u32, 2, 3, 4];
651        let result = prefix_sum(&data);
652        assert_eq!(result, vec![0, 1, 3, 6]);
653    }
654
655    #[test]
656    fn test_prefix_sum_empty() {
657        let result = prefix_sum(&[]);
658        assert!(result.is_empty());
659    }
660
661    #[test]
662    fn test_prefix_sum_single() {
663        let result = prefix_sum(&[5u32]);
664        assert_eq!(result, vec![0]);
665    }
666
667    #[test]
668    fn test_prefix_sum_all_ones() {
669        let data = vec![1u32; 5];
670        let result = prefix_sum(&data);
671        assert_eq!(result, vec![0, 1, 2, 3, 4]);
672    }
673
674    // ── histogram ────────────────────────────────────────────────────────────
675
676    #[test]
677    fn test_histogram_basic() {
678        let data = [0u32, 1, 2, 3, 4, 5, 6, 7, 8, 9];
679        let h = histogram(&data, 5);
680        assert_eq!(h.len(), 5);
681        let total: u32 = h.iter().sum();
682        assert_eq!(total, 10);
683    }
684
685    #[test]
686    fn test_histogram_empty_data() {
687        let h = histogram(&[], 4);
688        assert_eq!(h, vec![0, 0, 0, 0]);
689    }
690
691    #[test]
692    fn test_histogram_all_same() {
693        let data = vec![5u32; 10];
694        let h = histogram(&data, 3);
695        assert_eq!(h.iter().sum::<u32>(), 10);
696    }
697
698    // ── counting_sort ─────────────────────────────────────────────────────────
699
700    #[test]
701    fn test_counting_sort_basic() {
702        let mut data = vec![3u32, 1, 4, 1, 5, 9, 2, 6];
703        counting_sort(&mut data, 9);
704        for i in 0..data.len() - 1 {
705            assert!(data[i] <= data[i + 1]);
706        }
707    }
708
709    #[test]
710    fn test_counting_sort_empty() {
711        let mut data: Vec<u32> = vec![];
712        counting_sort(&mut data, 10);
713        assert!(data.is_empty());
714    }
715
716    #[test]
717    fn test_counting_sort_single() {
718        let mut data = vec![7u32];
719        counting_sort(&mut data, 10);
720        assert_eq!(data, vec![7]);
721    }
722
723    #[test]
724    fn test_counting_sort_duplicates() {
725        let mut data = vec![2u32, 2, 2, 1, 1];
726        counting_sort(&mut data, 2);
727        assert_eq!(data, vec![1, 1, 2, 2, 2]);
728    }
729
730    #[test]
731    fn test_counting_sort_all_zero() {
732        let mut data = vec![0u32; 5];
733        counting_sort(&mut data, 0);
734        assert_eq!(data, vec![0u32; 5]);
735    }
736
737    // ── morton_sort_3d ────────────────────────────────────────────────────────
738
739    #[test]
740    fn test_morton_sort_3d_basic() {
741        let mut points = vec![
742            [1.0f32, 0.0, 0.0],
743            [0.0, 0.0, 0.0],
744            [0.0, 1.0, 0.0],
745            [1.0, 1.0, 0.0],
746        ];
747        morton_sort_3d(&mut points);
748        // After sorting, [0,0,0] should come first (Morton code 0).
749        assert_eq!(points[0], [0.0, 0.0, 0.0]);
750    }
751
752    #[test]
753    fn test_morton_sort_3d_empty() {
754        let mut points: Vec<[f32; 3]> = vec![];
755        morton_sort_3d(&mut points);
756        assert!(points.is_empty());
757    }
758
759    #[test]
760    fn test_morton_sort_3d_single() {
761        let mut points = vec![[1.0f32, 2.0, 3.0]];
762        morton_sort_3d(&mut points);
763        assert_eq!(points, vec![[1.0, 2.0, 3.0]]);
764    }
765
766    #[test]
767    fn test_morton3_origin() {
768        assert_eq!(morton3(0, 0, 0), 0);
769    }
770
771    #[test]
772    fn test_morton3_unit_x() {
773        // x=1, y=0, z=0 => bit 0 set in x position => interleaved x is bit 0
774        let code = morton3(1, 0, 0);
775        assert_ne!(code, 0);
776    }
777
778    #[test]
779    fn test_spread_bits_zero() {
780        assert_eq!(spread_bits(0), 0);
781    }
782
783    #[test]
784    fn test_spread_bits_one() {
785        // bit 0 of input stays as bit 0 of output
786        assert_eq!(spread_bits(1) & 1, 1);
787    }
788
789    // ── GpuSortBuffer ─────────────────────────────────────────────────────────
790
791    #[test]
792    fn test_gpu_sort_buffer_sort_pairs_basic() {
793        let keys = vec![3u32, 1, 4, 1, 5, 9, 2, 6];
794        let values: Vec<u32> = (0..keys.len() as u32).collect();
795        let mut buf = GpuSortBuffer::new(keys, values);
796        buf.sort_pairs();
797        for i in 0..buf.keys.len() - 1 {
798            assert!(buf.keys[i] <= buf.keys[i + 1]);
799        }
800    }
801
802    #[test]
803    fn test_gpu_sort_buffer_empty() {
804        let mut buf = GpuSortBuffer::empty();
805        buf.sort_pairs();
806        assert!(buf.is_empty());
807    }
808
809    #[test]
810    fn test_gpu_sort_buffer_push() {
811        let mut buf = GpuSortBuffer::empty();
812        buf.push(5, 100);
813        buf.push(2, 200);
814        assert_eq!(buf.len(), 2);
815        buf.sort_pairs();
816        assert_eq!(buf.keys[0], 2);
817        assert_eq!(buf.values[0], 200);
818    }
819
820    #[test]
821    fn test_gpu_sort_buffer_values_follow_keys() {
822        let keys = vec![30u32, 10, 20];
823        let values = vec![3u32, 1, 2];
824        let mut buf = GpuSortBuffer::new(keys, values);
825        buf.sort_pairs();
826        assert_eq!(buf.keys, vec![10, 20, 30]);
827        assert_eq!(buf.values, vec![1, 2, 3]);
828    }
829
830    #[test]
831    fn test_gpu_sort_buffer_len_is_empty() {
832        let buf = GpuSortBuffer::empty();
833        assert_eq!(buf.len(), 0);
834        assert!(buf.is_empty());
835    }
836
837    // ── parallel_merge ────────────────────────────────────────────────────────
838
839    #[test]
840    fn test_parallel_merge_basic() {
841        let left = vec![1.0f32, 3.0, 5.0];
842        let right = vec![2.0f32, 4.0, 6.0];
843        let merged = parallel_merge(&left, &right);
844        assert_eq!(merged, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
845    }
846
847    #[test]
848    fn test_parallel_merge_empty_left() {
849        let merged = parallel_merge(&[], &[1.0f32, 2.0]);
850        assert_eq!(merged, vec![1.0, 2.0]);
851    }
852
853    #[test]
854    fn test_parallel_merge_empty_right() {
855        let merged = parallel_merge(&[1.0f32, 2.0], &[]);
856        assert_eq!(merged, vec![1.0, 2.0]);
857    }
858
859    #[test]
860    fn test_parallel_merge_both_empty() {
861        let merged: Vec<f32> = parallel_merge(&[], &[]);
862        assert!(merged.is_empty());
863    }
864
865    #[test]
866    fn test_parallel_merge_unequal_lengths() {
867        let left = vec![1.0f32, 10.0];
868        let right = vec![2.0f32, 3.0, 4.0, 5.0];
869        let merged = parallel_merge(&left, &right);
870        assert_eq!(merged.len(), 6);
871        for i in 0..merged.len() - 1 {
872            assert!(merged[i] <= merged[i + 1]);
873        }
874    }
875
876    #[test]
877    fn test_parallel_merge_duplicates() {
878        let left = vec![1.0f32, 2.0, 2.0];
879        let right = vec![2.0f32, 3.0];
880        let merged = parallel_merge(&left, &right);
881        assert_eq!(merged.len(), 5);
882        for i in 0..merged.len() - 1 {
883            assert!(merged[i] <= merged[i + 1]);
884        }
885    }
886}