meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
// Precompute rsqrt for RmsNorm: rsqrt[row] = inverseSqrt(sum(x²)/K + eps)
//
// Dispatch: [rows, 1, 1], WG=256
// Each workgroup cooperatively computes rsqrt for one row using
// shared-memory tree reduction over the K dimension.

struct Params {
    rows: u32,
    cols: u32,
    eps_bits: u32,
    _pad: u32,
}

var<storage> src: array<f32>;              // X [rows, cols]
var<storage, read_write> dst: array<f32>;  // rsqrt [rows]
var<uniform> params: Params;
var<workgroup> wg_sum: array<f32, 256>;

@compute @workgroup_size(256)
fn main(@builtin(workgroup_id) wgid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) {
    let row = wgid.x;
    let tid = lid.x;
    let cols = params.cols;
    let eps = bitcast<f32>(params.eps_bits);

    if row >= params.rows { return; }

    // Cooperative sum of squares: each thread sums a stride of the row
    let offset = row * cols;
    var ss = 0.0;
    var j = tid;
    loop {
        if j >= cols { break; }
        let v = src[offset + j];
        ss += v * v;
        j += 256u;
    }
    wg_sum[tid] = ss;
    workgroupBarrier();

    // Tree reduction
    var stride = 128u;
    loop {
        if stride == 0u { break; }
        if tid < stride {
            wg_sum[tid] += wg_sum[tid + stride];
        }
        workgroupBarrier();
        stride >>= 1u;
    }

    if tid == 0u {
        dst[row] = inverseSqrt(wg_sum[0] / f32(cols) + eps);
    }
}