vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use vyre::lower::wgsl;
use vyre::ops::decode::hex_decode_strict::{hex_decode_strict, HexDecodeStrict};
use vyre::ops::encode::base64_encode::{base64_encode, Base64Encode};
use vyre::ops::encode::hex_encode_lower::{hex_encode_lower, HexEncodeLower};

#[test]
fn hex_encode_composition_matches_cpu_on_32_random_inputs() {
    for input in random_inputs(0x4845_5845, 32) {
        assert_eq!(hex_encode_lower(&input), hex_encode_decomposition(&input));
    }
}

#[test]
fn base64_encode_composition_matches_cpu_on_32_random_inputs() {
    for input in random_inputs(0x6436_3421, 32) {
        assert_eq!(base64_encode(&input), base64_encode_decomposition(&input));
    }
}

#[test]
fn hex_decode_composition_matches_cpu_on_32_random_inputs() {
    for input in random_inputs(0xDEC0_DE32, 32) {
        let encoded = hex_encode_lower(&input);
        assert_eq!(
            hex_decode_strict(&encoded).expect("generated hex must decode"),
            hex_decode_decomposition(&encoded)
        );
    }
}

#[test]
fn encoding_compositions_lower_to_wgsl_without_handwritten_kernels() {
    for (name, shader) in [
        (
            "encode.hex_encode_lower",
            wgsl::lower(&HexEncodeLower::SPEC.program().expect("hex encode program"))
                .expect("hex encode lowers"),
        ),
        (
            "encode.base64_encode",
            wgsl::lower(&Base64Encode::SPEC.program().expect("base64 encode program"))
                .expect("base64 encode lowers"),
        ),
        (
            "decode.hex_decode_strict",
            wgsl::lower(&HexDecodeStrict::SPEC.program().expect("hex decode program"))
                .expect("hex decode lowers"),
        ),
    ] {
        assert!(
            shader.contains("@compute"),
            "{name} must lower through the generic IR path"
        );
        assert!(
            !shader.contains("hex_encode_lower") && !shader.contains("base64_encode"),
            "{name} must not embed the removed handwritten kernel name"
        );
    }
}

fn random_inputs(mut state: u32, count: usize) -> Vec<Vec<u8>> {
    let mut out = Vec::with_capacity(count);
    for case in 0..count {
        let len = case * 7 % 41;
        let mut bytes = Vec::with_capacity(len);
        for _ in 0..len {
            state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
            bytes.push((state >> 24) as u8);
        }
        out.push(bytes);
    }
    out
}

fn hex_encode_decomposition(input: &[u8]) -> Vec<u8> {
    let mut out = Vec::with_capacity(input.len() * 2);
    for &byte in input {
        out.push(hex_ascii(byte >> 4));
        out.push(hex_ascii(byte & 0x0f));
    }
    out
}

fn hex_ascii(nibble: u8) -> u8 {
    if nibble < 10 {
        b'0' + nibble
    } else {
        b'a' + (nibble - 10)
    }
}

fn base64_encode_decomposition(input: &[u8]) -> Vec<u8> {
    let mut out = Vec::with_capacity(input.len().div_ceil(3) * 4);
    for chunk in input.chunks(3) {
        let a = chunk[0];
        let b = chunk.get(1).copied().unwrap_or(0);
        let c = chunk.get(2).copied().unwrap_or(0);
        out.push(base64_ascii(a >> 2));
        out.push(base64_ascii(((a & 0x03) << 4) | (b >> 4)));
        out.push(if chunk.len() > 1 {
            base64_ascii(((b & 0x0f) << 2) | (c >> 6))
        } else {
            b'='
        });
        out.push(if chunk.len() > 2 {
            base64_ascii(c & 0x3f)
        } else {
            b'='
        });
    }
    out
}

fn base64_ascii(index: u8) -> u8 {
    match index {
        0..=25 => b'A' + index,
        26..=51 => b'a' + (index - 26),
        52..=61 => b'0' + (index - 52),
        62 => b'+',
        _ => b'/',
    }
}

fn hex_decode_decomposition(input: &[u8]) -> Vec<u8> {
    input
        .chunks_exact(2)
        .map(|pair| (hex_nibble(pair[0]) << 4) | hex_nibble(pair[1]))
        .collect()
}

fn hex_nibble(byte: u8) -> u8 {
    match byte {
        b'0'..=b'9' => byte - b'0',
        b'A'..=b'F' => byte - b'A' + 10,
        b'a'..=b'f' => byte - b'a' + 10,
        _ => 0,
    }
}