runmat-accelerate 0.4.4

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
use crate::backend::wgpu::types::NumericPrecision;

pub fn triangular_linsolve_shader(
    precision: NumericPrecision,
    transpose: bool,
    lower: bool,
) -> String {
    let scalar_ty = match precision {
        NumericPrecision::F64 => "f64",
        NumericPrecision::F32 => "f32",
    };
    let a_at = if transpose {
        "return A.data[col + row * rows];"
    } else {
        "return A.data[row + col * rows];"
    };
    let body = if lower {
        format!(
            "var sum: {scalar_ty} = B.data[idx];
             var k: u32 = 0u;
             loop {{
               if (k >= target_row) {{ break; }}
               sum = sum - a_at(target_row, k, rows) * Prev.data[k + col * rows];
               k = k + 1u;
             }}
             let diag = a_at(target_row, target_row, rows);
             Out.data[idx] = sum / diag;"
        )
    } else {
        format!(
            "var sum: {scalar_ty} = B.data[idx];
             var k: u32 = target_row + 1u;
             loop {{
               if (k >= rows) {{ break; }}
               sum = sum - a_at(target_row, k, rows) * Prev.data[k + col * rows];
               k = k + 1u;
             }}
             let diag = a_at(target_row, target_row, rows);
             Out.data[idx] = sum / diag;"
        )
    };
    format!(
        r#"
struct Tensor {{ data: array<{scalar_ty}>, }};

struct Params {{
    len: u32,
    offset: u32,
    total: u32,
    rows: u32,
    rhs_cols: u32,
    target_row: u32,
    _pad0: u32,
    _pad1: u32,
}};

@group(0) @binding(0) var<storage, read> A: Tensor;
@group(0) @binding(1) var<storage, read> B: Tensor;
@group(0) @binding(2) var<storage, read> Prev: Tensor;
@group(0) @binding(3) var<storage, read_write> Out: Tensor;
@group(0) @binding(4) var<uniform> params: Params;

fn a_at(row: u32, col: u32, rows: u32) -> {scalar_ty} {{
    {a_at}
}}

@compute @workgroup_size(512)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
    let local = gid.x;
    if (local >= params.len) {{
        return;
    }}
    let idx = params.offset + local;
    if (idx >= params.total) {{
        return;
    }}
    let rows: u32 = params.rows;
    let cols_rhs: u32 = params.rhs_cols;
    let target_row: u32 = params.target_row;
    let row = idx % rows;
    let col = idx / rows;
    if (col >= cols_rhs) {{
        return;
    }}
    if (row != target_row) {{
        Out.data[idx] = Prev.data[idx];
        return;
    }}
    {body}
}}
"#
    )
}