numr 0.5.1

High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)
Documentation
// Fused add + normalization operations. F32 only.
// Entry points:
//   - fused_add_rms_norm_f32: Add residual, then RMS normalize
//   - fused_add_layer_norm_f32: Add residual, then layer normalize
//   - fused_add_rms_norm_bwd_f32: Backward pass for fused add RMS norm
//   - fused_add_layer_norm_bwd_f32: Backward pass for fused add layer norm
//   - reduce_sum_rows_f32: Reduce d_weight/d_bias scratch buffers across batch dimension

// ============================================================================
// Workgroup Configuration
// ============================================================================

const WORKGROUP_SIZE: u32 = 256u;

// ============================================================================
// RMS Normalization Structs
// ============================================================================

struct RmsNormParams {
    batch_size: u32,
    hidden_size: u32,
    eps: f32,
}

struct LayerNormParams {
    batch_size: u32,
    hidden_size: u32,
    eps: f32,
}

struct ReduceSumParams {
    batch_size: u32,
    hidden_size: u32,
}

// ============================================================================
// Fused Add + RMS Norm (Forward)
// ============================================================================

@group(0) @binding(0) var<storage, read_write> farn_input: array<f32>;
@group(0) @binding(1) var<storage, read_write> farn_residual: array<f32>;
@group(0) @binding(2) var<storage, read_write> farn_weight: array<f32>;
@group(0) @binding(3) var<storage, read_write> farn_output: array<f32>;
@group(0) @binding(4) var<storage, read_write> farn_pre_norm: array<f32>;
@group(0) @binding(5) var<uniform> farn_params: RmsNormParams;

var<workgroup> farn_shared: array<f32, 256>;

@compute @workgroup_size(256)
fn fused_add_rms_norm_f32(@builtin(global_invocation_id) global_id: vec3<u32>,
                          @builtin(local_invocation_id) local_id: vec3<u32>,
                          @builtin(workgroup_id) group_id: vec3<u32>) {
    let tid = local_id.x;
    let batch_idx = group_id.x;

    if (batch_idx >= farn_params.batch_size) {
        return;
    }

    let hidden_size = farn_params.hidden_size;
    let eps = farn_params.eps;
    let base_offset = batch_idx * hidden_size;

    // Step 1: Add input + residual -> pre_norm, compute sum of squares
    var sum_sq: f32 = 0.0;
    var i: u32 = tid;
    while (i < hidden_size) {
        let pre_val = farn_input[base_offset + i] + farn_residual[base_offset + i];
        farn_pre_norm[base_offset + i] = pre_val;
        sum_sq = sum_sq + pre_val * pre_val;
        i = i + WORKGROUP_SIZE;
    }

    farn_shared[tid] = sum_sq;
    workgroupBarrier();

    // Reduce to get total sum of squares
    for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {
        if (tid < s) {
            farn_shared[tid] = farn_shared[tid] + farn_shared[tid + s];
        }
        workgroupBarrier();
    }

    // Compute RMS: sqrt(mean(x^2) + eps)
    let rms = sqrt(farn_shared[0] / f32(hidden_size) + eps);
    workgroupBarrier();

    // Step 2: Normalize and apply weight
    i = tid;
    while (i < hidden_size) {
        farn_output[base_offset + i] = farn_pre_norm[base_offset + i] / rms * farn_weight[i];
        i = i + WORKGROUP_SIZE;
    }
}

// ============================================================================
// Fused Add + Layer Norm (Forward)
// ============================================================================

@group(0) @binding(0) var<storage, read_write> faln_input: array<f32>;
@group(0) @binding(1) var<storage, read_write> faln_residual: array<f32>;
@group(0) @binding(2) var<storage, read_write> faln_weight: array<f32>;
@group(0) @binding(3) var<storage, read_write> faln_bias: array<f32>;
@group(0) @binding(4) var<storage, read_write> faln_output: array<f32>;
@group(0) @binding(5) var<storage, read_write> faln_pre_norm: array<f32>;
@group(0) @binding(6) var<uniform> faln_params: LayerNormParams;

var<workgroup> faln_shared_mean: array<f32, 256>;
var<workgroup> faln_shared_var: array<f32, 256>;

@compute @workgroup_size(256)
fn fused_add_layer_norm_f32(@builtin(global_invocation_id) global_id: vec3<u32>,
                            @builtin(local_invocation_id) local_id: vec3<u32>,
                            @builtin(workgroup_id) group_id: vec3<u32>) {
    let tid = local_id.x;
    let batch_idx = group_id.x;

    if (batch_idx >= faln_params.batch_size) {
        return;
    }

    let hidden_size = faln_params.hidden_size;
    let eps = faln_params.eps;
    let base_offset = batch_idx * hidden_size;

    // Step 1: Add input + residual -> pre_norm, compute sum for mean
    var sum: f32 = 0.0;
    var i: u32 = tid;
    while (i < hidden_size) {
        let pre_val = faln_input[base_offset + i] + faln_residual[base_offset + i];
        faln_pre_norm[base_offset + i] = pre_val;
        sum = sum + pre_val;
        i = i + WORKGROUP_SIZE;
    }

    faln_shared_mean[tid] = sum;
    workgroupBarrier();

    for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {
        if (tid < s) {
            faln_shared_mean[tid] = faln_shared_mean[tid] + faln_shared_mean[tid + s];
        }
        workgroupBarrier();
    }

    let mean = faln_shared_mean[0] / f32(hidden_size);
    workgroupBarrier();

    // Step 2: Compute variance
    var var_sum: f32 = 0.0;
    i = tid;
    while (i < hidden_size) {
        let diff = faln_pre_norm[base_offset + i] - mean;
        var_sum = var_sum + diff * diff;
        i = i + WORKGROUP_SIZE;
    }

    faln_shared_var[tid] = var_sum;
    workgroupBarrier();

    for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {
        if (tid < s) {
            faln_shared_var[tid] = faln_shared_var[tid] + faln_shared_var[tid + s];
        }
        workgroupBarrier();
    }

    let variance = faln_shared_var[0] / f32(hidden_size);
    let inv_std = 1.0 / sqrt(variance + eps);
    workgroupBarrier();

    // Step 3: Normalize and apply affine transformation
    i = tid;
    while (i < hidden_size) {
        let normalized = (faln_pre_norm[base_offset + i] - mean) * inv_std;
        faln_output[base_offset + i] = normalized * faln_weight[i] + faln_bias[i];
        i = i + WORKGROUP_SIZE;
    }
}

// ============================================================================
// Fused Add + RMS Norm (Backward)
// ============================================================================

@group(0) @binding(0) var<storage, read_write> farnb_grad: array<f32>;
@group(0) @binding(1) var<storage, read_write> farnb_pre_norm: array<f32>;
@group(0) @binding(2) var<storage, read_write> farnb_weight: array<f32>;
@group(0) @binding(3) var<storage, read_write> farnb_d_input_residual: array<f32>;
@group(0) @binding(4) var<storage, read_write> farnb_d_weight_scratch: array<f32>;
@group(0) @binding(5) var<uniform> farnb_params: RmsNormParams;

var<workgroup> farnb_shared_sum_sq: array<f32, 256>;
var<workgroup> farnb_shared_dot: array<f32, 256>;

@compute @workgroup_size(256)
fn fused_add_rms_norm_bwd_f32(@builtin(global_invocation_id) global_id: vec3<u32>,
                              @builtin(local_invocation_id) local_id: vec3<u32>,
                              @builtin(workgroup_id) group_id: vec3<u32>) {
    let tid = local_id.x;
    let batch_idx = group_id.x;

    if (batch_idx >= farnb_params.batch_size) {
        return;
    }

    let hidden_size = farnb_params.hidden_size;
    let eps = farnb_params.eps;
    let base_offset = batch_idx * hidden_size;

    // Phase 1: Compute sum_sq and dot(grad, weight, pre_norm)
    var sum_sq: f32 = 0.0;
    var dot: f32 = 0.0;
    var i: u32 = tid;
    while (i < hidden_size) {
        let pre_val = farnb_pre_norm[base_offset + i];
        sum_sq = sum_sq + pre_val * pre_val;
        dot = dot + farnb_grad[base_offset + i] * farnb_weight[i] * pre_val;
        i = i + WORKGROUP_SIZE;
    }

    farnb_shared_sum_sq[tid] = sum_sq;
    farnb_shared_dot[tid] = dot;
    workgroupBarrier();

    // Reduce both sums
    for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {
        if (tid < s) {
            farnb_shared_sum_sq[tid] = farnb_shared_sum_sq[tid] + farnb_shared_sum_sq[tid + s];
            farnb_shared_dot[tid] = farnb_shared_dot[tid] + farnb_shared_dot[tid + s];
        }
        workgroupBarrier();
    }

    let total_sum_sq = farnb_shared_sum_sq[0];
    let total_dot = farnb_shared_dot[0];
    let rms = sqrt(total_sum_sq / f32(hidden_size) + eps);
    let inv_rms = 1.0 / rms;
    let inv_rms_cubed = inv_rms * inv_rms * inv_rms;
    let coeff = total_dot * inv_rms_cubed / f32(hidden_size);
    workgroupBarrier();

    // Phase 2: Compute d_input_residual and accumulate d_weight
    i = tid;
    while (i < hidden_size) {
        // d_input_residual = (grad * weight - pre_norm * coeff) * inv_rms
        farnb_d_input_residual[base_offset + i] =
            (farnb_grad[base_offset + i] * farnb_weight[i] - farnb_pre_norm[base_offset + i] * coeff) * inv_rms;

        // d_weight contribution: sum(grad * pre_norm / rms) per element
        // Each workgroup writes its per-row contribution to scratch
        farnb_d_weight_scratch[base_offset + i] = farnb_grad[base_offset + i] * farnb_pre_norm[base_offset + i] * inv_rms;

        i = i + WORKGROUP_SIZE;
    }
}

// ============================================================================
// Fused Add + Layer Norm (Backward)
// ============================================================================

@group(0) @binding(0) var<storage, read_write> falnb_grad: array<f32>;
@group(0) @binding(1) var<storage, read_write> falnb_pre_norm: array<f32>;
@group(0) @binding(2) var<storage, read_write> falnb_weight: array<f32>;
@group(0) @binding(3) var<storage, read_write> falnb_bias: array<f32>;
@group(0) @binding(4) var<storage, read_write> falnb_d_input_residual: array<f32>;
@group(0) @binding(5) var<storage, read_write> falnb_d_weight_scratch: array<f32>;
@group(0) @binding(6) var<storage, read_write> falnb_d_bias_scratch: array<f32>;
@group(0) @binding(7) var<uniform> falnb_params: LayerNormParams;

var<workgroup> falnb_shared_mean: array<f32, 256>;
var<workgroup> falnb_shared_var: array<f32, 256>;

@compute @workgroup_size(256)
fn fused_add_layer_norm_bwd_f32(@builtin(global_invocation_id) global_id: vec3<u32>,
                                @builtin(local_invocation_id) local_id: vec3<u32>,
                                @builtin(workgroup_id) group_id: vec3<u32>) {
    let tid = local_id.x;
    let batch_idx = group_id.x;

    if (batch_idx >= falnb_params.batch_size) {
        return;
    }

    let hidden_size = falnb_params.hidden_size;
    let eps = falnb_params.eps;
    let base_offset = batch_idx * hidden_size;

    // Phase 1: Compute mean of pre_norm
    var sum: f32 = 0.0;
    var i: u32 = tid;
    while (i < hidden_size) {
        sum = sum + falnb_pre_norm[base_offset + i];
        i = i + WORKGROUP_SIZE;
    }

    falnb_shared_mean[tid] = sum;
    workgroupBarrier();

    for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {
        if (tid < s) {
            falnb_shared_mean[tid] = falnb_shared_mean[tid] + falnb_shared_mean[tid + s];
        }
        workgroupBarrier();
    }

    let mean = falnb_shared_mean[0] / f32(hidden_size);
    workgroupBarrier();

    // Phase 2: Compute variance
    var var_sum: f32 = 0.0;
    i = tid;
    while (i < hidden_size) {
        let diff = falnb_pre_norm[base_offset + i] - mean;
        var_sum = var_sum + diff * diff;
        i = i + WORKGROUP_SIZE;
    }

    falnb_shared_var[tid] = var_sum;
    workgroupBarrier();

    for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {
        if (tid < s) {
            falnb_shared_var[tid] = falnb_shared_var[tid] + falnb_shared_var[tid + s];
        }
        workgroupBarrier();
    }

    let variance = falnb_shared_var[0] / f32(hidden_size);
    let inv_std = 1.0 / sqrt(variance + eps);

    // Compute grad_scaled = grad * weight sums
    var sum_gs: f32 = 0.0;
    var sum_gs_n: f32 = 0.0;
    i = tid;
    while (i < hidden_size) {
        let normalized = (falnb_pre_norm[base_offset + i] - mean) * inv_std;
        let gs = falnb_grad[base_offset + i] * falnb_weight[i];
        sum_gs = sum_gs + gs;
        sum_gs_n = sum_gs_n + gs * normalized;
        i = i + WORKGROUP_SIZE;
    }

    falnb_shared_mean[tid] = sum_gs;
    falnb_shared_var[tid] = sum_gs_n;
    workgroupBarrier();

    for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {
        if (tid < s) {
            falnb_shared_mean[tid] = falnb_shared_mean[tid] + falnb_shared_mean[tid + s];
            falnb_shared_var[tid] = falnb_shared_var[tid] + falnb_shared_var[tid + s];
        }
        workgroupBarrier();
    }

    let total_sum_gs = falnb_shared_mean[0];
    let total_sum_gs_n = falnb_shared_var[0];
    workgroupBarrier();

    // Phase 3: Compute d_input_residual, d_weight_scratch, d_bias_scratch
    i = tid;
    while (i < hidden_size) {
        let normalized = (falnb_pre_norm[base_offset + i] - mean) * inv_std;

        // d_input_residual = inv_std * (grad*weight - mean_gs - normalized * mean_gs_n)
        let mean_gs_val = total_sum_gs / f32(hidden_size);
        let mean_gs_n_val = total_sum_gs_n / f32(hidden_size);
        let gs = falnb_grad[base_offset + i] * falnb_weight[i];
        falnb_d_input_residual[base_offset + i] = inv_std *
            (gs - mean_gs_val - normalized * mean_gs_n_val);

        // d_weight: sum(grad * normalized) per element
        falnb_d_weight_scratch[base_offset + i] = falnb_grad[base_offset + i] * normalized;

        // d_bias: sum(grad) per element
        falnb_d_bias_scratch[base_offset + i] = falnb_grad[base_offset + i];

        i = i + WORKGROUP_SIZE;
    }
}

// ============================================================================
// Reduce Sum Rows (helper for backward)
// ============================================================================

@group(0) @binding(0) var<storage, read_write> rsr_input: array<f32>;
@group(0) @binding(1) var<storage, read_write> rsr_output: array<f32>;
@group(0) @binding(2) var<uniform> rsr_params: ReduceSumParams;

@compute @workgroup_size(256)
fn reduce_sum_rows_f32(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    if (i >= rsr_params.hidden_size) {
        return;
    }

    var sum: f32 = 0.0;
    for (var b: u32 = 0u; b < rsr_params.batch_size; b = b + 1u) {
        sum = sum + rsr_input[b * rsr_params.hidden_size + i];
    }
    rsr_output[i] = sum;
}