vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{AlgebraicLaw, OpSpec, U32_INPUTS, U32_OUTPUTS};

// Software popcount (population count) via the classic SWAR-with-multiply
// trick. Exists as a second popcount op so `popcount_sw` can be
// cross-validated against the `Expr::popcount` hardware-intrinsic path:
// their bit-exact output must agree across every backend, which is the
// property verified by `verify::properties::tests::popcount_hw_matches_popcount_sw`.

pub const LAWS: &[AlgebraicLaw] = &[
    AlgebraicLaw::Bounded { lo: 0, hi: 32 },
    AlgebraicLaw::Complement {
        complement_op: "primitive.bitwise.not",
        universe: 32,
    },
];

/// Software population-count operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct PopcountSw;

impl PopcountSw {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition_inlinable(
        "primitive.bitwise.popcount_sw",
        U32_INPUTS,
        U32_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build the canonical IR program using the SWAR algorithm
    /// (no hardware popcount intrinsic, just masks/shifts/adds/mul).
    ///
    /// # Examples
    ///
    /// ```
    /// use vyre::ops::primitive::bitwise::popcount_sw::PopcountSw;
    ///
    /// let program = PopcountSw::program();
    /// assert!(!program.entry().is_empty());
    /// ```
    #[must_use]
    pub fn program() -> Program {
        let idx = Expr::var("idx");
        let a = Expr::load("a", idx.clone());
        // t1 = a - ((a >> 1) & 0x55555555)
        let t1 = Expr::sub(
            a.clone(),
            Expr::bitand(Expr::shr(a, Expr::u32(1)), Expr::u32(0x5555_5555)),
        );
        // t2 = (t1 & 0x33333333) + ((t1 >> 2) & 0x33333333)
        let t2 = Expr::add(
            Expr::bitand(t1.clone(), Expr::u32(0x3333_3333)),
            Expr::bitand(Expr::shr(t1, Expr::u32(2)), Expr::u32(0x3333_3333)),
        );
        // t3 = (t2 + (t2 >> 4)) & 0x0f0f0f0f
        let t3 = Expr::bitand(
            Expr::add(t2.clone(), Expr::shr(t2, Expr::u32(4))),
            Expr::u32(0x0f0f_0f0f),
        );
        // result = (t3 * 0x01010101) >> 24 — broadcasts byte-wise popcounts
        // into the high byte via the multiplication, then extracts it.
        let result = Expr::shr(Expr::mul(t3, Expr::u32(0x0101_0101)), Expr::u32(24));

        Program::new(
            vec![
                BufferDecl::read("a", 0, DataType::U32),
                BufferDecl::output("out", 1, DataType::U32),
            ],
            [64, 1, 1],
            vec![
                Node::let_bind("idx", Expr::gid_x()),
                Node::if_then(
                    Expr::lt(idx.clone(), Expr::buf_len("out")),
                    vec![Node::store("out", idx, result)],
                ),
            ],
        )
    }
}

// WGSL lowering marker for `primitive.bitwise.popcount_sw`. The IR
// builders above lower directly via the standard emit_binop paths.
//
// ```wgsl
// let t1 = _vyre_load_a(idx) - ((_vyre_load_a(idx) >> 1u) & 0x55555555u);
// let t2 = (t1 & 0x33333333u) + ((t1 >> 2u) & 0x33333333u);
// let t3 = (t2 + (t2 >> 4u)) & 0x0f0f0f0fu;
// _vyre_store_out(idx, ((t3 * 0x01010101u) >> 24u));
// ```