numr 0.5.2

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
// Auto-generated scatter operations for u32

const WORKGROUP_SIZE: u32 = 256u;

struct ScatterParams {
    ndim: u32,
    dim: u32,
    src_total: u32,
    _padding: u32,
    output_shape: vec4<u32>,
    output_strides: vec4<u32>,
    src_shape: vec4<u32>,
    src_strides: vec4<u32>,
}

@group(0) @binding(0) var<storage, read_write> src: array<u32>;
@group(0) @binding(1) var<storage, read_write> indices: array<i32>;
@group(0) @binding(2) var<storage, read_write> output: array<u32>;
@group(0) @binding(3) var<uniform> params: ScatterParams;

fn get_shape(arr: vec4<u32>, d: u32) -> u32 {
    if (d == 0u) { return arr.x; }
    else if (d == 1u) { return arr.y; }
    else if (d == 2u) { return arr.z; }
    else { return arr.w; }
}

@compute @workgroup_size(256)
fn scatter_u32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx >= params.src_total) {
        return;
    }

    var remaining = idx;
    var dst_offset: u32 = 0u;

    for (var d: u32 = 0u; d < params.ndim; d = d + 1u) {
        let src_stride = get_shape(params.src_strides, d);
        let coord = remaining / src_stride;
        remaining = remaining % src_stride;

        if (d == params.dim) {
            let index_val = indices[idx];
            let dim_size = get_shape(params.output_shape, d);
            if (index_val < 0 || u32(index_val) >= dim_size) {
                return;
            }
            dst_offset = dst_offset + u32(index_val) * get_shape(params.output_strides, d);
        } else {
            dst_offset = dst_offset + coord * get_shape(params.output_strides, d);
        }
    }

    output[dst_offset] = src[idx];
}

// Copy kernel for initializing output from input
@group(0) @binding(0) var<storage, read_write> copy_src: array<u32>;
@group(0) @binding(1) var<storage, read_write> copy_dst: array<u32>;

struct CopyParams {
    numel: u32,
}

@group(0) @binding(2) var<uniform> copy_params: CopyParams;

@compute @workgroup_size(256)
fn copy_u32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < copy_params.numel) {
        copy_dst[idx] = copy_src[idx];
    }
}