use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::reduce::radix_sort";
#[must_use]
pub fn radix_sort(input: &str, output: &str, count: u32, bits: u32) -> Program {
if count == 0 {
return crate::invalid_output_program(
OP_ID,
output,
DataType::U32,
format!("Fix: radix_sort requires count > 0, got {count}."),
);
}
if bits > 32 {
return crate::invalid_output_program(
OP_ID,
output,
DataType::U32,
format!("Fix: radix_sort bits must be <= 32, got {bits}."),
);
}
let buffers = vec![
BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::U32).with_count(count),
BufferDecl::storage(output, 1, BufferAccess::ReadWrite, DataType::U32).with_count(count),
];
let t = Expr::InvocationId { axis: 0 };
let mask = if bits == 32 {
u32::MAX
} else if bits == 0 {
0
} else {
(1u32 << bits) - 1
};
let masked_key = |expr: Expr| {
if bits == 32 {
expr
} else {
Expr::bitand(expr, Expr::u32(mask))
}
};
let key_i = masked_key(Expr::load(input, Expr::var("i")));
let key_j = masked_key(Expr::load(input, Expr::var("j")));
let lower_key = Expr::lt(key_j.clone(), key_i.clone());
let stable_tie = Expr::and(
Expr::eq(key_j, key_i),
Expr::lt(Expr::var("j"), Expr::var("i")),
);
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(count)),
vec![
Node::let_bind("i", t.clone()),
Node::let_bind("rank", Expr::u32(0)),
Node::loop_for(
"j",
Expr::u32(0),
Expr::u32(count),
vec![Node::if_then(
Expr::or(lower_key, stable_tie),
vec![Node::assign(
"rank",
Expr::add(Expr::var("rank"), Expr::u32(1)),
)],
)],
),
Node::store(output, Expr::var("rank"), Expr::load(input, Expr::var("i"))),
],
)];
Program::wrapped(
buffers,
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn cpu_ref(input: &[u32], bits: u32) -> Vec<u32> {
let mut out = Vec::new();
let mut scratch = Vec::new();
cpu_ref_into(input, bits, &mut out, &mut scratch);
out
}
pub fn cpu_ref_into(input: &[u32], bits: u32, out: &mut Vec<u32>, scratch: &mut Vec<u32>) {
let bits = bits.min(32);
out.clear();
out.extend_from_slice(input);
if out.is_empty() || bits == 0 {
return;
}
scratch.clear();
scratch.resize(out.len(), 0);
let passes = ((bits + 7) / 8).min(4) as usize;
let mut sorted_in_scratch = false;
for pass in 0..passes {
if sorted_in_scratch {
radix_pass(scratch, out, bits, pass);
} else {
radix_pass(out, scratch, bits, pass);
}
sorted_in_scratch = !sorted_in_scratch;
}
if sorted_in_scratch {
out.clear();
out.extend_from_slice(scratch);
}
}
fn radix_pass(src: &[u32], dst: &mut [u32], bits: u32, pass: usize) {
let shift = pass * 8;
let mask = if shift + 8 >= bits as usize {
(1u32 << ((bits as usize - shift).min(8))) - 1
} else {
0xFF
};
let mut counts = [0u32; 256];
for &key in src {
let digit = ((key >> shift) & mask) as usize;
counts[digit] += 1;
}
let mut offset = 0u32;
for count in &mut counts {
let current = *count;
*count = offset;
offset += current;
}
for &key in src {
let digit = ((key >> shift) & mask) as usize;
let dest = counts[digit] as usize;
dst[dest] = key;
counts[digit] += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_ref_empty() {
assert_eq!(cpu_ref(&[], 32), Vec::<u32>::new());
}
#[test]
fn cpu_ref_single_element() {
assert_eq!(cpu_ref(&[42], 32), vec![42]);
}
#[test]
fn cpu_ref_already_sorted() {
let input = vec![1, 2, 3, 4, 5];
assert_eq!(cpu_ref(&input, 32), input);
}
#[test]
fn cpu_ref_reverse_sorted() {
let input = vec![5, 4, 3, 2, 1];
assert_eq!(cpu_ref(&input, 32), vec![1, 2, 3, 4, 5]);
}
#[test]
fn cpu_ref_into_reuses_buffers() {
let mut out = Vec::with_capacity(16);
let mut scratch = Vec::with_capacity(16);
cpu_ref_into(&[5, 4, 3, 2, 1], 32, &mut out, &mut scratch);
let out_capacity = out.capacity();
let scratch_capacity = scratch.capacity();
assert_eq!(out, vec![1, 2, 3, 4, 5]);
cpu_ref_into(&[3, 1, 2], 32, &mut out, &mut scratch);
assert_eq!(out.capacity(), out_capacity);
assert_eq!(scratch.capacity(), scratch_capacity);
assert_eq!(out, vec![1, 2, 3]);
}
#[test]
fn cpu_ref_stable_sort() {
let input: Vec<u32> = vec![
(2 << 16),
(1 << 16),
(2 << 16) | 1,
(1 << 16) | 1,
(2 << 16) | 2,
];
let sorted = cpu_ref(&input, 32);
let payloads: Vec<u16> = sorted.iter().map(|v| (*v & 0xFFFF) as u16).collect();
assert_eq!(payloads, vec![0, 1, 0, 1, 2]);
}
#[test]
fn cpu_ref_duplicates() {
let input = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5];
let mut expected = input.clone();
expected.sort_unstable();
assert_eq!(cpu_ref(&input, 32), expected);
}
#[test]
fn cpu_ref_partial_bits() {
let input = vec![0x0100, 0x0001, 0x0200, 0x0002];
assert_eq!(cpu_ref(&input, 8), vec![0x0100, 0x0200, 0x0001, 0x0002]);
}
#[test]
fn cpu_ref_bits_zero_is_noop() {
let input = vec![3, 1, 2];
assert_eq!(cpu_ref(&input, 0), vec![3, 1, 2]);
}
#[test]
fn cpu_ref_large_random() {
let input: Vec<u32> = (0..1000u32).map(|i| i.wrapping_mul(0x9E3779B9)).collect();
let mut expected = input.clone();
expected.sort_unstable();
assert_eq!(cpu_ref(&input, 32), expected);
}
#[test]
fn emitted_program_has_expected_buffers() {
let p = radix_sort("in", "out", 128, 32);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["in", "out"]);
}
#[test]
fn emitted_program_small_count_ok() {
let p = radix_sort("in", "out", 1, 32);
assert_eq!(p.workgroup_size, [256, 1, 1]);
}
#[test]
fn bits_over_32_traps() {
let p = radix_sort("in", "out", 10, 33);
assert!(p.stats().trap());
}
#[test]
fn cpu_ref_bits_over_32_clamps_to_full_key_sort() {
assert_eq!(cpu_ref(&[2, 1], 33), vec![1, 2]);
}
}