use vyre::lower::wgsl;
use vyre::ops::decode::hex_decode_strict::{hex_decode_strict, HexDecodeStrict};
use vyre::ops::encode::base64_encode::{base64_encode, Base64Encode};
use vyre::ops::encode::hex_encode_lower::{hex_encode_lower, HexEncodeLower};
#[test]
fn hex_encode_composition_matches_cpu_on_32_random_inputs() {
for input in random_inputs(0x4845_5845, 32) {
assert_eq!(hex_encode_lower(&input), hex_encode_decomposition(&input));
}
}
#[test]
fn base64_encode_composition_matches_cpu_on_32_random_inputs() {
for input in random_inputs(0x6436_3421, 32) {
assert_eq!(base64_encode(&input), base64_encode_decomposition(&input));
}
}
#[test]
fn hex_decode_composition_matches_cpu_on_32_random_inputs() {
for input in random_inputs(0xDEC0_DE32, 32) {
let encoded = hex_encode_lower(&input);
assert_eq!(
hex_decode_strict(&encoded).expect("generated hex must decode"),
hex_decode_decomposition(&encoded)
);
}
}
#[test]
fn encoding_compositions_lower_to_wgsl_without_handwritten_kernels() {
for (name, shader) in [
(
"encode.hex_encode_lower",
wgsl::lower(&HexEncodeLower::SPEC.program().expect("hex encode program"))
.expect("hex encode lowers"),
),
(
"encode.base64_encode",
wgsl::lower(&Base64Encode::SPEC.program().expect("base64 encode program"))
.expect("base64 encode lowers"),
),
(
"decode.hex_decode_strict",
wgsl::lower(&HexDecodeStrict::SPEC.program().expect("hex decode program"))
.expect("hex decode lowers"),
),
] {
assert!(
shader.contains("@compute"),
"{name} must lower through the generic IR path"
);
assert!(
!shader.contains("hex_encode_lower") && !shader.contains("base64_encode"),
"{name} must not embed the removed handwritten kernel name"
);
}
}
fn random_inputs(mut state: u32, count: usize) -> Vec<Vec<u8>> {
let mut out = Vec::with_capacity(count);
for case in 0..count {
let len = case * 7 % 41;
let mut bytes = Vec::with_capacity(len);
for _ in 0..len {
state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
bytes.push((state >> 24) as u8);
}
out.push(bytes);
}
out
}
fn hex_encode_decomposition(input: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(input.len() * 2);
for &byte in input {
out.push(hex_ascii(byte >> 4));
out.push(hex_ascii(byte & 0x0f));
}
out
}
fn hex_ascii(nibble: u8) -> u8 {
if nibble < 10 {
b'0' + nibble
} else {
b'a' + (nibble - 10)
}
}
fn base64_encode_decomposition(input: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(input.len().div_ceil(3) * 4);
for chunk in input.chunks(3) {
let a = chunk[0];
let b = chunk.get(1).copied().unwrap_or(0);
let c = chunk.get(2).copied().unwrap_or(0);
out.push(base64_ascii(a >> 2));
out.push(base64_ascii(((a & 0x03) << 4) | (b >> 4)));
out.push(if chunk.len() > 1 {
base64_ascii(((b & 0x0f) << 2) | (c >> 6))
} else {
b'='
});
out.push(if chunk.len() > 2 {
base64_ascii(c & 0x3f)
} else {
b'='
});
}
out
}
fn base64_ascii(index: u8) -> u8 {
match index {
0..=25 => b'A' + index,
26..=51 => b'a' + (index - 26),
52..=61 => b'0' + (index - 52),
62 => b'+',
_ => b'/',
}
}
fn hex_decode_decomposition(input: &[u8]) -> Vec<u8> {
input
.chunks_exact(2)
.map(|pair| (hex_nibble(pair[0]) << 4) | hex_nibble(pair[1]))
.collect()
}
fn hex_nibble(byte: u8) -> u8 {
match byte {
b'0'..=b'9' => byte - b'0',
b'A'..=b'F' => byte - b'A' + 10,
b'a'..=b'f' => byte - b'a' + 10,
_ => 0,
}
}