runmat-accelerate 0.4.1

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
pub const TRIU_SHADER_F64: &str = r#"
struct Tensor {
    data: array<f64>,
};

struct TriuParams {
    len: u32,
    start: u32,
    rows: u32,
    cols: u32,
    plane: u32,
    diag_offset: i32,
    _pad0: u32,
    _pad1: u32,
};

@group(0) @binding(0) var<storage, read> input0: Tensor;
@group(0) @binding(1) var<storage, read_write> output: Tensor;
@group(0) @binding(2) var<uniform> params: TriuParams;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    if gid.x >= params.len {
        return;
    }
    let index: u32 = params.start + gid.x;
    let plane = params.plane;
    if plane == 0u {
        output.data[index] = input0.data[index];
        return;
    }
    let rows = params.rows;
    let within = index % plane;
    let row = within % rows;
    let col = within / rows;

    let row_i = i32(row);
    let col_i = i32(col);
    let diff = col_i - row_i;
    if diff < params.diag_offset {
        output.data[index] = 0.0;
    } else {
        output.data[index] = input0.data[index];
    }
}
"#;

pub const TRIU_SHADER_F32: &str = r#"
struct Tensor {
    data: array<f32>,
};

struct TriuParams {
    len: u32,
    start: u32,
    rows: u32,
    cols: u32,
    plane: u32,
    diag_offset: i32,
    _pad0: u32,
    _pad1: u32,
};

@group(0) @binding(0) var<storage, read> input0: Tensor;
@group(0) @binding(1) var<storage, read_write> output: Tensor;
@group(0) @binding(2) var<uniform> params: TriuParams;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    if gid.x >= params.len {
        return;
    }
    let index: u32 = params.start + gid.x;
    let plane = params.plane;
    if plane == 0u {
        output.data[index] = input0.data[index];
        return;
    }
    let rows = params.rows;
    let within = index % plane;
    let row = within % rows;
    let col = within / rows;

    let row_i = i32(row);
    let col_i = i32(col);
    let diff = col_i - row_i;
    if diff < params.diag_offset {
        output.data[index] = 0.0;
    } else {
        output.data[index] = input0.data[index];
    }
}
"#;