vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};

// WGSL lowering marker for `encode.hex_encode_lower`.
//
// Not a stub: this is a zero-overhead Category A marker for the real
// `HexEncodeLower::program()` IR. The generic lowering in
// `core/src/lower/wgsl` emits the compute shader from that IR.
//
// The IR is bit-exact against the conform CPU reference.
//
// ```wgsl
// @compute @workgroup_size(64, 1, 1)
// fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
//   let byte = input[gid.x];
//   // Store lowercase ASCII for byte >> 4 and byte & 0x0f.
// }
// ```

/// Documents the IR-to-WGSL path for `encode.hex_encode_lower`.
///
/// The operation owns no handwritten shader: `HexEncodeLower::program()` builds
/// pure IR from shifts, masks, arithmetic, and selects; `core/src/lower/wgsl`
/// lowers that IR into the final kernel without dispatch-specific glue here.
pub const CATEGORY_A_WGSL_MARKER: &str = "encode.hex_encode_lower";

pub const HEX: &[u8; 16] = b"0123456789abcdef";

pub fn hex_char(nibble: Expr) -> Expr {
    Expr::select(
        Expr::lt(nibble.clone(), Expr::u32(10)),
        Expr::add(Expr::u32(u32::from(b'0')), nibble.clone()),
        Expr::add(Expr::u32(u32::from(b'a')), Expr::sub(nibble, Expr::u32(10))),
    )
}

/// Encode bytes as lowercase hexadecimal ASCII.
#[must_use]
pub fn hex_encode_lower(input: &[u8]) -> Vec<u8> {
    let mut out = Vec::with_capacity(input.len() * 2);
    for &byte in input {
        out.push(HEX[(byte >> 4) as usize]);
        out.push(HEX[(byte & 0x0f) as usize]);
    }
    out
}

/// Lowercase hexadecimal encode operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct HexEncodeLower;

impl HexEncodeLower {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition_inlinable(
        "encode.hex_encode_lower",
        BYTES_TO_BYTES_INPUTS,
        BYTES_TO_BYTES_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build the canonical lowerable IR program.
    #[must_use]
    pub fn program() -> Program {
        let idx = Expr::var("idx");
        let out_idx = Expr::mul(idx.clone(), Expr::u32(2));
        Program::new(
            vec![
                BufferDecl::read("input", 0, DataType::Bytes),
                BufferDecl::output("out", 1, DataType::Bytes),
            ],
            [64, 1, 1],
            vec![
                Node::let_bind("idx", Expr::gid_x()),
                Node::if_then(
                    Expr::lt(idx.clone(), Expr::buf_len("input")),
                    vec![
                        Node::let_bind("byte", Expr::load("input", idx)),
                        Node::let_bind("out_idx", out_idx),
                        Node::store(
                            "out",
                            Expr::var("out_idx"),
                            hex_char(Expr::shr(Expr::var("byte"), Expr::u32(4))),
                        ),
                        Node::store(
                            "out",
                            Expr::add(Expr::var("out_idx"), Expr::u32(1)),
                            hex_char(Expr::bitand(Expr::var("byte"), Expr::u32(0x0f))),
                        ),
                    ],
                ),
            ],
        )
    }
}

pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];

// Unit tests.
// Unit tests extracted from `ops/encode/hex_encode_lower/kernel.rs`.

#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
    crate::ops::fixtures::run_committed_kats(
        include_str!("../fixtures/reference-vectors.toml"),
        |case| {
            if case.op.as_deref() != Some("hex_encode_lower") {
                return Ok(());
            }
            assert_eq!(
                hex_encode_lower(&crate::ops::fixtures::hex_to_bytes(
                    case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
                )?),
                crate::ops::fixtures::hex_to_bytes(
                    case.expected_output_hex.as_ref().ok_or("Fix: missing expected_output_hex")?,
                )?
            );
            Ok(())
        },
    )
}