@group(0) @binding(0) var<storage, read_write> output_buffer: array<u64>;
struct Params {
c_lo: u64,
c_hi: u64,
k_base: u64,
_pad: u64,
}
@group(0) @binding(1) var<uniform> params: Params;
fn mul_hi_u64(a: u64, b: u64) -> u64 {
let a_lo = a & u64(0xFFFFFFFFu);
let a_hi = a >> 32u;
let b_lo = b & u64(0xFFFFFFFFu);
let b_hi = b >> 32u;
let p00 = a_lo * b_lo;
let p01 = a_lo * b_hi;
let p10 = a_hi * b_lo;
let p11 = a_hi * b_hi;
let s_lo = (p00 >> 32u) + (p01 & u64(0xFFFFFFFFu)) + (p10 & u64(0xFFFFFFFFu));
let s_hi = (p01 >> 32u) + (p10 >> 32u) + (s_lo >> 32u);
return p11 + s_hi;
}
fn philox_round(v: vec2<u64>, k: u64) -> vec2<u64> {
// M0 = 0xD2B74407B1CE6E93
let M0 = (u64(0xD2B74407u) << 32u) | u64(0xB1CE6E93u);
let prod_lo = v.x * M0;
let prod_hi = mul_hi_u64(v.x, M0);
let next_v0 = prod_hi ^ v.y ^ k;
let next_v1 = prod_lo;
return vec2<u64>(next_v0, next_v1);
}
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let index = id.x;
let count = arrayLength(&output_buffer);
if u64(index) * u64(2u) >= u64(count) { return; }
var v0 = params.c_lo + u64(index);
var v1 = params.c_hi;
if v0 < params.c_lo {
v1 += u64(1u);
}
var k = params.k_base;
let w0 = (u64(0x9E3779B9u) << 32u) | u64(0x7F4A7C15u);
for (var i = 0; i < 10; i++) {
let res = philox_round(vec2(v0, v1), k);
v0 = res.x;
v1 = res.y;
k += w0;
}
output_buffer[index * 2u] = v0;
if index * 2u + 1u < u32(count) {
output_buffer[index * 2u + 1u] = v1;
}
}