runmat-accelerate 0.4.4

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
pub const MATMUL_SMALLK_SHADER_F64: &str = r#"
struct Tensor { data: array<f64>, };
struct Params {
    m: u32,
    n: u32,
    k: u32,
    lda: u32,
    ldb: u32,
    ldc: u32,
    offset_a: u32,
    offset_b: u32,
    offset_out: u32,
    flags: u32,
};

const SMALL_K_MAX: u32 = 8u;

@group(0) @binding(0) var<storage, read> A: Tensor;
@group(0) @binding(1) var<storage, read> B: Tensor;
@group(0) @binding(2) var<storage, read_write> Out: Tensor;
@group(0) @binding(3) var<uniform> params: Params;

@compute @workgroup_size(@MT@, @MT@, 1)
fn main(
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(workgroup_id) wid: vec3<u32>,
) {
    let tile = @MT@u;
    let global_row = wid.y * tile + lid.y;
    let global_col = wid.x * tile + lid.x;

    if global_row >= params.m || global_col >= params.n {
        return;
    }

    let lda = params.lda;
    let ldb = params.ldb;
    let ldc = params.ldc;
    let base_a = params.offset_a;
    let base_b = params.offset_b;
    let base_out = params.offset_out;
    let transpose_a = (params.flags & 1u) != 0u;
    let transpose_b = (params.flags & 2u) != 0u;

    var acc: f64 = 0.0;
    let k = params.k;
    for (var kk: u32 = 0u; kk < SMALL_K_MAX; kk = kk + 1u) {
        if (kk >= k) {
            break;
        }
        var a_idx = base_a + global_row + kk * lda;
        if (transpose_a) {
            a_idx = base_a + kk + global_row * lda;
        }
        var b_idx = base_b + kk + global_col * ldb;
        if (transpose_b) {
            b_idx = base_b + global_col + kk * ldb;
        }
        acc = acc + A.data[a_idx] * B.data[b_idx];
    }

    let out_idx = base_out + global_row + global_col * ldc;
    Out.data[out_idx] = acc;
}
"#;

pub const MATMUL_SMALLK_SHADER_F32: &str = r#"
struct Tensor { data: array<f32>, };
struct Params {
    m: u32,
    n: u32,
    k: u32,
    lda: u32,
    ldb: u32,
    ldc: u32,
    offset_a: u32,
    offset_b: u32,
    offset_out: u32,
    flags: u32,
};

const SMALL_K_MAX: u32 = 8u;

@group(0) @binding(0) var<storage, read> A: Tensor;
@group(0) @binding(1) var<storage, read> B: Tensor;
@group(0) @binding(2) var<storage, read_write> Out: Tensor;
@group(0) @binding(3) var<uniform> params: Params;

@compute @workgroup_size(@MT@, @MT@, 1)
fn main(
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(workgroup_id) wid: vec3<u32>,
) {
    let tile = @MT@u;
    let global_row = wid.y * tile + lid.y;
    let global_col = wid.x * tile + lid.x;

    if global_row >= params.m || global_col >= params.n {
        return;
    }

    let lda = params.lda;
    let ldb = params.ldb;
    let ldc = params.ldc;
    let base_a = params.offset_a;
    let base_b = params.offset_b;
    let base_out = params.offset_out;
    let transpose_a = (params.flags & 1u) != 0u;
    let transpose_b = (params.flags & 2u) != 0u;

    var acc: f32 = 0.0;
    let k = params.k;
    for (var kk: u32 = 0u; kk < SMALL_K_MAX; kk = kk + 1u) {
        if (kk >= k) {
            break;
        }
        var a_idx = base_a + global_row + kk * lda;
        if (transpose_a) {
            a_idx = base_a + kk + global_row * lda;
        }
        var b_idx = base_b + kk + global_col * ldb;
        if (transpose_b) {
            b_idx = base_b + global_col + kk * ldb;
        }
        acc = acc + A.data[a_idx] * B.data[b_idx];
    }

    let out_idx = base_out + global_row + global_col * ldc;
    Out.data[out_idx] = acc;
}
"#;