vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{AlgebraicLaw, OpSpec, U32_INPUTS, U32_OUTPUTS};

// Blelloch-style single-workgroup inclusive prefix sum over u32.


/// Fixed workgroup size for the scan. Input length must not exceed this.
pub const WORKGROUP_SIZE: u32 = 256;

pub const LAWS: &[AlgebraicLaw] = &[
    AlgebraicLaw::Associative,
    AlgebraicLaw::Identity { element: 0 },
];

/// Inclusive prefix sum of a u32 buffer using Blelloch's algorithm.
///
/// # Limits
///
/// This is a single-workgroup implementation. The input buffer must not exceed
/// [`WORKGROUP_SIZE`] elements. The caller must dispatch exactly one workgroup
/// with dimensions `[WORKGROUP_SIZE, 1, 1]`.
///
/// # Algorithm
///
/// 1. Load input into workgroup-shared memory (zero-pad beyond input length).
/// 2. Up-sweep (reduce) phase: build a partial-sum tree in shared memory.
/// 3. Clear the root element to prepare for exclusive scan.
/// 4. Down-sweep phase: propagate partial sums to produce exclusive scan.
/// 5. Add the original value to the exclusive result to obtain inclusive scan.
/// 6. Store result back to the output buffer (only for `lid < n`).
#[derive(Debug, Clone, Copy, Default)]
pub struct PrefixSumInclusiveU32;

impl PrefixSumInclusiveU32 {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition(
        "scan.prefix_sum_inclusive",
        U32_INPUTS,
        U32_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build the canonical IR program.
    #[must_use]
    pub fn program() -> Program {
        let mut entry = Vec::with_capacity(64);

        // lid = local_invocation_id.x
        entry.push(Node::let_bind("lid", Expr::local_x()));
        // n = length of input buffer
        entry.push(Node::let_bind("n", Expr::buf_len("input")));

        // Remember original value; threads beyond input load 0 so they do not
        // affect the scan tree.
        entry.push(Node::let_bind(
            "original",
            Expr::select(
                Expr::lt(Expr::var("lid"), Expr::var("n")),
                Expr::load("input", Expr::var("lid")),
                Expr::u32(0),
            ),
        ));

        // Load into shared memory.
        entry.push(Node::store("shared", Expr::var("lid"), Expr::var("original")));
        entry.push(Node::barrier());

        // Up-sweep phase (unrolled so barriers stay in uniform control flow).
        for stride in [1u32, 2, 4, 8, 16, 32, 64, 128] {
            entry.push(Node::if_then(
                Expr::eq(
                    Expr::rem(Expr::add(Expr::var("lid"), Expr::u32(1)), Expr::u32(stride * 2)),
                    Expr::u32(0),
                ),
                vec![Node::store(
                    "shared",
                    Expr::var("lid"),
                    Expr::add(
                        Expr::load("shared", Expr::var("lid")),
                        Expr::load("shared", Expr::sub(Expr::var("lid"), Expr::u32(stride))),
                    ),
                )],
            ));
            entry.push(Node::barrier());
        }

        // Set root to 0, producing an exclusive-scan basis.
        entry.push(Node::if_then(
            Expr::eq(Expr::var("lid"), Expr::u32(WORKGROUP_SIZE - 1)),
            vec![Node::store("shared", Expr::var("lid"), Expr::u32(0))],
        ));
        entry.push(Node::barrier());

        // Down-sweep phase (unrolled).
        let downsweep_pairs: [(u32, u32); 8] = [
            (128, 1),
            (64, 2),
            (32, 4),
            (16, 8),
            (8, 16),
            (4, 32),
            (2, 64),
            (1, 128),
        ];
        for (offset, d) in downsweep_pairs {
            entry.push(Node::if_then(
                Expr::lt(Expr::var("lid"), Expr::u32(d)),
                vec![
                    Node::let_bind(
                        "ai",
                        Expr::sub(
                            Expr::mul(
                                Expr::u32(offset),
                                Expr::add(Expr::mul(Expr::u32(2), Expr::var("lid")), Expr::u32(1)),
                            ),
                            Expr::u32(1),
                        ),
                    ),
                    Node::let_bind(
                        "bi",
                        Expr::sub(
                            Expr::mul(
                                Expr::u32(offset),
                                Expr::add(Expr::mul(Expr::u32(2), Expr::var("lid")), Expr::u32(2)),
                            ),
                            Expr::u32(1),
                        ),
                    ),
                    Node::let_bind("t", Expr::load("shared", Expr::var("ai"))),
                    Node::store("shared", Expr::var("ai"), Expr::load("shared", Expr::var("bi"))),
                    Node::store(
                        "shared",
                        Expr::var("bi"),
                        Expr::add(Expr::load("shared", Expr::var("bi")), Expr::var("t")),
                    ),
                ],
            ));
            entry.push(Node::barrier());
        }

        // Exclusive scan is now in shared memory. Add original input to get inclusive.
        entry.push(Node::let_bind("exclusive", Expr::load("shared", Expr::var("lid"))));
        entry.push(Node::if_then(
            Expr::lt(Expr::var("lid"), Expr::var("n")),
            vec![Node::store(
                "out",
                Expr::var("lid"),
                Expr::add(Expr::var("exclusive"), Expr::var("original")),
            )],
        ));

        Program::new(
            vec![
                BufferDecl::read("input", 0, DataType::U32),
                BufferDecl::output("out", 1, DataType::U32),
                BufferDecl::workgroup("shared", WORKGROUP_SIZE, DataType::U32),
            ],
            [WORKGROUP_SIZE, 1, 1],
            entry,
        )
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::ir;
    use crate::lower::wgsl;

    #[test]
    pub(crate) fn spec_builds_non_empty_program() {
        let program = PrefixSumInclusiveU32::program();
        assert!(!program.entry().is_empty());
        assert_eq!(program.workgroup_size(), [WORKGROUP_SIZE, 1, 1]);
    }

    #[test]
    pub(crate) fn program_validates_cleanly() {
        let program = PrefixSumInclusiveU32::program();
        let errors = ir::validate(&program);
        assert!(errors.is_empty(), "validation failed: {errors:?}");
    }

    #[test]
    pub(crate) fn program_lowers_to_wgsl() {
        let program = PrefixSumInclusiveU32::program();
        let wgsl = wgsl::lower_anonymous(&program).expect("WGSL lowering must succeed");
        assert!(wgsl.contains("workgroupBarrier"));
        assert!(wgsl.contains("var<workgroup> shared"));
    }
}

// WGSL lowering marker for `scan.prefix_sum_inclusive`.
//
// This operation is a Category A composition built entirely from vyre IR
// primitives (workgroup memory, barriers, arithmetic, and control flow).
// The generic WGSL lowerer emits the Blelloch scan without backend-specific
// intrinsic handling.
//
// The generated shader uses:
// - `var<workgroup> shared: array<u32, 256>`
// - `workgroupBarrier()` between each unrolled tree stride
// - `@workgroup_size(256, 1, 1)`