runmat-accelerate 0.4.4

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
pub const QR_POWER_ITER_CHOL_SHADER: &str = r#"
const MAX_K: u32 = 64u;
const EPS: f32 = 1.0e-6;

struct Matrix {
    data: array<f32>,
};

struct Params {
    cols: u32,
    stride: u32,
    _pad0: vec2<u32>,
}

@group(0) @binding(0)
var<storage, read> Gram : Matrix;

@group(0) @binding(1)
var<storage, read_write> OutR : Matrix;

@group(0) @binding(2)
var<storage, read_write> OutRInv : Matrix;

@group(0) @binding(3)
var<uniform> params : Params;

var<workgroup> R_local : array<array<f32, MAX_K>, MAX_K>;
var<workgroup> RInv_local : array<array<f32, MAX_K>, MAX_K>;

fn gram_at(row: u32, col: u32, stride: u32) -> f32 {
    return Gram.data[row + col * stride];
}

@compute @workgroup_size(1, 1, 1)
fn main(@builtin(local_invocation_index) local_index: u32) {
    if (local_index != 0u) {
        return;
    }

    let k = params.cols;
    if (k == 0u || k > MAX_K) {
        return;
    }
    let stride = params.stride;

    // Initialise local storage
    for (var col: u32 = 0u; col < MAX_K; col = col + 1u) {
        for (var row: u32 = 0u; row < MAX_K; row = row + 1u) {
            R_local[row][col] = 0.0;
            RInv_local[row][col] = 0.0;
        }
    }

    // Cholesky factorisation with compensated subtraction.
    for (var j: u32 = 0u; j < k; j = j + 1u) {
        var sum = gram_at(j, j, stride);
        var c = 0.0;
        for (var p: u32 = 0u; p < j; p = p + 1u) {
            let term = R_local[p][j] * R_local[p][j];
            let y = term - c;
            let t = sum - y;
            c = (t - sum) + y;
            sum = t;
        }
        sum = max(sum, EPS);
        let diag = sqrt(sum);
        R_local[j][j] = diag;

        if (diag > EPS) {
            for (var i: u32 = j + 1u; i < k; i = i + 1u) {
                var off = gram_at(j, i, stride);
                var c_off = 0.0;
                for (var p: u32 = 0u; p < j; p = p + 1u) {
                    let term = R_local[p][j] * R_local[p][i];
                    let y = term - c_off;
                    let t = off - y;
                    c_off = (t - off) + y;
                    off = t;
                }
                R_local[j][i] = off / diag;
            }
        } else {
            for (var i: u32 = j + 1u; i < k; i = i + 1u) {
                R_local[j][i] = 0.0;
            }
        }
        for (var i: u32 = 0u; i < j; i = i + 1u) {
            R_local[j][i] = 0.0;
        }
    }

    // Invert the upper-triangular factor (Gauss-Jordan style).
    for (var j: u32 = 0u; j < k; j = j + 1u) {
        let diag = R_local[j][j];
        let inv_diag = select(0.0, 1.0 / diag, diag > EPS);
        RInv_local[j][j] = inv_diag;

        var row = i32(j) - 1;
        loop {
            if (row < 0) {
                break;
            }
            let row_u = u32(row);
            var sum = 0.0;
            var c_sum = 0.0;
            for (var p: u32 = row_u + 1u; p <= j; p = p + 1u) {
                let term = R_local[row_u][p] * RInv_local[p][j];
                let y = term - c_sum;
                let t = sum + y;
                c_sum = (t - sum) - y;
                sum = t;
            }
            let diag_row = max(R_local[row_u][row_u], EPS);
            RInv_local[row_u][j] = -sum / diag_row;
            row = row - 1;
        }
    }

    // Zero lower triangles explicitly for determinism.
    for (var col: u32 = 0u; col < k; col = col + 1u) {
        for (var row: u32 = col + 1u; row < k; row = row + 1u) {
            RInv_local[row][col] = 0.0;
            R_local[row][col] = 0.0;
        }
    }

    for (var col: u32 = 0u; col < k; col = col + 1u) {
        for (var row: u32 = 0u; row < k; row = row + 1u) {
            let idx = row + col * stride;
            OutR.data[idx] = R_local[row][col];
            OutRInv.data[idx] = RInv_local[row][col];
        }
    }
}
"#;