use crate::engine::decode::DecodeRules;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum DecodeFormat {
Base64,
Hex,
Url,
Unicode,
}
impl DecodeFormat {
pub(crate) fn label(self) -> &'static str {
match self {
Self::Base64 => "vyre decode base64",
Self::Hex => "vyre decode hex",
Self::Url => "vyre decode url",
Self::Unicode => "vyre decode unicode",
}
}
pub(crate) fn min_run(self, rules: &DecodeRules) -> u32 {
match self {
Self::Base64 => rules.min_base64_run,
Self::Hex => rules.min_hex_run,
Self::Url | Self::Unicode => 0,
}
}
pub(crate) fn op_id(self) -> &'static str {
match self {
Self::Base64 => vyre::ops::decode::base64::Base64Decode::SPEC.id(),
Self::Hex => vyre::ops::decode::hex::HexDecode::SPEC.id(),
Self::Url => vyre::ops::decode::url::UrlDecode::SPEC.id(),
Self::Unicode => vyre::ops::decode::unicode::UnicodeDecode::SPEC.id(),
}
}
pub(crate) fn wgsl(self) -> String {
match self {
Self::Base64 => [DECODE_WGSL_HEADER, BASE64_WGSL_BODY].concat(),
Self::Hex => [DECODE_WGSL_HEADER, HEX_WGSL_BODY].concat(),
Self::Url => [DECODE_WGSL_HEADER, URL_WGSL_BODY].concat(),
Self::Unicode => [DECODE_WGSL_HEADER, UNICODE_WGSL_BODY].concat(),
}
}
}
pub const DECODE_WGSL_HEADER: &str = r"
pub struct Params {
input_len: u32,
min_run: u32,
max_regions: u32,
output_size: u32,
};
pub struct RegionMeta {
src_offset: u32,
src_len: u32,
dst_offset: u32,
dst_len: u32,
};
@group(0) @binding(0) var<storage, read> input_words: array<u32>;
@group(0) @binding(1) var<storage, read_write> regions: array<RegionMeta>;
@group(0) @binding(2) var<storage, read_write> output_words: array<u32>;
@group(0) @binding(3) var<storage, read_write> counters: array<atomic<u32>>;
@group(0) @binding(4) var<uniform> params: Params;
pub fn read_byte(offset: u32) -> u32 {
let word = input_words[offset / 4u];
let shift = (offset % 4u) * 8u;
return (word >> shift) & 0xffu;
}
pub fn hex_value(byte: u32) -> u32 {
if (byte >= 48u && byte <= 57u) { return byte - 48u; }
if (byte >= 65u && byte <= 70u) { return byte - 55u; }
if (byte >= 97u && byte <= 102u) { return byte - 87u; }
return 0xffffffffu;
}
pub fn emit_region(src_offset: u32, src_len: u32, dst_len: u32, b0: u32, b1: u32, b2: u32) {
let region_index = atomicAdd(&counters[0], 1u);
if (region_index >= params.max_regions) { return; }
let dst_offset = atomicAdd(&counters[1], dst_len);
if (dst_offset + dst_len > params.output_size) { return; }
regions[region_index] = RegionMeta(src_offset, src_len, dst_offset, dst_len);
if (dst_len > 0u) { output_words[dst_offset] = b0; }
if (dst_len > 1u) { output_words[dst_offset + 1u] = b1; }
if (dst_len > 2u) { output_words[dst_offset + 2u] = b2; }
}
";
pub const BASE64_WGSL_BODY: &str = r"
pub fn b64_value(byte: u32) -> u32 {
if (byte >= 65u && byte <= 90u) { return byte - 65u; }
if (byte >= 97u && byte <= 122u) { return byte - 71u; }
if (byte >= 48u && byte <= 57u) { return byte + 4u; }
if (byte == 43u) { return 62u; }
if (byte == 47u) { return 63u; }
return 0xffffffffu;
}
@compute @workgroup_size(64)
pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let offset = gid.x;
if (offset + 3u >= params.input_len) { return; }
let a = b64_value(read_byte(offset));
let b = b64_value(read_byte(offset + 1u));
let c = b64_value(read_byte(offset + 2u));
let d = b64_value(read_byte(offset + 3u));
if (a == 0xffffffffu || b == 0xffffffffu || c == 0xffffffffu || d == 0xffffffffu) { return; }
let out0 = ((a << 2u) | (b >> 4u)) & 0xffu;
let out1 = (((b & 15u) << 4u) | (c >> 2u)) & 0xffu;
let out2 = (((c & 3u) << 6u) | d) & 0xffu;
emit_region(offset, 4u, 3u, out0, out1, out2);
}
";
pub const HEX_WGSL_BODY: &str = r"
@compute @workgroup_size(64)
pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let offset = gid.x;
if (offset + 1u >= params.input_len) { return; }
let hi = hex_value(read_byte(offset));
let lo = hex_value(read_byte(offset + 1u));
if (hi == 0xffffffffu || lo == 0xffffffffu) { return; }
emit_region(offset, 2u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
}
";
pub const URL_WGSL_BODY: &str = r"
@compute @workgroup_size(64)
pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let offset = gid.x;
if (offset + 2u >= params.input_len || read_byte(offset) != 37u) { return; }
let hi = hex_value(read_byte(offset + 1u));
let lo = hex_value(read_byte(offset + 2u));
if (hi == 0xffffffffu || lo == 0xffffffffu) { return; }
emit_region(offset, 3u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
}
";
pub const UNICODE_WGSL_BODY: &str = r"
@compute @workgroup_size(64)
pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let offset = gid.x;
if (offset + 3u >= params.input_len || read_byte(offset) != 92u) { return; }
if (read_byte(offset + 1u) == 120u) {
let hi = hex_value(read_byte(offset + 2u));
let lo = hex_value(read_byte(offset + 3u));
if (hi != 0xffffffffu && lo != 0xffffffffu) {
emit_region(offset, 4u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
}
return;
}
if (offset + 5u >= params.input_len || read_byte(offset + 1u) != 117u) { return; }
let h0 = hex_value(read_byte(offset + 2u));
let h1 = hex_value(read_byte(offset + 3u));
let h2 = hex_value(read_byte(offset + 4u));
let h3 = hex_value(read_byte(offset + 5u));
if (h0 == 0xffffffffu || h1 == 0xffffffffu || h2 == 0xffffffffu || h3 == 0xffffffffu) { return; }
emit_region(offset, 6u, 1u, ((h2 << 4u) | h3) & 0xffu, 0u, 0u);
}
";