vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};

// WGSL lowering marker for `encode.base32_encode`.
//
// Not a stub: this is a zero-overhead Category A marker for the real
// `Base32Encode::program()` IR. The generic lowering in `core/src/lower/wgsl`
// emits the compute shader from that IR.
//
// The IR is bit-exact against the conform CPU reference.
//
// ```wgsl
// @compute @workgroup_size(64, 1, 1)
// fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
//   let chunk = gid.x;
//   // Load up to five bytes, compute eight 5-bit indices, map each index
//   // through nested select expressions, and store '=' for missing tails.
// }
// ```

/// RFC 4648 base32 alphabet.
pub const ALPHABET: &[u8; 32] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";

/// Map a 5-bit index to its base32 alphabet character as an IR expression.
pub fn alphabet_char(idx: Expr) -> Expr {
    Expr::select(
        Expr::lt(idx.clone(), Expr::u32(26)),
        Expr::add(Expr::u32(u32::from(b'A')), idx.clone()),
        Expr::add(
            Expr::u32(u32::from(b'2')),
            Expr::sub(idx.clone(), Expr::u32(26)),
        ),
    )
}

/// Encode bytes as canonical padded RFC 4648 base32.
#[must_use]
pub fn base32_encode(input: &[u8]) -> Vec<u8> {
    let mut out = Vec::with_capacity(input.len().div_ceil(5) * 8);
    for chunk in input.chunks(5) {
        let a = chunk[0];
        let b = chunk.get(1).copied().unwrap_or(0);
        let c = chunk.get(2).copied().unwrap_or(0);
        let d = chunk.get(3).copied().unwrap_or(0);
        let e = chunk.get(4).copied().unwrap_or(0);
        out.push(ALPHABET[(a >> 3) as usize]);
        out.push(ALPHABET[(((a & 0x07) << 2) | (b >> 6)) as usize]);
        if chunk.len() > 1 {
            out.push(ALPHABET[((b >> 1) & 0x1f) as usize]);
        } else {
            out.push(b'=');
        }
        if chunk.len() > 1 {
            out.push(ALPHABET[(((b & 0x01) << 4) | (c >> 4)) as usize]);
        } else {
            out.push(b'=');
        }
        if chunk.len() > 2 {
            out.push(ALPHABET[(((c & 0x0f) << 1) | (d >> 7)) as usize]);
        } else {
            out.push(b'=');
        }
        if chunk.len() > 3 {
            out.push(ALPHABET[((d >> 2) & 0x1f) as usize]);
        } else {
            out.push(b'=');
        }
        if chunk.len() > 3 {
            out.push(ALPHABET[(((d & 0x03) << 3) | (e >> 5)) as usize]);
        } else {
            out.push(b'=');
        }
        if chunk.len() > 4 {
            out.push(ALPHABET[(e & 0x1f) as usize]);
        } else {
            out.push(b'=');
        }
    }
    out
}

/// RFC 4648 base32 encode operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct Base32Encode;

/// IR predicate checking whether `input_idx + offset` is within input bounds.
pub fn has_input_byte(offset: u32) -> Expr {
    Expr::lt(
        Expr::add(Expr::var("input_idx"), Expr::u32(offset)),
        Expr::buf_len("input"),
    )
}

impl Base32Encode {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition_inlinable(
        "encode.base32_encode",
        BYTES_TO_BYTES_INPUTS,
        BYTES_TO_BYTES_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build the canonical lowerable IR program.
    #[must_use]
    pub fn program() -> Program {
        let chunk = Expr::var("chunk");
        let input_idx = Expr::mul(chunk.clone(), Expr::u32(5));
        let out_idx = Expr::mul(chunk, Expr::u32(8));
        Program::new(
            vec![
                BufferDecl::read("input", 0, DataType::Bytes),
                BufferDecl::output("out", 1, DataType::Bytes),
            ],
            [64, 1, 1],
            vec![
                Node::let_bind("chunk", Expr::gid_x()),
                Node::if_then(
                    Expr::lt(input_idx.clone(), Expr::buf_len("input")),
                    vec![
                        Node::let_bind("input_idx", input_idx),
                        Node::let_bind("out_idx", out_idx),
                        Node::let_bind("a", Expr::load("input", Expr::var("input_idx"))),
                        Node::let_bind(
                            "b",
                            Expr::load("input", Expr::add(Expr::var("input_idx"), Expr::u32(1))),
                        ),
                        Node::let_bind(
                            "c",
                            Expr::load("input", Expr::add(Expr::var("input_idx"), Expr::u32(2))),
                        ),
                        Node::let_bind(
                            "d",
                            Expr::load("input", Expr::add(Expr::var("input_idx"), Expr::u32(3))),
                        ),
                        Node::let_bind(
                            "e",
                            Expr::load("input", Expr::add(Expr::var("input_idx"), Expr::u32(4))),
                        ),
                        Node::let_bind("has_b", has_input_byte(1)),
                        Node::let_bind("has_c", has_input_byte(2)),
                        Node::let_bind("has_d", has_input_byte(3)),
                        Node::let_bind("has_e", has_input_byte(4)),
                        Node::store(
                            "out",
                            Expr::var("out_idx"),
                            alphabet_char(Expr::shr(Expr::var("a"), Expr::u32(3))),
                        ),
                        Node::store(
                            "out",
                            Expr::add(Expr::var("out_idx"), Expr::u32(1)),
                            alphabet_char(Expr::bitor(
                                Expr::shl(
                                    Expr::bitand(Expr::var("a"), Expr::u32(0x07)),
                                    Expr::u32(2),
                                ),
                                Expr::shr(Expr::var("b"), Expr::u32(6)),
                            )),
                        ),
                        Node::store(
                            "out",
                            Expr::add(Expr::var("out_idx"), Expr::u32(2)),
                            Expr::select(
                                Expr::var("has_b"),
                                alphabet_char(Expr::bitand(
                                    Expr::shr(Expr::var("b"), Expr::u32(1)),
                                    Expr::u32(0x1f),
                                )),
                                Expr::u32(u32::from(b'=')),
                            ),
                        ),
                        Node::store(
                            "out",
                            Expr::add(Expr::var("out_idx"), Expr::u32(3)),
                            Expr::select(
                                Expr::var("has_b"),
                                alphabet_char(Expr::bitor(
                                    Expr::shl(
                                        Expr::bitand(Expr::var("b"), Expr::u32(0x01)),
                                        Expr::u32(4),
                                    ),
                                    Expr::shr(Expr::var("c"), Expr::u32(4)),
                                )),
                                Expr::u32(u32::from(b'=')),
                            ),
                        ),
                        Node::store(
                            "out",
                            Expr::add(Expr::var("out_idx"), Expr::u32(4)),
                            Expr::select(
                                Expr::var("has_c"),
                                alphabet_char(Expr::bitor(
                                    Expr::shl(
                                        Expr::bitand(Expr::var("c"), Expr::u32(0x0f)),
                                        Expr::u32(1),
                                    ),
                                    Expr::shr(Expr::var("d"), Expr::u32(7)),
                                )),
                                Expr::u32(u32::from(b'=')),
                            ),
                        ),
                        Node::store(
                            "out",
                            Expr::add(Expr::var("out_idx"), Expr::u32(5)),
                            Expr::select(
                                Expr::var("has_d"),
                                alphabet_char(Expr::bitand(
                                    Expr::shr(Expr::var("d"), Expr::u32(2)),
                                    Expr::u32(0x1f),
                                )),
                                Expr::u32(u32::from(b'=')),
                            ),
                        ),
                        Node::store(
                            "out",
                            Expr::add(Expr::var("out_idx"), Expr::u32(6)),
                            Expr::select(
                                Expr::var("has_d"),
                                alphabet_char(Expr::bitor(
                                    Expr::shl(
                                        Expr::bitand(Expr::var("d"), Expr::u32(0x03)),
                                        Expr::u32(3),
                                    ),
                                    Expr::shr(Expr::var("e"), Expr::u32(5)),
                                )),
                                Expr::u32(u32::from(b'=')),
                            ),
                        ),
                        Node::store(
                            "out",
                            Expr::add(Expr::var("out_idx"), Expr::u32(7)),
                            Expr::select(
                                Expr::var("has_e"),
                                alphabet_char(Expr::bitand(
                                    Expr::var("e"),
                                    Expr::u32(0x1f),
                                )),
                                Expr::u32(u32::from(b'=')),
                            ),
                        ),
                    ],
                ),
            ],
        )
    }
}

/// Algebraic laws for base32 encode (none declared).
pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];

// Unit tests.
// Unit tests extracted from `ops/encode/base32_encode/kernel.rs`.

#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
    crate::ops::fixtures::run_committed_kats(
        include_str!("../fixtures/reference-vectors.toml"),
        |case| {
            if case.op.as_deref() != Some("base32_encode") {
                return Ok(());
            }
            assert_eq!(
                base32_encode(&crate::ops::fixtures::hex_to_bytes(
                    case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
                )?),
                crate::ops::fixtures::hex_to_bytes(
                    case.expected_output_hex
                        .as_ref()
                        .ok_or("Fix: missing expected_output_hex")?,
                )?
            );
            Ok(())
        },
    )
}