vyre-wgpu 0.1.0

wgpu backend for vyre IR — implements VyreBackend, owns GPU runtime, buffer pool, pipeline cache
Documentation
//! Supported decode formats and shader source selection.

use crate::engine::decode::DecodeRules;




/// Supported decode format.
///
/// This enum is `#[non_exhaustive]` to allow adding new formats (like
/// Base32 or quoted-printable) without breaking consumers.
///
/// # Examples
///
/// ```
/// use vyre_wgpu::engine::decode::DecodeFormat;
///
/// assert!(matches!(DecodeFormat::Base64, DecodeFormat::Base64));
/// ```
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum DecodeFormat {
    /// RFC 4648 base64.
    Base64,
    /// Hexadecimal bytes.
    Hex,
    /// URL percent-encoded bytes.
    Url,
    /// `\xNN` and `\uNNNN` escape sequences.
    Unicode,
}


impl DecodeFormat {
    /// Fixes architecture_deep_audit.md#10/#13: crate-private visibility avoids
    /// restricted visibility audit blind spots.
    pub(crate) fn label(self) -> &'static str {
        match self {
            Self::Base64 => "vyre decode base64",
            Self::Hex => "vyre decode hex",
            Self::Url => "vyre decode url",
            Self::Unicode => "vyre decode unicode",
        }
    }

    /// Fixes architecture_deep_audit.md#10/#13: crate-private visibility avoids
    /// restricted visibility audit blind spots.
    pub(crate) fn min_run(self, rules: &DecodeRules) -> u32 {
        match self {
            Self::Base64 => rules.min_base64_run,
            Self::Hex => rules.min_hex_run,
            Self::Url | Self::Unicode => 0,
        }
    }

    /// Fixes architecture_deep_audit.md#10/#13: crate-private visibility avoids
    /// restricted visibility audit blind spots.
    pub(crate) fn op_id(self) -> &'static str {
        match self {
            Self::Base64 => vyre::ops::decode::base64::Base64Decode::SPEC.id(),
            Self::Hex => vyre::ops::decode::hex::HexDecode::SPEC.id(),
            Self::Url => vyre::ops::decode::url::UrlDecode::SPEC.id(),
            Self::Unicode => vyre::ops::decode::unicode::UnicodeDecode::SPEC.id(),
        }
    }

    /// Fixes architecture_deep_audit.md#10/#13: crate-private visibility avoids
    /// restricted visibility audit blind spots.
    pub(crate) fn wgsl(self) -> String {
        match self {
            Self::Base64 => [DECODE_WGSL_HEADER, BASE64_WGSL_BODY].concat(),
            Self::Hex => [DECODE_WGSL_HEADER, HEX_WGSL_BODY].concat(),
            Self::Url => [DECODE_WGSL_HEADER, URL_WGSL_BODY].concat(),
            Self::Unicode => [DECODE_WGSL_HEADER, UNICODE_WGSL_BODY].concat(),
        }
    }
}


/// `DECODE_WGSL_HEADER` constant.
pub const DECODE_WGSL_HEADER: &str = r"
pub struct Params {
    input_len: u32,
    min_run: u32,
    max_regions: u32,
    output_size: u32,
};

pub struct RegionMeta {
    src_offset: u32,
    src_len: u32,
    dst_offset: u32,
    dst_len: u32,
};

@group(0) @binding(0) var<storage, read> input_words: array<u32>;
@group(0) @binding(1) var<storage, read_write> regions: array<RegionMeta>;
@group(0) @binding(2) var<storage, read_write> output_words: array<u32>;
@group(0) @binding(3) var<storage, read_write> counters: array<atomic<u32>>;
@group(0) @binding(4) var<uniform> params: Params;

pub fn read_byte(offset: u32) -> u32 {
    let word = input_words[offset / 4u];
    let shift = (offset % 4u) * 8u;
    return (word >> shift) & 0xffu;
}

pub fn hex_value(byte: u32) -> u32 {
    if (byte >= 48u && byte <= 57u) { return byte - 48u; }
    if (byte >= 65u && byte <= 70u) { return byte - 55u; }
    if (byte >= 97u && byte <= 102u) { return byte - 87u; }
    return 0xffffffffu;
}

pub fn emit_region(src_offset: u32, src_len: u32, dst_len: u32, b0: u32, b1: u32, b2: u32) {
    let region_index = atomicAdd(&counters[0], 1u);
    if (region_index >= params.max_regions) { return; }
    let dst_offset = atomicAdd(&counters[1], dst_len);
    if (dst_offset + dst_len > params.output_size) { return; }
    regions[region_index] = RegionMeta(src_offset, src_len, dst_offset, dst_len);
    if (dst_len > 0u) { output_words[dst_offset] = b0; }
    if (dst_len > 1u) { output_words[dst_offset + 1u] = b1; }
    if (dst_len > 2u) { output_words[dst_offset + 2u] = b2; }
}
";


/// `BASE64_WGSL_BODY` constant.
pub const BASE64_WGSL_BODY: &str = r"
pub fn b64_value(byte: u32) -> u32 {
    if (byte >= 65u && byte <= 90u) { return byte - 65u; }
    if (byte >= 97u && byte <= 122u) { return byte - 71u; }
    if (byte >= 48u && byte <= 57u) { return byte + 4u; }
    if (byte == 43u) { return 62u; }
    if (byte == 47u) { return 63u; }
    return 0xffffffffu;
}

@compute @workgroup_size(64)
pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let offset = gid.x;
    if (offset + 3u >= params.input_len) { return; }
    let a = b64_value(read_byte(offset));
    let b = b64_value(read_byte(offset + 1u));
    let c = b64_value(read_byte(offset + 2u));
    let d = b64_value(read_byte(offset + 3u));
    if (a == 0xffffffffu || b == 0xffffffffu || c == 0xffffffffu || d == 0xffffffffu) { return; }
    let out0 = ((a << 2u) | (b >> 4u)) & 0xffu;
    let out1 = (((b & 15u) << 4u) | (c >> 2u)) & 0xffu;
    let out2 = (((c & 3u) << 6u) | d) & 0xffu;
    emit_region(offset, 4u, 3u, out0, out1, out2);
}
";


/// `HEX_WGSL_BODY` constant.
pub const HEX_WGSL_BODY: &str = r"
@compute @workgroup_size(64)
pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let offset = gid.x;
    if (offset + 1u >= params.input_len) { return; }
    let hi = hex_value(read_byte(offset));
    let lo = hex_value(read_byte(offset + 1u));
    if (hi == 0xffffffffu || lo == 0xffffffffu) { return; }
    emit_region(offset, 2u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
}
";


/// `URL_WGSL_BODY` constant.
pub const URL_WGSL_BODY: &str = r"
@compute @workgroup_size(64)
pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let offset = gid.x;
    if (offset + 2u >= params.input_len || read_byte(offset) != 37u) { return; }
    let hi = hex_value(read_byte(offset + 1u));
    let lo = hex_value(read_byte(offset + 2u));
    if (hi == 0xffffffffu || lo == 0xffffffffu) { return; }
    emit_region(offset, 3u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
}
";


/// `UNICODE_WGSL_BODY` constant.
pub const UNICODE_WGSL_BODY: &str = r"
@compute @workgroup_size(64)
pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let offset = gid.x;
    if (offset + 3u >= params.input_len || read_byte(offset) != 92u) { return; }
    if (read_byte(offset + 1u) == 120u) {
        let hi = hex_value(read_byte(offset + 2u));
        let lo = hex_value(read_byte(offset + 3u));
        if (hi != 0xffffffffu && lo != 0xffffffffu) {
            emit_region(offset, 4u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
        }
        return;
    }
    if (offset + 5u >= params.input_len || read_byte(offset + 1u) != 117u) { return; }
    let h0 = hex_value(read_byte(offset + 2u));
    let h1 = hex_value(read_byte(offset + 3u));
    let h2 = hex_value(read_byte(offset + 4u));
    let h3 = hex_value(read_byte(offset + 5u));
    if (h0 == 0xffffffffu || h1 == 0xffffffffu || h2 == 0xffffffffu || h3 == 0xffffffffu) { return; }
    emit_region(offset, 6u, 1u, ((h2 << 4u) | h3) & 0xffu, 0u, 0u);
}
";