runmat-accelerate 0.4.5

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
pub const DIFF_SHADER_F64: &str = r#"
struct Tensor {
    data: array<f64>,
};

struct DiffParams {
    stride_before: u32,
    segments: u32,
    segment_len: u32,
    segment_out: u32,
    block: u32,
    total_out: u32,
    total_in: u32,
    _pad: u32,
}

@group(0) @binding(0) var<storage, read> Input: Tensor;
@group(0) @binding(1) var<storage, read_write> Output: Tensor;
@group(0) @binding(2) var<uniform> params: DiffParams;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    if (params.segment_out == 0u) {
        return;
    }
    let idx = gid.x;
    if (idx >= params.total_out) {
        return;
    }
    let segment_idx = idx / params.segment_out;
    if (segment_idx >= params.segments) {
        return;
    }
    let offset = idx % params.segment_out;
    let before = segment_idx % params.stride_before;
    let after = segment_idx / params.stride_before;
    let base = after * params.block;
    let i0 = base + before + offset * params.stride_before;
    let i1 = i0 + params.stride_before;
    if (i1 >= params.total_in || i0 >= params.total_in) {
        return;
    }
    let a = Input.data[i1];
    let b = Input.data[i0];
    Output.data[idx] = a - b;
}
"#;

pub const DIFF_SHADER_F32: &str = r#"
struct Tensor {
    data: array<f32>,
};

struct DiffParams {
    stride_before: u32,
    segments: u32,
    segment_len: u32,
    segment_out: u32,
    block: u32,
    total_out: u32,
    total_in: u32,
    _pad: u32,
}

@group(0) @binding(0) var<storage, read> Input: Tensor;
@group(0) @binding(1) var<storage, read_write> Output: Tensor;
@group(0) @binding(2) var<uniform> params: DiffParams;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    if (params.segment_out == 0u) {
        return;
    }
    let idx = gid.x;
    if (idx >= params.total_out) {
        return;
    }
    let segment_idx = idx / params.segment_out;
    if (segment_idx >= params.segments) {
        return;
    }
    let offset = idx % params.segment_out;
    let before = segment_idx % params.stride_before;
    let after = segment_idx / params.stride_before;
    let base = after * params.block;
    let i0 = base + before + offset * params.stride_before;
    let i1 = i0 + params.stride_before;
    if (i1 >= params.total_in || i0 >= params.total_in) {
        return;
    }
    let a = Input.data[i1];
    let b = Input.data[i0];
    Output.data[idx] = a - b;
}
"#;