meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
// GroupNorm backward w.r.t. input.
// Dispatch: [N * num_groups, 1, 1]  workgroup_size(256)
// grad_input[i] = inv_std * (w[c] * dout[i] - mean(w*dout) - xhat[i] * mean(w*dout*xhat)) / 1
// where xhat = (x - mean) * inv_std
//
// Inputs: grad_out (src_a), input (src_b), weight (bias), grad_input (dst)
// Params encode the same as forward.

struct Params {
    batch: u32,
    channels: u32,
    spatial: u32,
    num_groups: u32,
    eps_bits: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
}

var<storage> src_a: array<f32>;     // grad_output
var<storage> src_b: array<f32>;     // input x
var<storage> bias: array<f32>;      // weight[C]
var<storage, read_write> dst: array<f32>;  // grad_input
var<uniform> params: Params;
var<workgroup> wg_data: array<f32, 256>;
var<workgroup> wg_data2: array<f32, 256>;

@compute @workgroup_size(256)
fn grad_input(@builtin(workgroup_id) wgid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) {
    let ng = wgid.x;
    if ng >= params.batch * params.num_groups { return; }

    let n = ng / params.num_groups;
    let group = ng % params.num_groups;
    let tid = lid.x;
    let eps = bitcast<f32>(params.eps_bits);

    let channels_per_group = params.channels / params.num_groups;
    let group_size = channels_per_group * params.spatial;
    let c_start = group * channels_per_group;

    // Pass 1: compute mean and variance of x within group
    var sum_x = 0.0;
    var sum_x2 = 0.0;
    var j = tid;
    loop {
        if j >= group_size { break; }
        let c_local = j / params.spatial;
        let hw = j % params.spatial;
        let c = c_start + c_local;
        let idx = ((n * params.channels + c) * params.spatial) + hw;
        let v = src_b[idx];
        sum_x += v;
        sum_x2 += v * v;
        j += 256u;
    }
    wg_data[tid] = sum_x;
    wg_data2[tid] = sum_x2;
    workgroupBarrier();

    var stride = 128u;
    loop {
        if stride == 0u { break; }
        if tid < stride {
            wg_data[tid] += wg_data[tid + stride];
            wg_data2[tid] += wg_data2[tid + stride];
        }
        workgroupBarrier();
        stride >>= 1u;
    }
    let mean = wg_data[0] / f32(group_size);
    let variance = wg_data2[0] / f32(group_size) - mean * mean;
    let inv_std = inverseSqrt(variance + eps);
    workgroupBarrier();

    // Pass 2: compute sum(w * dout) and sum(w * dout * xhat) within group
    var sum_wdy = 0.0;
    var sum_wdy_xhat = 0.0;
    j = tid;
    loop {
        if j >= group_size { break; }
        let c_local = j / params.spatial;
        let hw = j % params.spatial;
        let c = c_start + c_local;
        let idx = ((n * params.channels + c) * params.spatial) + hw;
        let dy = src_a[idx];
        let w = bias[c];
        let xhat = (src_b[idx] - mean) * inv_std;
        sum_wdy += w * dy;
        sum_wdy_xhat += w * dy * xhat;
        j += 256u;
    }
    wg_data[tid] = sum_wdy;
    wg_data2[tid] = sum_wdy_xhat;
    workgroupBarrier();

    stride = 128u;
    loop {
        if stride == 0u { break; }
        if tid < stride {
            wg_data[tid] += wg_data[tid + stride];
            wg_data2[tid] += wg_data2[tid + stride];
        }
        workgroupBarrier();
        stride >>= 1u;
    }
    let mean_wdy = wg_data[0] / f32(group_size);
    let mean_wdy_xhat = wg_data2[0] / f32(group_size);

    // Pass 3: compute grad_input
    j = tid;
    loop {
        if j >= group_size { break; }
        let c_local = j / params.spatial;
        let hw = j % params.spatial;
        let c = c_start + c_local;
        let idx = ((n * params.channels + c) * params.spatial) + hw;
        let dy = src_a[idx];
        let w = bias[c];
        let xhat = (src_b[idx] - mean) * inv_std;
        dst[idx] = inv_std * (w * dy - mean_wdy - xhat * mean_wdy_xhat);
        j += 256u;
    }
}

// GroupNorm backward w.r.t. weight and bias.
// Dispatch: [C, 1, 1]  workgroup_size(256)
// grad_weight[c] = sum_{n,hw} grad_out[n,c,hw] * xhat[n,c,hw]
// grad_bias[c] = sum_{n,hw} grad_out[n,c,hw]
// dst layout: [grad_weight[C], grad_bias[C]] = 2*C elements
//
// Each workgroup handles one channel c. For each batch item n, all 256 threads
// cooperatively compute mean/var for the group via parallel reduction, then
// cooperatively accumulate dw and db over spatial positions.

@compute @workgroup_size(256)
fn grad_weight_bias(@builtin(workgroup_id) wgid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) {
    let c = wgid.x;
    if c >= params.channels { return; }

    let tid = lid.x;
    let eps = bitcast<f32>(params.eps_bits);
    let channels_per_group = params.channels / params.num_groups;
    let group = c / channels_per_group;
    let c_start = group * channels_per_group;
    let group_size = channels_per_group * params.spatial;

    var acc_dw = 0.0;
    var acc_db = 0.0;

    for (var n = 0u; n < params.batch; n++) {
        // Cooperative mean/var: 256 threads stride over group_size elements
        var local_sum = 0.0;
        var local_sum2 = 0.0;
        var j = tid;
        loop {
            if j >= group_size { break; }
            let cc = c_start + j / params.spatial;
            let hw = j % params.spatial;
            let idx = ((n * params.channels + cc) * params.spatial) + hw;
            let v = src_b[idx];
            local_sum += v;
            local_sum2 += v * v;
            j += 256u;
        }
        wg_data[tid] = local_sum;
        wg_data2[tid] = local_sum2;
        workgroupBarrier();

        // Tree reduction for mean and variance
        var stride = 128u;
        loop {
            if stride == 0u { break; }
            if tid < stride {
                wg_data[tid] += wg_data[tid + stride];
                wg_data2[tid] += wg_data2[tid + stride];
            }
            workgroupBarrier();
            stride >>= 1u;
        }
        let mean = wg_data[0] / f32(group_size);
        let variance = wg_data2[0] / f32(group_size) - mean * mean;
        let inv_std = inverseSqrt(variance + eps);
        workgroupBarrier();

        // Cooperative accumulation of dw and db over spatial for this channel
        var local_dw = 0.0;
        var local_db = 0.0;
        j = tid;
        loop {
            if j >= params.spatial { break; }
            let idx = ((n * params.channels + c) * params.spatial) + j;
            let dy = src_a[idx];
            let xhat = (src_b[idx] - mean) * inv_std;
            local_dw += dy * xhat;
            local_db += dy;
            j += 256u;
        }
        wg_data[tid] = local_dw;
        wg_data2[tid] = local_db;
        workgroupBarrier();

        stride = 128u;
        loop {
            if stride == 0u { break; }
            if tid < stride {
                wg_data[tid] += wg_data[tid + stride];
                wg_data2[tid] += wg_data2[tid + stride];
            }
            workgroupBarrier();
            stride >>= 1u;
        }
        acc_dw += wg_data[0];
        acc_db += wg_data2[0];
        workgroupBarrier();
    }

    if tid == 0u {
        dst[c] = acc_dw;
        dst[params.channels + c] = acc_db;
    }
}