use core::mem;
#[inline(always)]
pub fn radix_sort_u32_soa<N: Copy>(inds: &mut [u32], vals: &mut [N]) {
assert_eq!(inds.len(), vals.len());
let n = inds.len();
if n <= 1 {
return;
}
if n <= 32 {
insertion_sort_u32_soa(inds, vals);
return;
}
let mut inds_tmp = vec![0u32; n];
let mut vals_tmp: Vec<N> = vec![unsafe { mem::zeroed() }; n];
let mut src_inds: &mut [u32] = inds;
let mut src_vals: &mut [N] = vals;
let mut dst_inds: &mut [u32] = &mut inds_tmp;
let mut dst_vals: &mut [N] = &mut vals_tmp;
for shift in [0u32, 8, 16, 24] {
let mut count = [0usize; 256];
for &k in src_inds.iter() {
count[((k >> shift) & 0xFF) as usize] += 1;
}
let mut sum = 0usize;
for c in count.iter_mut() {
let tmp = *c;
*c = sum;
sum += tmp;
}
unsafe {
let s_i = src_inds.as_ptr();
let s_v = src_vals.as_ptr();
let d_i = dst_inds.as_mut_ptr();
let d_v = dst_vals.as_mut_ptr();
for idx in 0..n {
let k = *s_i.add(idx);
let b = ((k >> shift) & 0xFF) as usize;
let pos = count[b];
count[b] = pos + 1;
*d_i.add(pos) = k;
*d_v.add(pos) = *s_v.add(idx);
}
}
mem::swap(&mut src_inds, &mut dst_inds);
mem::swap(&mut src_vals, &mut dst_vals);
}
}
#[inline(always)]
fn insertion_sort_u32_soa<N: Copy>(inds: &mut [u32], vals: &mut [N]) {
let n = inds.len();
for i in 1..n {
let mut j = i;
while j > 0 && inds[j] < inds[j - 1] {
inds.swap(j, j - 1);
vals.swap(j, j - 1);
j -= 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn baseline_stable_sort<N: Copy>(inds: &[u32], vals: &[N]) -> (Vec<u32>, Vec<N>) {
let mut pairs: Vec<(u32, usize, N)> = inds
.iter()
.copied()
.enumerate()
.map(|(i, k)| (k, i, vals[i]))
.collect();
pairs.sort_unstable_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
let mut out_k = Vec::with_capacity(pairs.len());
let mut out_v = Vec::with_capacity(pairs.len());
for (k, _i, v) in pairs {
out_k.push(k);
out_v.push(v);
}
(out_k, out_v)
}
fn assert_sorted(keys: &[u32]) {
for i in 1..keys.len() {
assert!(keys[i - 1] <= keys[i], "not sorted at {i}: {} > {}", keys[i - 1], keys[i]);
}
}
struct Rng(u32);
impl Rng {
fn new(seed: u32) -> Self { Self(seed) }
fn next_u32(&mut self) -> u32 {
let mut x = self.0;
x ^= x << 13;
x ^= x >> 17;
x ^= x << 5;
self.0 = x;
x
}
}
#[test]
fn radix_sort_handles_empty_and_single() {
let mut inds: Vec<u32> = vec![];
let mut vals: Vec<u16> = vec![];
radix_sort_u32_soa(&mut inds, &mut vals);
assert!(inds.is_empty());
assert!(vals.is_empty());
let mut inds = vec![42u32];
let mut vals = vec![7u16];
radix_sort_u32_soa(&mut inds, &mut vals);
assert_eq!(inds, vec![42u32]);
assert_eq!(vals, vec![7u16]);
}
#[test]
fn radix_sort_works_on_duplicates_and_preserves_pairing() {
let mut inds = vec![3u32, 1, 3, 2, 1, 3, 0];
let mut vals: Vec<u32> = (0..inds.len() as u32).collect();
let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
radix_sort_u32_soa(&mut inds, &mut vals);
assert_sorted(&inds);
assert_eq!(inds, base_k);
assert_eq!(vals, base_v, "radix sort should be stable in this implementation");
}
#[test]
fn radix_sort_matches_baseline_many_sizes() {
let mut rng = Rng::new(0x1234_5678);
for &n in &[0usize, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 33, 63, 64, 65, 127, 128, 129, 1024] {
let mut inds = Vec::with_capacity(n);
let mut vals = Vec::with_capacity(n);
for i in 0..n {
let k = rng.next_u32() & 0x00FF_FFFF;
inds.push(k);
vals.push((i as u32) ^ 0xA5A5_5A5A);
}
let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
radix_sort_u32_soa(&mut inds, &mut vals);
assert_sorted(&inds);
assert_eq!(inds, base_k, "keys mismatch at n={n}");
assert_eq!(vals, base_v, "vals mismatch at n={n}");
}
}
#[test]
fn radix_sort_extremes() {
let mut inds = vec![
0u32,
u32::MAX,
1,
u32::MAX - 1,
0,
2,
u32::MAX,
];
let mut vals: Vec<u32> = (0..inds.len() as u32).collect();
let (base_k, base_v) = baseline_stable_sort(&inds, &vals);
radix_sort_u32_soa(&mut inds, &mut vals);
assert_sorted(&inds);
assert_eq!(inds, base_k);
assert_eq!(vals, base_v);
}
}