rumus 0.3.1

A native-Rust deep learning framework with explicit memory safety and hardware acceleration
Documentation
// LayerNorm forward: fused mean + variance + normalize + affine.
//
// One workgroup per normalization instance (B*S workgroups).
// Three phases separated by workgroupBarrier():
//   1. Reduce → mean
//   2. Reduce → variance
//   3. Normalize + affine: y = γ * (x - mean) / sqrt(var + ε) + β
//
// Saves mean + invstd per instance for backward.

struct LayerNormParams {
    num_instances: u32,
    norm_size: u32,
    epsilon: f32,
    _pad: u32,
}
// 16 bytes ✓

@group(0) @binding(0) var<storage, read>       ln_input:  array<scalar>;
@group(0) @binding(1) var<storage, read>       ln_weight: array<scalar>; // γ [D]
@group(0) @binding(2) var<storage, read>       ln_bias:   array<scalar>; // β [D]
@group(0) @binding(3) var<storage, read_write> ln_output: array<scalar>;
@group(0) @binding(4) var<storage, read_write> ln_save:   array<scalar>; // [num_instances, 2]
@group(0) @binding(5) var<uniform>             ln_params: LayerNormParams;

var<workgroup> shared_val: array<scalar, 64>;

@compute @workgroup_size(64)
fn layer_norm_forward_kernel(
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(workgroup_id) wgid: vec3<u32>,
) {
    let inst = wgid.x;
    if (inst >= ln_params.num_instances) { return; }
    let D = ln_params.norm_size;
    let tid = lid.x;
    let base = inst * D;

    // ---- Phase 1: Mean ----
    var local_sum: scalar = scalar(0.0);
    var j = tid;
    while (j < D) {
        local_sum += ln_input[base + j];
        j += 64u;
    }
    shared_val[tid] = local_sum;
    workgroupBarrier();

    var s: u32 = 32u;
    while (s > 0u) {
        if (tid < s) { shared_val[tid] += shared_val[tid + s]; }
        workgroupBarrier();
        s = s >> 1u;
    }
    let mean = shared_val[0] / scalar(D);
    workgroupBarrier();

    // ---- Phase 2: Variance ----
    var local_var: scalar = scalar(0.0);
    j = tid;
    while (j < D) {
        let diff = ln_input[base + j] - mean;
        local_var += diff * diff;
        j += 64u;
    }
    shared_val[tid] = local_var;
    workgroupBarrier();

    s = 32u;
    while (s > 0u) {
        if (tid < s) { shared_val[tid] += shared_val[tid + s]; }
        workgroupBarrier();
        s = s >> 1u;
    }
    let variance = shared_val[0] / scalar(D);
    let invstd = scalar(1.0) / sqrt(variance + scalar(ln_params.epsilon));
    workgroupBarrier();

    // Save mean + invstd for backward.
    if (tid == 0u) {
        ln_save[inst * 2u] = mean;
        ln_save[inst * 2u + 1u] = invstd;
    }

    // ---- Phase 3: Normalize + Affine ----
    j = tid;
    while (j < D) {
        let x_hat = (ln_input[base + j] - mean) * invstd;
        ln_output[base + j] = ln_weight[j] * x_hat + ln_bias[j];
        j += 64u;
    }
}