vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};

// WGSL lowering marker for `encode.base64_encode`.
//
// Not a stub: this is a zero-overhead Category A marker for the real
// `Base64Encode::program()` IR. The generic lowering in `core/src/lower/wgsl`
// emits the compute shader from that IR.
//
// The IR is bit-exact against the conform CPU reference.
//
// ```wgsl
// @compute @workgroup_size(64, 1, 1)
// fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
//   let chunk = gid.x;
//   // Load up to three bytes, compute four 6-bit indices, map each index
//   // through nested select expressions, and store '=' for missing tails.
// }
// ```

/// Documents the IR-to-WGSL path for `encode.base64_encode`.
///
/// The operation owns no handwritten shader: `Base64Encode::program()` builds
/// pure IR from shifts, masks, arithmetic, and selects; `core/src/lower/wgsl`
/// lowers that IR into the final kernel without dispatch-specific glue here.
pub const CATEGORY_A_WGSL_MARKER: &str = "encode.base64_encode";

pub const ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

pub fn alphabet_char(idx: Expr) -> Expr {
    Expr::select(
        Expr::lt(idx.clone(), Expr::u32(26)),
        Expr::add(Expr::u32(u32::from(b'A')), idx.clone()),
        Expr::select(
            Expr::lt(idx.clone(), Expr::u32(52)),
            Expr::add(
                Expr::u32(u32::from(b'a')),
                Expr::sub(idx.clone(), Expr::u32(26)),
            ),
            Expr::select(
                Expr::lt(idx.clone(), Expr::u32(62)),
                Expr::add(
                    Expr::u32(u32::from(b'0')),
                    Expr::sub(idx.clone(), Expr::u32(52)),
                ),
                Expr::select(
                    Expr::eq(idx, Expr::u32(62)),
                    Expr::u32(u32::from(b'+')),
                    Expr::u32(u32::from(b'/')),
                ),
            ),
        ),
    )
}

/// Encode bytes as canonical padded RFC 4648 base64.
#[must_use]
pub fn base64_encode(input: &[u8]) -> Vec<u8> {
    let mut out = Vec::with_capacity(input.len().div_ceil(3) * 4);
    for chunk in input.chunks(3) {
        let a = chunk[0];
        let b = chunk.get(1).copied().unwrap_or(0);
        let c = chunk.get(2).copied().unwrap_or(0);
        out.push(ALPHABET[(a >> 2) as usize]);
        out.push(ALPHABET[(((a & 0x03) << 4) | (b >> 4)) as usize]);
        if chunk.len() > 1 {
            out.push(ALPHABET[(((b & 0x0f) << 2) | (c >> 6)) as usize]);
        } else {
            out.push(b'=');
        }
        if chunk.len() > 2 {
            out.push(ALPHABET[(c & 0x3f) as usize]);
        } else {
            out.push(b'=');
        }
    }
    out
}

/// RFC 4648 base64 encode operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct Base64Encode;

pub fn has_input_byte(offset: u32) -> Expr {
    Expr::lt(
        Expr::add(Expr::var("input_idx"), Expr::u32(offset)),
        Expr::buf_len("input"),
    )
}

impl Base64Encode {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition_inlinable(
        "encode.base64_encode",
        BYTES_TO_BYTES_INPUTS,
        BYTES_TO_BYTES_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build the canonical lowerable IR program.
    #[must_use]
    pub fn program() -> Program {
        let chunk = Expr::var("chunk");
        let input_idx = Expr::mul(chunk.clone(), Expr::u32(3));
        let out_idx = Expr::mul(chunk, Expr::u32(4));
        Program::new(
            vec![
                BufferDecl::read("input", 0, DataType::Bytes),
                BufferDecl::output("out", 1, DataType::Bytes),
            ],
            [64, 1, 1],
            vec![
                Node::let_bind("chunk", Expr::gid_x()),
                Node::if_then(
                    Expr::lt(input_idx.clone(), Expr::buf_len("input")),
                    vec![
                        Node::let_bind("input_idx", input_idx),
                        Node::let_bind("out_idx", out_idx),
                        Node::let_bind("a", Expr::load("input", Expr::var("input_idx"))),
                        Node::let_bind(
                            "b",
                            Expr::load("input", Expr::add(Expr::var("input_idx"), Expr::u32(1))),
                        ),
                        Node::let_bind(
                            "c",
                            Expr::load("input", Expr::add(Expr::var("input_idx"), Expr::u32(2))),
                        ),
                        Node::let_bind("has_b", has_input_byte(1)),
                        Node::let_bind("has_c", has_input_byte(2)),
                        Node::store(
                            "out",
                            Expr::var("out_idx"),
                            alphabet_char(Expr::shr(Expr::var("a"), Expr::u32(2))),
                        ),
                        Node::store(
                            "out",
                            Expr::add(Expr::var("out_idx"), Expr::u32(1)),
                            alphabet_char(Expr::bitor(
                                Expr::shl(
                                    Expr::bitand(Expr::var("a"), Expr::u32(0x03)),
                                    Expr::u32(4),
                                ),
                                Expr::shr(Expr::var("b"), Expr::u32(4)),
                            )),
                        ),
                        Node::store(
                            "out",
                            Expr::add(Expr::var("out_idx"), Expr::u32(2)),
                            Expr::select(
                                Expr::var("has_b"),
                                alphabet_char(Expr::bitor(
                                    Expr::shl(
                                        Expr::bitand(Expr::var("b"), Expr::u32(0x0f)),
                                        Expr::u32(2),
                                    ),
                                    Expr::shr(Expr::var("c"), Expr::u32(6)),
                                )),
                                Expr::u32(u32::from(b'=')),
                            ),
                        ),
                        Node::store(
                            "out",
                            Expr::add(Expr::var("out_idx"), Expr::u32(3)),
                            Expr::select(
                                Expr::var("has_c"),
                                alphabet_char(Expr::bitand(Expr::var("c"), Expr::u32(0x3f))),
                                Expr::u32(u32::from(b'=')),
                            ),
                        ),
                    ],
                ),
            ],
        )
    }
}

pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];

// Unit tests.
// Unit tests extracted from `ops/encode/base64_encode/kernel.rs`.

#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
    crate::ops::fixtures::run_committed_kats(
        include_str!("../fixtures/reference-vectors.toml"),
        |case| {
            if case.op.as_deref() != Some("base64_encode") {
                return Ok(());
            }
            assert_eq!(
                base64_encode(&crate::ops::fixtures::hex_to_bytes(
                    case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
                )?),
                crate::ops::fixtures::hex_to_bytes(
                    case.expected_output_hex.as_ref().ok_or("Fix: missing expected_output_hex")?,
                )?
            );
            Ok(())
        },
    )
}