numr 0.5.2

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
// I32 broadcast binary operations

struct BroadcastBinaryParams {
    numel: u32,
    ndim: u32,
}

@group(0) @binding(0) var<storage, read_write> broadcast_a: array<i32>;
@group(0) @binding(1) var<storage, read_write> broadcast_b: array<i32>;
@group(0) @binding(2) var<storage, read_write> broadcast_out: array<i32>;
@group(0) @binding(3) var<storage, read_write> broadcast_a_strides: array<u32>;
@group(0) @binding(4) var<storage, read_write> broadcast_b_strides: array<u32>;
@group(0) @binding(5) var<storage, read_write> broadcast_out_strides: array<u32>;
@group(0) @binding(6) var<uniform> broadcast_params: BroadcastBinaryParams;

@compute @workgroup_size(256)
fn broadcast_add_i32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx >= broadcast_params.numel) { return; }
    var remaining = idx;
    var a_offset: u32 = 0u;
    var b_offset: u32 = 0u;
    for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) {
        let stride = broadcast_out_strides[d];
        let coord = remaining / stride;
        remaining = remaining % stride;
        a_offset = a_offset + coord * broadcast_a_strides[d];
        b_offset = b_offset + coord * broadcast_b_strides[d];
    }
    broadcast_out[idx] = broadcast_a[a_offset] + broadcast_b[b_offset];
}

@compute @workgroup_size(256)
fn broadcast_sub_i32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx >= broadcast_params.numel) { return; }
    var remaining = idx;
    var a_offset: u32 = 0u;
    var b_offset: u32 = 0u;
    for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) {
        let stride = broadcast_out_strides[d];
        let coord = remaining / stride;
        remaining = remaining % stride;
        a_offset = a_offset + coord * broadcast_a_strides[d];
        b_offset = b_offset + coord * broadcast_b_strides[d];
    }
    broadcast_out[idx] = broadcast_a[a_offset] - broadcast_b[b_offset];
}

@compute @workgroup_size(256)
fn broadcast_mul_i32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx >= broadcast_params.numel) { return; }
    var remaining = idx;
    var a_offset: u32 = 0u;
    var b_offset: u32 = 0u;
    for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) {
        let stride = broadcast_out_strides[d];
        let coord = remaining / stride;
        remaining = remaining % stride;
        a_offset = a_offset + coord * broadcast_a_strides[d];
        b_offset = b_offset + coord * broadcast_b_strides[d];
    }
    broadcast_out[idx] = broadcast_a[a_offset] * broadcast_b[b_offset];
}

@compute @workgroup_size(256)
fn broadcast_div_i32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx >= broadcast_params.numel) { return; }
    var remaining = idx;
    var a_offset: u32 = 0u;
    var b_offset: u32 = 0u;
    for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) {
        let stride = broadcast_out_strides[d];
        let coord = remaining / stride;
        remaining = remaining % stride;
        a_offset = a_offset + coord * broadcast_a_strides[d];
        b_offset = b_offset + coord * broadcast_b_strides[d];
    }
    broadcast_out[idx] = broadcast_a[a_offset] / broadcast_b[b_offset];
}

@compute @workgroup_size(256)
fn broadcast_max_i32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx >= broadcast_params.numel) { return; }
    var remaining = idx;
    var a_offset: u32 = 0u;
    var b_offset: u32 = 0u;
    for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) {
        let stride = broadcast_out_strides[d];
        let coord = remaining / stride;
        remaining = remaining % stride;
        a_offset = a_offset + coord * broadcast_a_strides[d];
        b_offset = b_offset + coord * broadcast_b_strides[d];
    }
    broadcast_out[idx] = max(broadcast_a[a_offset], broadcast_b[b_offset]);
}

@compute @workgroup_size(256)
fn broadcast_min_i32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx >= broadcast_params.numel) { return; }
    var remaining = idx;
    var a_offset: u32 = 0u;
    var b_offset: u32 = 0u;
    for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) {
        let stride = broadcast_out_strides[d];
        let coord = remaining / stride;
        remaining = remaining % stride;
        a_offset = a_offset + coord * broadcast_a_strides[d];
        b_offset = b_offset + coord * broadcast_b_strides[d];
    }
    broadcast_out[idx] = min(broadcast_a[a_offset], broadcast_b[b_offset]);
}