runmat-accelerate 0.4.1

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
pub const STOCHASTIC_EVOLUTION_SHADER_F32: &str = r#"
struct Tensor {
    data: array<f32>,
};

struct StochasticEvolutionParams {
    offset: u32,
    chunk: u32,
    len: u32,
    steps: u32,
    key0: u32,
    key1: u32,
    _pad0: u32,
    _pad1: u32,
    drift: f32,
    scale: f32,
}

@group(0) @binding(0) var<storage, read> Input: Tensor;
@group(0) @binding(1) var<storage, read_write> Output: Tensor;
@group(0) @binding(2) var<uniform> params: StochasticEvolutionParams;

const PHILOX_M0: u32 = 0xD2511F53u;
const PHILOX_M1: u32 = 0xCD9E8D57u;
const PHILOX_W0: u32 = 0x9E3779B9u;
const PHILOX_W1: u32 = 0xBB67AE85u;
const INV_U32: f32 = 1.0 / 4294967296.0;
const TWO_PI: f32 = 6.2831855;
const MIN_UNIFORM: f32 = 1.0e-8;
const ALMOST_ONE: f32 = 0.99999994;

fn mul_hi_u32(a: u32, b: u32) -> u32 {
    let a_hi = a >> 16u;
    let a_lo = a & 0xFFFFu;
    let b_hi = b >> 16u;
    let b_lo = b & 0xFFFFu;
    let p0 = a_lo * b_lo;
    let p1 = a_lo * b_hi;
    let p2 = a_hi * b_lo;
    let p3 = a_hi * b_hi;
    let mid = (p0 >> 16u) + (p1 & 0xFFFFu) + (p2 & 0xFFFFu);
    return p3 + (p1 >> 16u) + (p2 >> 16u) + (mid >> 16u);
}

fn philox_round(counter: vec4<u32>, key: vec2<u32>) -> vec4<u32> {
    let hi0 = mul_hi_u32(PHILOX_M0, counter.x);
    let lo0 = PHILOX_M0 * counter.x;
    let hi1 = mul_hi_u32(PHILOX_M1, counter.z);
    let lo1 = PHILOX_M1 * counter.z;
    return vec4<u32>(
        hi1 ^ counter.y ^ key.x,
        lo1,
        hi0 ^ counter.w ^ key.y,
        lo0,
    );
}

fn philox(counter: vec4<u32>, key: vec2<u32>) -> vec4<u32> {
    var ctr = counter;
    var k = key;
    for (var i: u32 = 0u; i < 10u; i = i + 1u) {
        ctr = philox_round(ctr, k);
        k = vec2<u32>(k.x + PHILOX_W0, k.y + PHILOX_W1);
    }
    return ctr;
}

fn uniform_from_bits(bits: u32) -> f32 {
    let sample = (f32(bits) + 0.5) * INV_U32;
    return clamp(sample, MIN_UNIFORM, ALMOST_ONE);
}

fn normal_from_bits(bits0: u32, bits1: u32) -> f32 {
    let u1 = uniform_from_bits(bits0);
    let u2 = uniform_from_bits(bits1);
    let radius = sqrt(-2.0 * log(u1));
    let angle = TWO_PI * u2;
    return radius * cos(angle);
}

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if idx >= params.chunk {
        return;
    }
    let global_idx = params.offset + idx;
    if global_idx >= params.len {
        return;
    }

    var ctr = vec4<u32>(global_idx, 0u, 0u, 0u);
    let key = vec2<u32>(params.key0, params.key1);
    var steps: u32 = params.steps;
    var current = Input.data[global_idx];

    var iter: u32 = 0u;
    loop {
        if iter >= steps {
            break;
        }
        ctr.y = iter;
        let rnd = philox(ctr, key);
        let noise = normal_from_bits(rnd.x, rnd.y);
        let delta = params.drift + params.scale * noise;
        current = current * exp(delta);
        iter = iter + 1u;
    }

    Output.data[global_idx] = current;
}
"#;

pub const STOCHASTIC_EVOLUTION_SHADER_F64: &str = r#"
struct Tensor {
    data: array<f64>,
};

struct StochasticEvolutionParams {
    offset: u32,
    chunk: u32,
    len: u32,
    steps: u32,
    key0: u32,
    key1: u32,
    _pad0: u32,
    drift: f64,
    scale: f64,
}

@group(0) @binding(0) var<storage, read> Input: Tensor;
@group(0) @binding(1) var<storage, read_write> Output: Tensor;
@group(0) @binding(2) var<uniform> params: StochasticEvolutionParams;

const PHILOX_M0: u32 = 0xD2511F53u;
const PHILOX_M1: u32 = 0xCD9E8D57u;
const PHILOX_W0: u32 = 0x9E3779B9u;
const PHILOX_W1: u32 = 0xBB67AE85u;
const INV_U32: f32 = 1.0 / 4294967296.0;
const TWO_PI: f32 = 6.2831855;
const MIN_UNIFORM: f32 = 1.0e-8;
const ALMOST_ONE: f32 = 0.99999994;

fn mul_hi_u32(a: u32, b: u32) -> u32 {
    let a_hi = a >> 16u;
    let a_lo = a & 0xFFFFu;
    let b_hi = b >> 16u;
    let b_lo = b & 0xFFFFu;
    let p0 = a_lo * b_lo;
    let p1 = a_lo * b_hi;
    let p2 = a_hi * b_lo;
    let p3 = a_hi * b_hi;
    let mid = (p0 >> 16u) + (p1 & 0xFFFFu) + (p2 & 0xFFFFu);
    return p3 + (p1 >> 16u) + (p2 >> 16u) + (mid >> 16u);
}

fn philox_round(counter: vec4<u32>, key: vec2<u32>) -> vec4<u32> {
    let hi0 = mul_hi_u32(PHILOX_M0, counter.x);
    let lo0 = PHILOX_M0 * counter.x;
    let hi1 = mul_hi_u32(PHILOX_M1, counter.z);
    let lo1 = PHILOX_M1 * counter.z;
    return vec4<u32>(
        hi1 ^ counter.y ^ key.x,
        lo1,
        hi0 ^ counter.w ^ key.y,
        lo0,
    );
}

fn philox(counter: vec4<u32>, key: vec2<u32>) -> vec4<u32> {
    var ctr = counter;
    var k = key;
    for (var i: u32 = 0u; i < 10u; i = i + 1u) {
        ctr = philox_round(ctr, k);
        k = vec2<u32>(k.x + PHILOX_W0, k.y + PHILOX_W1);
    }
    return ctr;
}

fn uniform_from_bits(bits: u32) -> f32 {
    let sample = (f32(bits) + 0.5) * INV_U32;
    return clamp(sample, MIN_UNIFORM, ALMOST_ONE);
}

fn normal_from_bits(bits0: u32, bits1: u32) -> f32 {
    let u1 = uniform_from_bits(bits0);
    let u2 = uniform_from_bits(bits1);
    let radius = sqrt(-2.0 * log(u1));
    let angle = TWO_PI * u2;
    return radius * cos(angle);
}

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if idx >= params.chunk {
        return;
    }
    let global_idx = params.offset + idx;
    if global_idx >= params.len {
        return;
    }

    var ctr = vec4<u32>(global_idx, 0u, 0u, 0u);
    let key = vec2<u32>(params.key0, params.key1);
    var steps: u32 = params.steps;
    var current = Input.data[global_idx];

    var iter: u32 = 0u;
    loop {
        if iter >= steps {
            break;
        }
        ctr.y = iter;
        let rnd = philox(ctr, key);
        let noise_f32 = normal_from_bits(rnd.x, rnd.y);
        let noise = f64(noise_f32);
        let delta = params.drift + params.scale * noise;
        current = current * exp(delta);
        iter = iter + 1u;
    }

    Output.data[global_idx] = current;
}
"#;