runmat-accelerate 0.4.5

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
pub const POLYVAL_SHADER_F64: &str = r#"
struct Tensor {
    data: array<f64>,
};

struct Params {
    chunk_len: u32,
    coeff_len: u32,
    offset: u32,
    has_mu: u32,
    mu_mean: f64,
    mu_scale: f64,
};

@group(0) @binding(0) var<storage, read> Coeffs: Tensor;
@group(0) @binding(1) var<storage, read> Points: Tensor;
@group(0) @binding(2) var<storage, read_write> Output: Tensor;
@group(0) @binding(3) var<uniform> params: Params;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let local_idx = gid.x;
    if (local_idx >= params.chunk_len) {
        return;
    }

    let idx = params.offset + local_idx;
    var x = Points.data[idx];
    if (params.has_mu != 0u) {
        x = (x - params.mu_mean) / params.mu_scale;
    }

    var acc = Coeffs.data[0u];
    var k: u32 = 1u;
    loop {
        if (k >= params.coeff_len) {
            break;
        }
        acc = acc * x + Coeffs.data[k];
        k = k + 1u;
    }

    Output.data[idx] = acc;
}
"#;

pub const POLYVAL_SHADER_F32: &str = r#"
struct Tensor {
    data: array<f32>,
};

struct Params {
    chunk_len: u32,
    coeff_len: u32,
    offset: u32,
    has_mu: u32,
    mu_mean: f32,
    mu_scale: f32,
    _pad0: u32,
    _pad1: u32,
};

@group(0) @binding(0) var<storage, read> Coeffs: Tensor;
@group(0) @binding(1) var<storage, read> Points: Tensor;
@group(0) @binding(2) var<storage, read_write> Output: Tensor;
@group(0) @binding(3) var<uniform> params: Params;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let local_idx = gid.x;
    if (local_idx >= params.chunk_len) {
        return;
    }

    let idx = params.offset + local_idx;
    var x = Points.data[idx];
    if (params.has_mu != 0u) {
        x = (x - params.mu_mean) / params.mu_scale;
    }

    var acc = Coeffs.data[0u];
    var k: u32 = 1u;
    loop {
        if (k >= params.coeff_len) {
            break;
        }
        acc = acc * x + Coeffs.data[k];
        k = k + 1u;
    }

    Output.data[idx] = acc;
}
"#;