rumus 0.2.0

A native-Rust deep learning framework with explicit memory safety and hardware acceleration
Documentation
// BatchNorm2d forward: per-channel reduction over B*H*W.
//
// One workgroup per channel.  Three phases:
//   1. Reduce → mean[c]
//   2. Reduce → var[c]
//   3. Normalize + affine + update running stats
//
// During eval (is_training==0), uses running_mean/running_var instead.

struct BatchNorm2dParams {
    batch: u32,
    channels: u32,
    height: u32,
    width: u32,
    epsilon: f32,
    momentum: f32,
    is_training: u32,
    _pad: u32,
}
// 32 bytes ✓

@group(0) @binding(0) var<storage, read>       bn_input:        array<f32>;
@group(0) @binding(1) var<storage, read>       bn_weight:       array<f32>; // γ [C]
@group(0) @binding(2) var<storage, read>       bn_bias:         array<f32>; // β [C]
@group(0) @binding(3) var<storage, read_write> bn_running_mean: array<f32>; // [C]
@group(0) @binding(4) var<storage, read_write> bn_running_var:  array<f32>; // [C]
@group(0) @binding(5) var<storage, read_write> bn_output:       array<f32>;
@group(0) @binding(6) var<storage, read_write> bn_save:         array<f32>; // [C, 2]
@group(0) @binding(7) var<uniform>             bn_params:       BatchNorm2dParams;

var<workgroup> shared_val: array<f32, 64>;

@compute @workgroup_size(64)
fn batch_norm_forward_kernel(
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(workgroup_id) wgid: vec3<u32>,
) {
    let c = wgid.x;
    if (c >= bn_params.channels) { return; }
    let tid = lid.x;
    let spatial = bn_params.height * bn_params.width;
    let n = bn_params.batch * spatial;

    var use_mean: f32;
    var use_invstd: f32;

    if (bn_params.is_training == 1u) {
        // ---- Phase 1: Mean ----
        var local_sum: f32 = 0.0;
        var idx = tid;
        while (idx < n) {
            let b = idx / spatial;
            let hw = idx % spatial;
            let flat = b * bn_params.channels * spatial + c * spatial + hw;
            local_sum += bn_input[flat];
            idx += 64u;
        }
        shared_val[tid] = local_sum;
        workgroupBarrier();
        var s: u32 = 32u;
        while (s > 0u) {
            if (tid < s) { shared_val[tid] += shared_val[tid + s]; }
            workgroupBarrier();
            s = s >> 1u;
        }
        let mean = shared_val[0] / f32(n);
        workgroupBarrier();

        // ---- Phase 2: Variance ----
        var local_var: f32 = 0.0;
        idx = tid;
        while (idx < n) {
            let b = idx / spatial;
            let hw = idx % spatial;
            let flat = b * bn_params.channels * spatial + c * spatial + hw;
            let diff = bn_input[flat] - mean;
            local_var += diff * diff;
            idx += 64u;
        }
        shared_val[tid] = local_var;
        workgroupBarrier();
        s = 32u;
        while (s > 0u) {
            if (tid < s) { shared_val[tid] += shared_val[tid + s]; }
            workgroupBarrier();
            s = s >> 1u;
        }
        let variance = shared_val[0] / f32(n);
        let invstd = 1.0 / sqrt(variance + bn_params.epsilon);
        workgroupBarrier();

        use_mean = mean;
        use_invstd = invstd;

        // Save + update running stats (thread 0 only).
        if (tid == 0u) {
            bn_save[c * 2u] = mean;
            bn_save[c * 2u + 1u] = invstd;
            let m = bn_params.momentum;
            bn_running_mean[c] = (1.0 - m) * bn_running_mean[c] + m * mean;
            bn_running_var[c]  = (1.0 - m) * bn_running_var[c]  + m * variance;
        }
    } else {
        // Eval mode: use running stats.
        use_mean = bn_running_mean[c];
        use_invstd = 1.0 / sqrt(bn_running_var[c] + bn_params.epsilon);
    }
    workgroupBarrier();

    // ---- Phase 3: Normalize + Affine ----
    let gamma = bn_weight[c];
    let beta = bn_bias[c];
    var idx2 = tid;
    while (idx2 < n) {
        let b = idx2 / spatial;
        let hw = idx2 % spatial;
        let flat = b * bn_params.channels * spatial + c * spatial + hw;
        let x_hat = (bn_input[flat] - use_mean) * use_invstd;
        bn_output[flat] = gamma * x_hat + beta;
        idx2 += 64u;
    }
}