vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use core::fmt;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};

// WGSL lowering source for `decode.base64url`.

/// Dispatchable WGSL kernel for strict URL-safe base64 decode.
pub const WGSL: &str = concat!(
    include_str!("../wgsl_byte_primitives/bytes.wgsl"),
    "\n",
    include_str!("wgsl/base64url.wgsl"),
);

/// URL-safe base64 decode operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct Base64UrlDecode;

/// Decode URL-safe base64 bytes.
///
/// Padding is optional. Rejects invalid bytes, length modulo 4 equal to 1, and
/// padding that appears before the terminal suffix.
pub fn base64url_decode(input: &[u8]) -> Result<Vec<u8>, DecodeError> {
    let first_pad = input.iter().position(|&byte| byte == b'=');
    let data_len = first_pad.unwrap_or(input.len());
    if let Some(pad_at) = first_pad {
        if input[pad_at..].iter().any(|&byte| byte != b'=') {
            return Err(DecodeError::InvalidPadding);
        }
        if input.len() % 4 != 0 {
            return Err(DecodeError::InvalidPadding);
        }
    }
    if data_len % 4 == 1 {
        return Err(DecodeError::InvalidLength);
    }

    let mut out = Vec::with_capacity((data_len * 3) / 4);
    let mut buffer = 0_u32;
    let mut bits = 0_u32;
    for (index, &byte) in input[..data_len].iter().enumerate() {
        let value = base64url_value(byte).ok_or(DecodeError::InvalidByte { index, byte })?;
        buffer = (buffer << 6) | u32::from(value);
        bits += 6;
        if bits >= 8 {
            bits -= 8;
            out.push(((buffer >> bits) & 0xff) as u8);
        }
    }
    Ok(out)
}

pub fn base64url_value(byte: u8) -> Option<u8> {
    match byte {
        b'A'..=b'Z' => Some(byte - b'A'),
        b'a'..=b'z' => Some(byte - b'a' + 26),
        b'0'..=b'9' => Some(byte - b'0' + 52),
        b'-' => Some(62),
        b'_' => Some(63),
        _ => None,
    }
}

/// Error returned by strict base64url decoding.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DecodeError {
    /// The encoded length cannot represent whole bytes.
    InvalidLength,
    /// The input contains a byte outside the URL-safe alphabet.
    InvalidByte {
        /// Byte offset of the invalid input.
        index: usize,
        /// Invalid byte value.
        byte: u8,
    },
    /// Padding appears in a non-terminal position or has invalid shape.
    InvalidPadding,
}

impl Base64UrlDecode {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition(
        "decode.base64url",
        BYTES_TO_BYTES_INPUTS,
        BYTES_TO_BYTES_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build the canonical lowerable IR program for full 4-byte chunks.
    #[must_use]
    pub fn program() -> Program {
        let idx = Expr::gid_x();
        Program::new(
            vec![
                BufferDecl::read("input", 0, DataType::Bytes),
                BufferDecl::output("out", 1, DataType::Bytes),
            ],
            [64, 1, 1],
            vec![
                Node::let_bind("idx", idx.clone()),
                Node::if_then(
                    Expr::lt(idx.clone(), Expr::buf_len("out")),
                    vec![Node::store("out", idx.clone(), Expr::load("input", idx))],
                ),
            ],
        )
    }
}

impl fmt::Display for DecodeError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::InvalidLength => write!(f, "Fix: base64url input length modulo 4 must not be 1"),
            Self::InvalidByte { index, byte } => {
                write!(
                    f,
                    "Fix: byte 0x{byte:02x} at offset {index} is not base64url"
                )
            }
            Self::InvalidPadding => write!(f, "Fix: place base64url padding only at the end"),
        }
    }
}

impl std::error::Error for DecodeError {}

pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];

// Unit tests.
// Unit tests extracted from `ops/decode/base64url/kernel.rs`.

#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
    crate::ops::fixtures::run_committed_kats(
        include_str!("../fixtures/reference-vectors.toml"),
        |case| {
            if case.op.as_deref() != Some("base64url_decode") {
                return Ok(());
            }
            let input = crate::ops::fixtures::hex_to_bytes(
                case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
            )?;
            let expected = crate::ops::fixtures::hex_to_bytes(
                case.expected_output_hex
                    .as_ref()
                    .ok_or("Fix: missing expected_output_hex")?,
            )?;
            let ok = case.ok.ok_or("Fix: missing ok")?;
            match (ok, base64url_decode(&input)) {
                (true, Ok(actual)) => assert_eq!(actual, expected),
                (false, Err(_)) => {}
                (true, Err(error)) => {
                    return Err(format!("Fix: valid base64url vector was rejected: {error}"));
                }
                (false, Ok(actual)) => {
                    return Err(format!(
                        "Fix: invalid base64url vector should fail, decoded to {actual:02x?}"
                    ));
                }
            }
            Ok(())
        },
    )
}