runmat-accelerate 0.4.4

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
pub const DIAG_FROM_VECTOR_SHADER_F64: &str = r#"
struct Vector {
    data: array<f64>,
};

struct Matrix {
    data: array<f64>,
};

struct DiagVecParams {
    len: u32,
    size: u32,
    offset: i32,
    _pad: u32,
};

@group(0) @binding(0) var<storage, read> Input: Vector;
@group(0) @binding(1) var<storage, read_write> Output: Matrix;
@group(0) @binding(2) var<uniform> params: DiagVecParams;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if idx >= params.len {
        return;
    }
    var row = idx;
    var col = idx;
    if params.offset >= 0 {
        col = idx + u32(params.offset);
    } else {
        let shift = u32(-params.offset);
        row = idx + shift;
    }
    let base = row + col * params.size;
    Output.data[base] = Input.data[idx];
}
"#;

pub const DIAG_FROM_VECTOR_SHADER_F32: &str = r#"
struct Vector {
    data: array<f32>,
};

struct Matrix {
    data: array<f32>,
};

struct DiagVecParams {
    len: u32,
    size: u32,
    offset: i32,
    _pad: u32,
};

@group(0) @binding(0) var<storage, read> Input: Vector;
@group(0) @binding(1) var<storage, read_write> Output: Matrix;
@group(0) @binding(2) var<uniform> params: DiagVecParams;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if idx >= params.len {
        return;
    }
    var row = idx;
    var col = idx;
    if params.offset >= 0 {
        col = idx + u32(params.offset);
    } else {
        let shift = u32(-params.offset);
        row = idx + shift;
    }
    let base = row + col * params.size;
    Output.data[base] = Input.data[idx];
}
"#;

pub const DIAG_EXTRACT_SHADER_F64: &str = r#"
struct Matrix {
    data: array<f64>,
};

struct Vector {
    data: array<f64>,
};

struct DiagExtractParams {
    rows: u32,
    cols: u32,
    offset: i32,
    diag_len: u32,
};

@group(0) @binding(0) var<storage, read> Input: Matrix;
@group(0) @binding(1) var<storage, read_write> Output: Vector;
@group(0) @binding(2) var<uniform> params: DiagExtractParams;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if idx >= params.diag_len {
        return;
    }
    var row = idx;
    var col = idx;
    if params.offset >= 0 {
        col = idx + u32(params.offset);
    } else {
        let shift = u32(-params.offset);
        row = idx + shift;
    }
    let base = row + col * params.rows;
    Output.data[idx] = Input.data[base];
}
"#;

pub const DIAG_EXTRACT_SHADER_F32: &str = r#"
struct Matrix {
    data: array<f32>,
};

struct Vector {
    data: array<f32>,
};

struct DiagExtractParams {
    rows: u32,
    cols: u32,
    offset: i32,
    diag_len: u32,
};

@group(0) @binding(0) var<storage, read> Input: Matrix;
@group(0) @binding(1) var<storage, read_write> Output: Vector;
@group(0) @binding(2) var<uniform> params: DiagExtractParams;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if idx >= params.diag_len {
        return;
    }
    var row = idx;
    var col = idx;
    if params.offset >= 0 {
        col = idx + u32(params.offset);
    } else {
        let shift = u32(-params.offset);
        row = idx + shift;
    }
    let base = row + col * params.rows;
    Output.data[idx] = Input.data[base];
}
"#;