numr 0.5.1

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
// Fused activation-mul WGSL shaders (F32 only)
// Forward: out = activation(a) * b
// Backward: d_a = grad * b * activation'(a), d_b = grad * activation(a)

// ============================================================================
// Forward kernels: 2 inputs (a, b), 1 output, uniform params
// ============================================================================

struct FusedFwdParams {
    numel: u32,
}

@group(0) @binding(0) var<storage, read_write> fwd_a: array<f32>;
@group(0) @binding(1) var<storage, read_write> fwd_b: array<f32>;
@group(0) @binding(2) var<storage, read_write> fwd_out: array<f32>;
@group(0) @binding(3) var<uniform> fwd_params: FusedFwdParams;

@compute @workgroup_size(256)
fn silu_mul_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < fwd_params.numel) {
        let x = fwd_a[idx];
        let sig = 1.0 / (1.0 + exp(-x));
        fwd_out[idx] = x * sig * fwd_b[idx];
    }
}

@compute @workgroup_size(256)
fn gelu_mul_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < fwd_params.numel) {
        let x = fwd_a[idx];
        let c = 0.7978845608;
        let k = 0.044715;
        let inner = c * (x + k * x * x * x);
        let t = tanh(inner);
        fwd_out[idx] = 0.5 * x * (1.0 + t) * fwd_b[idx];
    }
}

@compute @workgroup_size(256)
fn relu_mul_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < fwd_params.numel) {
        fwd_out[idx] = max(0.0, fwd_a[idx]) * fwd_b[idx];
    }
}

@compute @workgroup_size(256)
fn sigmoid_mul_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < fwd_params.numel) {
        let sig = 1.0 / (1.0 + exp(-fwd_a[idx]));
        fwd_out[idx] = sig * fwd_b[idx];
    }
}

// ============================================================================
// Backward kernels: 3 inputs (grad, a, b), 2 outputs (d_a, d_b), uniform params
// ============================================================================

struct FusedBwdParams {
    numel: u32,
}

@group(0) @binding(0) var<storage, read_write> bwd_grad: array<f32>;
@group(0) @binding(1) var<storage, read_write> bwd_a: array<f32>;
@group(0) @binding(2) var<storage, read_write> bwd_b: array<f32>;
@group(0) @binding(3) var<storage, read_write> bwd_d_a: array<f32>;
@group(0) @binding(4) var<storage, read_write> bwd_d_b: array<f32>;
@group(0) @binding(5) var<uniform> bwd_params: FusedBwdParams;

// silu'(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
@compute @workgroup_size(256)
fn silu_mul_bwd_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < bwd_params.numel) {
        let x = bwd_a[idx];
        let g = bwd_grad[idx];
        let bv = bwd_b[idx];
        let sig = 1.0 / (1.0 + exp(-x));
        let silu_val = x * sig;
        let silu_deriv = sig * (1.0 + x * (1.0 - sig));
        bwd_d_b[idx] = g * silu_val;
        bwd_d_a[idx] = g * bv * silu_deriv;
    }
}

// gelu'(x) = 0.5 * (1 + t) + 0.5 * x * (1 - t*t) * c * (1 + 3*k*x*x)
@compute @workgroup_size(256)
fn gelu_mul_bwd_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < bwd_params.numel) {
        let x = bwd_a[idx];
        let g = bwd_grad[idx];
        let bv = bwd_b[idx];
        let c = 0.7978845608;
        let k = 0.044715;
        let inner = c * (x + k * x * x * x);
        let t = tanh(inner);
        let gelu_val = 0.5 * x * (1.0 + t);
        let gelu_deriv = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * c * (1.0 + 3.0 * k * x * x);
        bwd_d_b[idx] = g * gelu_val;
        bwd_d_a[idx] = g * bv * gelu_deriv;
    }
}

// relu'(x) = 1 if x > 0, else 0
@compute @workgroup_size(256)
fn relu_mul_bwd_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < bwd_params.numel) {
        let x = bwd_a[idx];
        let g = bwd_grad[idx];
        let bv = bwd_b[idx];
        let relu_val = max(0.0, x);
        let relu_deriv = select(0.0, 1.0, x > 0.0);
        bwd_d_b[idx] = g * relu_val;
        bwd_d_a[idx] = g * bv * relu_deriv;
    }
}

// sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
@compute @workgroup_size(256)
fn sigmoid_mul_bwd_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let idx = gid.x;
    if (idx < bwd_params.numel) {
        let x = bwd_a[idx];
        let g = bwd_grad[idx];
        let bv = bwd_b[idx];
        let sig = 1.0 / (1.0 + exp(-x));
        let sig_deriv = sig * (1.0 - sig);
        bwd_d_b[idx] = g * sig;
        bwd_d_a[idx] = g * bv * sig_deriv;
    }
}