vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{AlgebraicLaw, OpSpec, U32_INPUTS, U32_OUTPUTS};

// Bitonic sort for u32 using a workgroup-local butterfly network.

/// Maximum number of elements sortable in one workgroup dispatch.
///
/// WebGPU guarantees at least 256 invocations per workgroup.
pub const MAX_ELEMENTS: u32 = 256;

pub(crate) const LAWS: &[AlgebraicLaw] = &[
    // Sorting an already-sorted array is a no-op.
    AlgebraicLaw::Idempotent,
];

/// In-place bitonic sort of a u32 buffer.
///
/// # Limitations
///
/// The input buffer length must be a power of two and ≤ [`MAX_ELEMENTS`].
/// The dispatch must launch exactly `len` threads in one workgroup.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct BitonicSortU32;

impl BitonicSortU32 {
    /// Declarative operation specification using the default [`MAX_ELEMENTS`].
    pub const SPEC: OpSpec = OpSpec::composition(
        "sort.bitonic_sort_u32",
        U32_INPUTS,
        U32_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build the canonical IR program for [`MAX_ELEMENTS`] elements.
    ///
    /// # Examples
    ///
    /// ```
    /// use vyre::ops::sort::bitonic_sort_u32::BitonicSortU32;
    ///
    /// let program = BitonicSortU32::program();
    /// assert!(!program.entry().is_empty());
    /// ```
    #[must_use]
    pub fn program() -> Program {
        Self::program_with_elements(MAX_ELEMENTS)
    }

    /// Build the IR program for a specific element count.
    ///
    /// # Panics
    ///
    /// Panics if `n` is not a power of two or exceeds [`MAX_ELEMENTS`].
    #[must_use]
    pub fn program_with_elements(n: u32) -> Program {
        assert!(
            n.is_power_of_two() && n <= MAX_ELEMENTS,
            "n must be a power of two and <= {MAX_ELEMENTS}"
        );
        let log2_n = n.trailing_zeros();
        let local_idx = Expr::local_x();
        let n_expr = Expr::u32(n);

        // Load from global storage into workgroup shared memory.
        let load_shared = Node::if_then(
            Expr::lt(local_idx.clone(), n_expr.clone()),
            vec![Node::store(
                "shared",
                local_idx.clone(),
                Expr::load("values", local_idx.clone()),
            )],
        );

        // Butterfly network.
        let stage_var = Expr::var("stage");
        let sub_var = Expr::var("sub");

        let k = Expr::shl(Expr::u32(1), stage_var.clone());
        let j = Expr::shr(k.clone(), Expr::add(sub_var.clone(), Expr::u32(1)));
        let partner = Expr::bitxor(local_idx.clone(), j.clone());

        let ascending = Expr::eq(Expr::bitand(local_idx.clone(), k.clone()), Expr::u32(0));

        let a = Expr::load("shared", local_idx.clone());
        let b = Expr::load("shared", partner.clone());

        let swap_cond = Expr::select(
            ascending.clone(),
            Expr::gt(a.clone(), b.clone()),
            Expr::lt(a.clone(), b.clone()),
        );

        let swap_body = vec![
            Node::store("shared", local_idx.clone(), b.clone()),
            Node::store("shared", partner.clone(), a.clone()),
        ];

        let inner_body = vec![
            Node::if_then(
                Expr::gt(partner.clone(), local_idx.clone()),
                vec![Node::if_then(swap_cond, swap_body)],
            ),
            Node::barrier(),
        ];

        let outer_loop = Node::loop_for(
            "stage",
            Expr::u32(1),
            Expr::u32(log2_n + 1),
            vec![Node::loop_for("sub", Expr::u32(0), stage_var, inner_body)],
        );

        // Store back to global memory.
        let store_global = Node::if_then(
            Expr::lt(local_idx.clone(), n_expr),
            vec![Node::store(
                "values",
                local_idx.clone(),
                Expr::load("shared", local_idx),
            )],
        );

        Program::new(
            vec![
                BufferDecl::read_write("values", 0, DataType::U32),
                BufferDecl::workgroup("shared", n, DataType::U32),
            ],
            [n, 1, 1],
            vec![load_shared, Node::barrier(), outer_loop, store_global],
        )
    }
}

/// CPU reference bitonic sort (ascending).
///
/// # Panics
///
/// Panics if `values.len()` is not a power of two.
pub fn bitonic_sort_u32_reference(values: &mut [u32]) {
    let n = values.len();
    assert!(n.is_power_of_two(), "slice length must be a power of two");
    if n <= 1 {
        return;
    }
    let n = n as u32;
    for stage in 1..=n.trailing_zeros() {
        let k = 1u32 << stage;
        for sub in 0..stage {
            let j = k >> (sub + 1);
            for i in 0..n {
                let l = i ^ j;
                if l > i {
                    let ascending = (i & k) == 0;
                    let a = values[i as usize];
                    let b = values[l as usize];
                    let swap = if ascending { a > b } else { a < b };
                    if swap {
                        values.swap(i as usize, l as usize);
                    }
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    pub(crate) fn reference_sorts_powers_of_two() {
        for n in [2, 4, 8, 16, 32, 64, 128, 256] {
            let mut input: Vec<u32> = (0..n).rev().map(|x| x as u32).collect();
            let mut expected = input.clone();
            expected.sort();
            bitonic_sort_u32_reference(&mut input);
            assert_eq!(input, expected, "reference failed for n={n}");
        }
    }

    #[test]
    pub(crate) fn reference_sorts_random() {
        let mut input = vec![42u32, 7, 1023, 0, 999_999, 1, 8, 3];
        let mut expected = input.clone();
        expected.sort();
        bitonic_sort_u32_reference(&mut input);
        assert_eq!(input, expected);
    }

    #[test]
    pub(crate) fn reference_sorts_already_sorted() {
        let mut input: Vec<u32> = (0..128).map(|x| x as u32).collect();
        let expected = input.clone();
        bitonic_sort_u32_reference(&mut input);
        assert_eq!(input, expected);
    }

    #[test]
    pub(crate) fn reference_sorts_all_equal() {
        let mut input = vec![7u32; 64];
        let expected = input.clone();
        bitonic_sort_u32_reference(&mut input);
        assert_eq!(input, expected);
    }

    #[test]
    pub(crate) fn reference_sorts_single_element() {
        let mut input = vec![42u32];
        let expected = input.clone();
        bitonic_sort_u32_reference(&mut input);
        assert_eq!(input, expected);
    }

    #[test]
    pub(crate) fn ir_program_builds_for_valid_sizes() {
        for n in [2, 4, 8, 16, 32, 64, 128, 256] {
            let prog = BitonicSortU32::program_with_elements(n);
            assert_eq!(prog.workgroup_size(), [n, 1, 1]);
            assert_eq!(prog.buffers().len(), 2);
        }
    }

    #[test]
    pub(crate) fn ir_program_lowers_to_wgsl() {
        let program = BitonicSortU32::program();
        let wgsl = crate::lower::wgsl::lower_anonymous(&program);
        assert!(wgsl.is_ok(), "WGSL lowering failed: {:?}", wgsl.err());
    }

    #[test]
    pub(crate) fn ir_program_wgsl_contains_barriers() {
        let program = BitonicSortU32::program_with_elements(16);
        let wgsl = crate::lower::wgsl::lower_anonymous(&program).expect("lowering");
        assert!(wgsl.contains("storageBarrier()"), "missing storageBarrier");
        assert!(
            wgsl.contains("workgroupBarrier()"),
            "missing workgroupBarrier"
        );
    }

    #[test]
    #[should_panic(expected = "n must be a power of two")]
    pub(crate) fn rejects_non_power_of_two() {
        let _ = BitonicSortU32::program_with_elements(63);
    }

    #[test]
    #[should_panic(expected = "n must be a power of two")]
    pub(crate) fn rejects_zero() {
        let _ = BitonicSortU32::program_with_elements(0);
    }
}

// WGSL lowering marker for `sort.bitonic_sort_u32`.
//
// The operation is a Category A composition expressed directly in vyre IR.
// Lowering is performed automatically by `core/src/lower/wgsl` using the
// loop, barrier, load/store, and bitwise expression emitters.