runmat-accelerate 0.4.4

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
pub const CENTERED_GRAM_SHADER_F32: &str = r#"
struct Matrix {
    data: array<f32>,
};

struct Means {
    data: array<f32>,
};

struct Output {
    data: array<f32>,
};

struct Params {
    rows: u32,
    cols: u32,
    lda: u32,
    ldc: u32,
    offset_matrix: u32,
    offset_means: u32,
    offset_out: u32,
    _pad0: u32,
    denom: vec4<f32>,
    _pad1: vec4<f32>,
    _pad2: vec4<f32>,
};

@group(0) @binding(0) var<storage, read> MatrixTensor: Matrix;
@group(0) @binding(1) var<storage, read> MeansTensor: Means;
@group(0) @binding(2) var<storage, read_write> OutTensor: Output;
@group(0) @binding(3) var<uniform> params: Params;

var<workgroup> tileA: array<array<f32, @MT@>, @MT@>;
var<workgroup> tileB: array<array<f32, @MT@>, @MT@>;

@compute @workgroup_size(@MT@, @MT@, 1)
fn main(
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(workgroup_id) wid: vec3<u32>,
) {
    let tile = @MT@u;
    let col_i = wid.y * tile + lid.y;
    let col_j = wid.x * tile + lid.x;
    let rows = params.rows;
    let cols = params.cols;

    var acc: f32 = 0.0;
    var comp: f32 = 0.0;
    let tiles_k = (rows + tile - 1u) / tile;
    for (var t: u32 = 0u; t < tiles_k; t = t + 1u) {
        let row_left = t * tile + lid.x;
        let row_right = t * tile + lid.y;

        var left_val: f32 = 0.0;
        if (col_i < cols && row_left < rows) {
            let idx = params.offset_matrix + row_left + col_i * params.lda;
            let mean = MeansTensor.data[params.offset_means + col_i];
            left_val = MatrixTensor.data[idx] - mean;
        }

        var right_val: f32 = 0.0;
        if (col_j < cols && row_right < rows) {
            let idx = params.offset_matrix + row_right + col_j * params.lda;
            let mean = MeansTensor.data[params.offset_means + col_j];
            right_val = MatrixTensor.data[idx] - mean;
        }

        tileA[lid.y][lid.x] = left_val;
        tileB[lid.y][lid.x] = right_val;
        workgroupBarrier();

        for (var p: u32 = 0u; p < tile; p = p + 1u) {
            let a_val = tileA[lid.y][p];
            let b_val = tileB[p][lid.x];
            let prod = a_val * b_val;
            let y = prod - comp;
            let t_sum = acc + y;
            comp = (t_sum - acc) - y;
            acc = t_sum;
        }
        workgroupBarrier();
    }

    if (col_i < cols && col_j < cols) {
        var value = (acc + comp) * params.denom.x;
        if (col_i == col_j && value < 0.0 && value > -1.0e-12) {
            value = 0.0;
        }
        let idx_out = params.offset_out + col_i + col_j * params.ldc;
        OutTensor.data[idx_out] = value;
    }
}
"#;

pub const CENTERED_GRAM_SHADER_F64: &str = r#"
struct Matrix {
    data: array<f64>,
};

struct Means {
    data: array<f64>,
};

struct Output {
    data: array<f64>,
};

struct Params {
    rows: u32,
    cols: u32,
    lda: u32,
    ldc: u32,
    offset_matrix: u32,
    offset_means: u32,
    offset_out: u32,
    _pad0: u32,
    denom: vec2<f64>,
    _pad1: vec2<f64>,
    _pad2: vec2<f64>,
};

@group(0) @binding(0) var<storage, read> MatrixTensor: Matrix;
@group(0) @binding(1) var<storage, read> MeansTensor: Means;
@group(0) @binding(2) var<storage, read_write> OutTensor: Output;
@group(0) @binding(3) var<uniform> params: Params;

var<workgroup> tileA: array<array<f64, @MT@>, @MT@>;
var<workgroup> tileB: array<array<f64, @MT@>, @MT@>;

@compute @workgroup_size(@MT@, @MT@, 1)
fn main(
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(workgroup_id) wid: vec3<u32>,
) {
    let tile = @MT@u;
    let col_i = wid.y * tile + lid.y;
    let col_j = wid.x * tile + lid.x;
    let rows = params.rows;
    let cols = params.cols;

    var acc: f64 = 0.0;
    var comp: f64 = 0.0;
    let tiles_k = (rows + tile - 1u) / tile;
    for (var t: u32 = 0u; t < tiles_k; t = t + 1u) {
        let row_left = t * tile + lid.x;
        let row_right = t * tile + lid.y;

        var left_val: f64 = 0.0;
        if (col_i < cols && row_left < rows) {
            let idx = params.offset_matrix + row_left + col_i * params.lda;
            let mean = MeansTensor.data[params.offset_means + col_i];
            left_val = MatrixTensor.data[idx] - mean;
        }

        var right_val: f64 = 0.0;
        if (col_j < cols && row_right < rows) {
            let idx = params.offset_matrix + row_right + col_j * params.lda;
            let mean = MeansTensor.data[params.offset_means + col_j];
            right_val = MatrixTensor.data[idx] - mean;
        }

        tileA[lid.y][lid.x] = left_val;
        tileB[lid.y][lid.x] = right_val;
        workgroupBarrier();

        for (var p: u32 = 0u; p < tile; p = p + 1u) {
            let a_val = tileA[lid.y][p];
            let b_val = tileB[p][lid.x];
            let prod = a_val * b_val;
            let y = prod - comp;
            let t_sum = acc + y;
            comp = (t_sum - acc) - y;
            acc = t_sum;
        }
        workgroupBarrier();
    }

    if (col_i < cols && col_j < cols) {
        var value = (acc + comp) * params.denom.x;
        if (col_i == col_j && value < 0.0 && value > -1.0e-12) {
            value = 0.0;
        }
        let idx_out = params.offset_out + col_i + col_j * params.ldc;
        OutTensor.data[idx_out] = value;
    }
}
"#;