numr 0.5.2

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
// Fused GEMM + bias + activation. F32 only.
// C = activation(A @ B + bias)
// activation_type in params: 0=None, 1=ReLU, 2=GELU, 3=SiLU, 4=Sigmoid, 5=Tanh

const TILE_SIZE: u32 = 16u;

var<workgroup> tile_a: array<array<f32, 16>, 16>;
var<workgroup> tile_b: array<array<f32, 16>, 16>;

struct GemmEpilogueParams {
    M: u32,
    K: u32,
    N: u32,
    batch_size: u32,
    activation_type: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
}

@group(0) @binding(0) var<storage, read_write> a: array<f32>;
@group(0) @binding(1) var<storage, read_write> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> bias: array<f32>;
@group(0) @binding(3) var<storage, read_write> c: array<f32>;
@group(0) @binding(4) var<uniform> params: GemmEpilogueParams;

fn apply_activation(x: f32, act_type: u32) -> f32 {
    switch act_type {
        case 1u: {
            return max(x, 0.0);
        }
        case 2u: {
            let s = 0.7978845608;
            let co = 0.044715;
            let inner = s * (x + co * x * x * x);
            return 0.5 * x * (1.0 + tanh(inner));
        }
        case 3u: {
            return x / (1.0 + exp(-x));
        }
        case 4u: {
            return 1.0 / (1.0 + exp(-x));
        }
        case 5u: {
            return tanh(x);
        }
        default: {
            return x;
        }
    }
}

@compute @workgroup_size(16, 16, 1)
fn gemm_bias_act_f32(@builtin(local_invocation_id) local_id: vec3<u32>,
                     @builtin(workgroup_id) group_id: vec3<u32>) {
    let M = params.M;
    let K = params.K;
    let N = params.N;
    let row = group_id.y * TILE_SIZE + local_id.y;
    let col = group_id.x * TILE_SIZE + local_id.x;

    var sum: f32 = 0.0;
    let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;

    for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
        let a_col = t * TILE_SIZE + local_id.x;
        if (row < M && a_col < K) {
            tile_a[local_id.y][local_id.x] = a[row * K + a_col];
        } else {
            tile_a[local_id.y][local_id.x] = 0.0;
        }
        let b_row = t * TILE_SIZE + local_id.y;
        if (b_row < K && col < N) {
            tile_b[local_id.y][local_id.x] = b[b_row * N + col];
        } else {
            tile_b[local_id.y][local_id.x] = 0.0;
        }
        workgroupBarrier();
        for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) {
            sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x];
        }
        workgroupBarrier();
    }

    if (row < M && col < N) {
        c[row * N + col] = apply_activation(sum + bias[col], params.activation_type);
    }
}

@compute @workgroup_size(16, 16, 1)
fn gemm_bias_act_batched_f32(@builtin(local_invocation_id) local_id: vec3<u32>,
                              @builtin(workgroup_id) group_id: vec3<u32>) {
    let M = params.M;
    let K = params.K;
    let N = params.N;
    let batch = group_id.z;
    if (batch >= params.batch_size) { return; }

    let row = group_id.y * TILE_SIZE + local_id.y;
    let col = group_id.x * TILE_SIZE + local_id.x;
    let a_off = batch * M * K;
    let b_off = batch * K * N;
    let c_off = batch * M * N;

    var sum: f32 = 0.0;
    let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;

    for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
        let a_col = t * TILE_SIZE + local_id.x;
        if (row < M && a_col < K) {
            tile_a[local_id.y][local_id.x] = a[a_off + row * K + a_col];
        } else {
            tile_a[local_id.y][local_id.x] = 0.0;
        }
        let b_row = t * TILE_SIZE + local_id.y;
        if (b_row < K && col < N) {
            tile_b[local_id.y][local_id.x] = b[b_off + b_row * N + col];
        } else {
            tile_b[local_id.y][local_id.x] = 0.0;
        }
        workgroupBarrier();
        for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) {
            sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x];
        }
        workgroupBarrier();
    }

    if (row < M && col < N) {
        c[c_off + row * N + col] = apply_activation(sum + bias[col], params.activation_type);
    }
}