vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use core::fmt;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};

// WGSL lowering marker for `decode.hex_decode_strict`.
//
// Not a stub: this is a zero-overhead Category A marker for the real
// `HexDecodeStrict::program()` IR. The generic lowering in
// `core/src/lower/wgsl` emits the compute shader, and the valid-input IR is
// bit-exact against the conform CPU reference.
//
// ```wgsl
// @compute @workgroup_size(64, 1, 1)
// fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
//   let pair = gid.x * 2u;
//   // Validate 0-9/a-f/A-F, combine two nibbles, otherwise write 0u.
// }
// ```

/// Documents the IR-to-WGSL path for `decode.hex_decode_strict`.
///
/// The operation owns no handwritten shader: `HexDecodeStrict::program()`
/// builds pure IR from range checks, shifts, masks, arithmetic, and selects;
/// `core/src/lower/wgsl` lowers that IR into the final kernel without
/// dispatch-specific glue here.
pub const CATEGORY_A_WGSL_MARKER: &str = "decode.hex_decode_strict";

pub fn and(left: Expr, right: Expr) -> Expr {
    Expr::and(left, right)
}

/// Error returned by strict hexadecimal decoding.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DecodeError {
    /// Hex strings must contain exactly two characters per output byte.
    OddLength,
    /// The input contains a non-hexadecimal byte.
    InvalidByte {
        /// Byte offset of the invalid input.
        index: usize,
        /// Invalid byte value.
        byte: u8,
    },
}

/// Decode ASCII hexadecimal bytes.
pub fn hex_decode_strict(input: &[u8]) -> Result<Vec<u8>, DecodeError> {
    if input.len() % 2 != 0 {
        return Err(DecodeError::OddLength);
    }
    let mut out = Vec::with_capacity(input.len() / 2);
    for (pair_index, pair) in input.chunks_exact(2).enumerate() {
        let hi_index = pair_index * 2;
        let hi = hex_value(pair[0]).ok_or(DecodeError::InvalidByte {
            index: hi_index,
            byte: pair[0],
        })?;
        let lo = hex_value(pair[1]).ok_or(DecodeError::InvalidByte {
            index: hi_index + 1,
            byte: pair[1],
        })?;
        out.push((hi << 4) | lo);
    }
    Ok(out)
}

/// Strict hexadecimal decode operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct HexDecodeStrict;

/// `hex_nibble` function.
pub fn hex_nibble(byte: Expr) -> Expr {
    Expr::select(
        is_digit(byte.clone()),
        Expr::sub(byte.clone(), Expr::u32(u32::from(b'0'))),
        Expr::select(
            is_lower_hex(byte.clone()),
            Expr::add(
                Expr::sub(byte.clone(), Expr::u32(u32::from(b'a'))),
                Expr::u32(10),
            ),
            Expr::select(
                is_upper_hex(byte.clone()),
                Expr::add(Expr::sub(byte, Expr::u32(u32::from(b'A'))), Expr::u32(10)),
                Expr::u32(0),
            ),
        ),
    )
}

pub fn hex_value(byte: u8) -> Option<u8> {
    match byte {
        b'0'..=b'9' => Some(byte - b'0'),
        b'A'..=b'F' => Some(byte - b'A' + 10),
        b'a'..=b'f' => Some(byte - b'a' + 10),
        _ => None,
    }
}

impl fmt::Display for DecodeError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::OddLength => write!(f, "Fix: hex input must contain an even number of bytes"),
            Self::InvalidByte { index, byte } => {
                write!(
                    f,
                    "Fix: byte 0x{byte:02x} at offset {index} is not hexadecimal"
                )
            }
        }
    }
}

impl std::error::Error for DecodeError {}

impl HexDecodeStrict {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition_inlinable(
        "decode.hex_decode_strict",
        BYTES_TO_BYTES_INPUTS,
        BYTES_TO_BYTES_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Build the canonical lowerable IR program for paired hex bytes.
    #[must_use]
    pub fn program() -> Program {
        let idx = Expr::var("idx");
        let input_idx = Expr::mul(idx.clone(), Expr::u32(2));
        Program::new(
            vec![
                BufferDecl::read("input", 0, DataType::Bytes),
                BufferDecl::output("out", 1, DataType::Bytes),
            ],
            [64, 1, 1],
            vec![
                Node::let_bind("idx", Expr::gid_x()),
                Node::if_then(
                    Expr::lt(idx.clone(), Expr::buf_len("out")),
                    vec![
                        Node::let_bind("input_idx", input_idx),
                        Node::let_bind("hi_byte", Expr::load("input", Expr::var("input_idx"))),
                        Node::let_bind(
                            "lo_byte",
                            Expr::load("input", Expr::add(Expr::var("input_idx"), Expr::u32(1))),
                        ),
                        Node::let_bind("hi_valid", is_hex_byte(Expr::var("hi_byte"))),
                        Node::let_bind("lo_valid", is_hex_byte(Expr::var("lo_byte"))),
                        Node::let_bind(
                            "even_len",
                            Expr::eq(
                                Expr::rem(Expr::buf_len("input"), Expr::u32(2)),
                                Expr::u32(0),
                            ),
                        ),
                        Node::let_bind("hi", hex_nibble(Expr::var("hi_byte"))),
                        Node::let_bind("lo", hex_nibble(Expr::var("lo_byte"))),
                        Node::store(
                            "out",
                            idx,
                            Expr::select(
                                and(
                                    Expr::var("even_len"),
                                    and(Expr::var("hi_valid"), Expr::var("lo_valid")),
                                ),
                                Expr::bitor(
                                    Expr::shl(Expr::var("hi"), Expr::u32(4)),
                                    Expr::var("lo"),
                                ),
                                Expr::u32(0),
                            ),
                        ),
                    ],
                ),
            ],
        )
    }
}

pub fn in_range(value: Expr, low: u8, high: u8) -> Expr {
    and(
        Expr::le(Expr::u32(u32::from(low)), value.clone()),
        Expr::le(value, Expr::u32(u32::from(high))),
    )
}

pub fn is_digit(byte: Expr) -> Expr {
    in_range(byte, b'0', b'9')
}

pub fn is_hex_byte(byte: Expr) -> Expr {
    or(
        is_digit(byte.clone()),
        or(is_lower_hex(byte.clone()), is_upper_hex(byte)),
    )
}

pub fn is_lower_hex(byte: Expr) -> Expr {
    in_range(byte, b'a', b'f')
}

pub fn is_upper_hex(byte: Expr) -> Expr {
    in_range(byte, b'A', b'F')
}

pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];

pub fn or(left: Expr, right: Expr) -> Expr {
    Expr::or(left, right)
}

// Unit tests.
// Unit tests extracted from `ops/decode/hex_decode_strict/kernel.rs`.

#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
    crate::ops::fixtures::run_committed_kats(
        include_str!("../fixtures/reference-vectors.toml"),
        |case| {
            if case.op.as_deref() != Some("hex_decode_strict") {
                return Ok(());
            }
            let input = crate::ops::fixtures::hex_to_bytes(
                case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
            )?;
            let expected = crate::ops::fixtures::hex_to_bytes(
                case.expected_output_hex
                    .as_ref()
                    .ok_or("Fix: missing expected_output_hex")?,
            )?;
            let ok = case.ok.ok_or("Fix: missing ok")?;
            match (ok, hex_decode_strict(&input)) {
                (true, Ok(actual)) => assert_eq!(actual, expected),
                (false, Err(_)) => {}
                (true, Err(error)) => {
                    return Err(format!(
                        "Fix: valid hex decode vector was rejected: {error}"
                    ));
                }
                (false, Ok(actual)) => {
                    return Err(format!(
                        "Fix: invalid hex decode vector should fail, decoded to {actual:02x?}"
                    ));
                }
            }
            Ok(())
        },
    )
}