use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{AlgebraicLaw, OpSpec, U32_INPUTS, U32_OUTPUTS};
pub const MAX_ELEMENTS: u32 = 256;
pub(crate) const LAWS: &[AlgebraicLaw] = &[
AlgebraicLaw::Idempotent,
];
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct BitonicSortU32;
impl BitonicSortU32 {
pub const SPEC: OpSpec = OpSpec::composition(
"sort.bitonic_sort_u32",
U32_INPUTS,
U32_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
Self::program_with_elements(MAX_ELEMENTS)
}
#[must_use]
pub fn program_with_elements(n: u32) -> Program {
assert!(
n.is_power_of_two() && n <= MAX_ELEMENTS,
"n must be a power of two and <= {MAX_ELEMENTS}"
);
let log2_n = n.trailing_zeros();
let local_idx = Expr::local_x();
let n_expr = Expr::u32(n);
let load_shared = Node::if_then(
Expr::lt(local_idx.clone(), n_expr.clone()),
vec![Node::store(
"shared",
local_idx.clone(),
Expr::load("values", local_idx.clone()),
)],
);
let stage_var = Expr::var("stage");
let sub_var = Expr::var("sub");
let k = Expr::shl(Expr::u32(1), stage_var.clone());
let j = Expr::shr(k.clone(), Expr::add(sub_var.clone(), Expr::u32(1)));
let partner = Expr::bitxor(local_idx.clone(), j.clone());
let ascending = Expr::eq(Expr::bitand(local_idx.clone(), k.clone()), Expr::u32(0));
let a = Expr::load("shared", local_idx.clone());
let b = Expr::load("shared", partner.clone());
let swap_cond = Expr::select(
ascending.clone(),
Expr::gt(a.clone(), b.clone()),
Expr::lt(a.clone(), b.clone()),
);
let swap_body = vec![
Node::store("shared", local_idx.clone(), b.clone()),
Node::store("shared", partner.clone(), a.clone()),
];
let inner_body = vec![
Node::if_then(
Expr::gt(partner.clone(), local_idx.clone()),
vec![Node::if_then(swap_cond, swap_body)],
),
Node::barrier(),
];
let outer_loop = Node::loop_for(
"stage",
Expr::u32(1),
Expr::u32(log2_n + 1),
vec![Node::loop_for("sub", Expr::u32(0), stage_var, inner_body)],
);
let store_global = Node::if_then(
Expr::lt(local_idx.clone(), n_expr),
vec![Node::store(
"values",
local_idx.clone(),
Expr::load("shared", local_idx),
)],
);
Program::new(
vec![
BufferDecl::read_write("values", 0, DataType::U32),
BufferDecl::workgroup("shared", n, DataType::U32),
],
[n, 1, 1],
vec![load_shared, Node::barrier(), outer_loop, store_global],
)
}
}
pub fn bitonic_sort_u32_reference(values: &mut [u32]) {
let n = values.len();
assert!(n.is_power_of_two(), "slice length must be a power of two");
if n <= 1 {
return;
}
let n = n as u32;
for stage in 1..=n.trailing_zeros() {
let k = 1u32 << stage;
for sub in 0..stage {
let j = k >> (sub + 1);
for i in 0..n {
let l = i ^ j;
if l > i {
let ascending = (i & k) == 0;
let a = values[i as usize];
let b = values[l as usize];
let swap = if ascending { a > b } else { a < b };
if swap {
values.swap(i as usize, l as usize);
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
pub(crate) fn reference_sorts_powers_of_two() {
for n in [2, 4, 8, 16, 32, 64, 128, 256] {
let mut input: Vec<u32> = (0..n).rev().map(|x| x as u32).collect();
let mut expected = input.clone();
expected.sort();
bitonic_sort_u32_reference(&mut input);
assert_eq!(input, expected, "reference failed for n={n}");
}
}
#[test]
pub(crate) fn reference_sorts_random() {
let mut input = vec![42u32, 7, 1023, 0, 999_999, 1, 8, 3];
let mut expected = input.clone();
expected.sort();
bitonic_sort_u32_reference(&mut input);
assert_eq!(input, expected);
}
#[test]
pub(crate) fn reference_sorts_already_sorted() {
let mut input: Vec<u32> = (0..128).map(|x| x as u32).collect();
let expected = input.clone();
bitonic_sort_u32_reference(&mut input);
assert_eq!(input, expected);
}
#[test]
pub(crate) fn reference_sorts_all_equal() {
let mut input = vec![7u32; 64];
let expected = input.clone();
bitonic_sort_u32_reference(&mut input);
assert_eq!(input, expected);
}
#[test]
pub(crate) fn reference_sorts_single_element() {
let mut input = vec![42u32];
let expected = input.clone();
bitonic_sort_u32_reference(&mut input);
assert_eq!(input, expected);
}
#[test]
pub(crate) fn ir_program_builds_for_valid_sizes() {
for n in [2, 4, 8, 16, 32, 64, 128, 256] {
let prog = BitonicSortU32::program_with_elements(n);
assert_eq!(prog.workgroup_size(), [n, 1, 1]);
assert_eq!(prog.buffers().len(), 2);
}
}
#[test]
pub(crate) fn ir_program_lowers_to_wgsl() {
let program = BitonicSortU32::program();
let wgsl = crate::lower::wgsl::lower_anonymous(&program);
assert!(wgsl.is_ok(), "WGSL lowering failed: {:?}", wgsl.err());
}
#[test]
pub(crate) fn ir_program_wgsl_contains_barriers() {
let program = BitonicSortU32::program_with_elements(16);
let wgsl = crate::lower::wgsl::lower_anonymous(&program).expect("lowering");
assert!(wgsl.contains("storageBarrier()"), "missing storageBarrier");
assert!(
wgsl.contains("workgroupBarrier()"),
"missing workgroupBarrier"
);
}
#[test]
#[should_panic(expected = "n must be a power of two")]
pub(crate) fn rejects_non_power_of_two() {
let _ = BitonicSortU32::program_with_elements(63);
}
#[test]
#[should_panic(expected = "n must be a power of two")]
pub(crate) fn rejects_zero() {
let _ = BitonicSortU32::program_with_elements(0);
}
}