rng-pack 0.4.0

Random number generator variety pack
@group(0) @binding(0) var<storage, read_write> output_buffer: array<u64>;

struct Params {
    c_lo: u64,
    c_hi: u64,
    k_base: u64,
    _pad: u64,
}
@group(0) @binding(1) var<uniform> params: Params;

fn mul_hi_u64(a: u64, b: u64) -> u64 {
    let a_lo = a & u64(0xFFFFFFFFu);
    let a_hi = a >> 32u;
    let b_lo = b & u64(0xFFFFFFFFu);
    let b_hi = b >> 32u;

    let p00 = a_lo * b_lo;
    let p01 = a_lo * b_hi;
    let p10 = a_hi * b_lo;
    let p11 = a_hi * b_hi;

    let s_lo = (p00 >> 32u) + (p01 & u64(0xFFFFFFFFu)) + (p10 & u64(0xFFFFFFFFu));
    let s_hi = (p01 >> 32u) + (p10 >> 32u) + (s_lo >> 32u);
    return p11 + s_hi;
}

fn philox_round(v: vec2<u64>, k: u64) -> vec2<u64> {
    // M0 = 0xD2B74407B1CE6E93
    let M0 = (u64(0xD2B74407u) << 32u) | u64(0xB1CE6E93u);
    let prod_lo = v.x * M0;
    let prod_hi = mul_hi_u64(v.x, M0);

    let next_v0 = prod_hi ^ v.y ^ k;
    let next_v1 = prod_lo;
    return vec2<u64>(next_v0, next_v1);
}

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
    let index = id.x;
    let count = arrayLength(&output_buffer);
    if u64(index) * u64(2u) >= u64(count) { return; }

    var v0 = params.c_lo + u64(index);
    var v1 = params.c_hi;
    if v0 < params.c_lo {
        v1 += u64(1u);
    }

    var k = params.k_base;
    let w0 = (u64(0x9E3779B9u) << 32u) | u64(0x7F4A7C15u);

    for (var i = 0; i < 10; i++) {
        let res = philox_round(vec2(v0, v1), k);
        v0 = res.x;
        v1 = res.y;
        k += w0;
    }

    output_buffer[index * 2u] = v0;
    if index * 2u + 1u < u32(count) {
        output_buffer[index * 2u + 1u] = v1;
    }
}