vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! WGSL lowering source for `crypto.chacha20_block`.


/// Portable WGSL compute shader for the ChaCha20 block function.
pub const WGSL: &str = r##"@group(0) @binding(0) var<storage, read> input_words: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_words: array<u32>;

fn crypto_chacha20_quarter_round(s_in: array<u32, 16>, a: u32, b: u32, c: u32, d: u32) -> array<u32, 16> {
    var s = s_in;
    s[a] = s[a] + s[b];
    s[d] = s[d] ^ s[a];
    s[d] = (s[d] << 16u) | (s[d] >> 16u);

    s[c] = s[c] + s[d];
    s[b] = s[b] ^ s[c];
    s[b] = (s[b] << 12u) | (s[b] >> 20u);

    s[a] = s[a] + s[b];
    s[d] = s[d] ^ s[a];
    s[d] = (s[d] << 8u) | (s[d] >> 24u);

    s[c] = s[c] + s[d];
    s[b] = s[b] ^ s[c];
    s[b] = (s[b] << 7u) | (s[b] >> 25u);
    return s;
}

@compute @workgroup_size(1, 1, 1)
fn crypto_chacha20_block(@builtin(global_invocation_id) id: vec3<u32>) {
    if (id.x != 0u) { return; }

    var state: array<u32, 16>;
    var orig: array<u32, 16>;
    let words = min(arrayLength(&input_words), 16u);

    for (var i = 0u; i < 16u; i = i + 1u) {
        state[i] = 0u;
        orig[i] = 0u;
    }
    for (var i = 0u; i < words; i = i + 1u) {
        let w = input_words[i];
        state[i] = w;
        orig[i] = w;
    }

    for (var r = 0u; r < 10u; r = r + 1u) {
        state = crypto_chacha20_quarter_round(state, 0u, 4u, 8u, 12u);
        state = crypto_chacha20_quarter_round(state, 1u, 5u, 9u, 13u);
        state = crypto_chacha20_quarter_round(state, 2u, 6u, 10u, 14u);
        state = crypto_chacha20_quarter_round(state, 3u, 7u, 11u, 15u);

        state = crypto_chacha20_quarter_round(state, 0u, 5u, 10u, 15u);
        state = crypto_chacha20_quarter_round(state, 1u, 6u, 11u, 12u);
        state = crypto_chacha20_quarter_round(state, 2u, 7u, 8u, 13u);
        state = crypto_chacha20_quarter_round(state, 3u, 4u, 9u, 14u);
    }

    for (var i = 0u; i < 16u; i = i + 1u) {
        output_words[i] = state[i] + orig[i];
    }
}"##;