runmat-accelerate 0.4.5

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
pub const REPMAT_SHADER_F64: &str = r#"
const MAX_RANK: u32 = 128u;

struct PackedValue {
    value: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
};

alias PackedArray = array<PackedValue, MAX_RANK>;

struct Tensor {
    data: array<f64>,
};

struct Params {
    len: u32,
    offset: u32,
    rank: u32,
    _pad: u32,
    base_shape: PackedArray,
    new_shape: PackedArray,
    base_strides: PackedArray,
};

@group(0) @binding(0) var<storage, read> Input: Tensor;
@group(0) @binding(1) var<storage, read_write> Output: Tensor;
@group(0) @binding(2) var<uniform> params: Params;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if idx >= params.len {
        return;
    }
    let global_index = params.offset + idx;

    var remaining = global_index;
    var src_index: u32 = 0u;
    var dim: u32 = 0u;

    loop {
        if dim >= params.rank {
            break;
        }
        let size = params.new_shape[dim].value;
        var coord: u32 = 0u;
        if size != 0u {
            coord = remaining % size;
            remaining = remaining / size;
        }
        let base = params.base_shape[dim].value;
        var orig_coord: u32 = 0u;
        if base != 0u {
            orig_coord = coord % base;
        }
        let stride = params.base_strides[dim].value;
        src_index = src_index + orig_coord * stride;
        dim = dim + 1u;
    }

    Output.data[global_index] = Input.data[src_index];
}
"#;

pub const REPMAT_SHADER_F32: &str = r#"
const MAX_RANK: u32 = 128u;

struct PackedValue {
    value: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
};

alias PackedArray = array<PackedValue, MAX_RANK>;

struct Tensor {
    data: array<f32>,
};

struct Params {
    len: u32,
    offset: u32,
    rank: u32,
    _pad: u32,
    base_shape: PackedArray,
    new_shape: PackedArray,
    base_strides: PackedArray,
};

@group(0) @binding(0) var<storage, read> Input: Tensor;
@group(0) @binding(1) var<storage, read_write> Output: Tensor;
@group(0) @binding(2) var<uniform> params: Params;

@compute @workgroup_size(@WG@)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if idx >= params.len {
        return;
    }
    let global_index = params.offset + idx;

    var remaining = global_index;
    var src_index: u32 = 0u;
    var dim: u32 = 0u;

    loop {
        if dim >= params.rank {
            break;
        }
        let size = params.new_shape[dim].value;
        var coord: u32 = 0u;
        if size != 0u {
            coord = remaining % size;
            remaining = remaining / size;
        }
        let base = params.base_shape[dim].value;
        var orig_coord: u32 = 0u;
        if base != 0u {
            orig_coord = coord % base;
        }
        let stride = params.base_strides[dim].value;
        src_index = src_index + orig_coord * stride;
        dim = dim + 1u;
    }

    Output.data[global_index] = Input.data[src_index];
}
"#;