runmat-accelerate 0.4.4

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
pub const GRADIENT_SHADER_F64: &str = r#"
struct Tensor {
    data: array<f64>,
};

struct GradientParams {
    stride_before: u32,
    segment_len: u32,
    block: u32,
    total: u32,
    spacing: f64,
    _pad0: f64,
    _pad1: f64,
    _pad2: f64,
};

@group(0) @binding(0) var<storage, read> Input: Tensor;
@group(0) @binding(1) var<storage, read_write> Output: Tensor;
@group(0) @binding(2) var<uniform> params: GradientParams;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx >= params.total) {
        return;
    }
    if (params.segment_len <= 1u) {
        Output.data[idx] = 0.0;
        return;
    }

    let block = params.block;
    let after = idx / block;
    let within = idx % block;
    let before = within % params.stride_before;
    let k = within / params.stride_before;
    let base = after * block;
    let center = base + before + k * params.stride_before;

    if (k == 0u) {
        let right = center + params.stride_before;
        Output.data[idx] = (Input.data[right] - Input.data[center]) / params.spacing;
        return;
    }
    if (k + 1u == params.segment_len) {
        let left = center - params.stride_before;
        Output.data[idx] = (Input.data[center] - Input.data[left]) / params.spacing;
        return;
    }

    let left = center - params.stride_before;
    let right = center + params.stride_before;
    Output.data[idx] = (Input.data[right] - Input.data[left]) / (2.0 * params.spacing);
}
"#;

pub const GRADIENT_SHADER_F32: &str = r#"
struct Tensor {
    data: array<f32>,
};

struct GradientParams {
    meta0: vec4<u32>,
    meta1: vec4<f32>,
};

@group(0) @binding(0) var<storage, read> Input: Tensor;
@group(0) @binding(1) var<storage, read_write> Output: Tensor;
@group(0) @binding(2) var<uniform> params: GradientParams;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    let stride_before = params.meta0.x;
    let segment_len = params.meta0.y;
    let block = params.meta0.z;
    let total = params.meta0.w;
    let spacing = params.meta1.x;

    if (idx >= total) {
        return;
    }
    if (segment_len <= 1u) {
        Output.data[idx] = 0.0;
        return;
    }

    let after = idx / block;
    let within = idx % block;
    let before = within % stride_before;
    let k = within / stride_before;
    let base = after * block;
    let center = base + before + k * stride_before;

    if (k == 0u) {
        let right = center + stride_before;
        Output.data[idx] = (Input.data[right] - Input.data[center]) / spacing;
        return;
    }
    if (k + 1u == segment_len) {
        let left = center - stride_before;
        Output.data[idx] = (Input.data[center] - Input.data[left]) / spacing;
        return;
    }

    let left = center - stride_before;
    let right = center + stride_before;
    Output.data[idx] = (Input.data[right] - Input.data[left]) / (2.0 * spacing);
}
"#;