meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
// LayerNorm gradient shaders.
// Two entry points:
//   layer_norm_grad_wb: gradient wrt weight and bias
//   layer_norm_grad_x: gradient wrt input
// Params: m=rows, n=cols, k=eps_bits

struct Params {
    m: u32,
    n: u32,
    k: u32,    // eps_bits
    _pad: u32,
}

var<storage> src_a: array<f32>;  // dy (grad_output)
var<storage> src_b: array<f32>;  // x (input)
var<storage> bias: array<f32>;   // w (weight)
var<storage, read_write> dst: array<f32>;
var<uniform> params: Params;

// grad_weight[j] = sum_i(dy[i,j] * normed[i,j])
// grad_bias[j] = sum_i(dy[i,j])
// Output layout: dst[0..cols] = grad_weight, dst[cols..2*cols] = grad_bias
//
// Dispatch: [cols, 1, 1]
@compute @workgroup_size(1)
fn layer_norm_grad_wb(@builtin(global_invocation_id) gid: vec3<u32>) {
    let j = gid.x;
    let rows = params.m;
    let cols = params.n;
    let eps = bitcast<f32>(params.k);
    if j >= cols { return; }

    var grad_w = 0.0;

    for (var i = 0u; i < rows; i++) {
        let offset = i * cols;
        // Recompute mean and rstd for this row
        var sum = 0.0;
        for (var c = 0u; c < cols; c++) {
            sum += src_b[offset + c];
        }
        let mean = sum / f32(cols);
        var var_sum = 0.0;
        for (var c = 0u; c < cols; c++) {
            let diff = src_b[offset + c] - mean;
            var_sum += diff * diff;
        }
        let rstd = inverseSqrt(var_sum / f32(cols) + eps);
        let normed = (src_b[offset + j] - mean) * rstd;

        grad_w += src_a[offset + j] * normed;
    }

    dst[j] = grad_w;
}

// grad_x[i,j] = rstd * (dy[i,j]*w[j] - normed[i,j]*s_i - mean(dy*w)/cols)
// where s_i = sum_j(dy[i,j]*w[j]*normed[i,j]) / cols
//
// Dispatch: [rows, 1, 1]
@compute @workgroup_size(1)
fn layer_norm_grad_x(@builtin(global_invocation_id) gid: vec3<u32>) {
    let row = gid.x;
    let rows = params.m;
    let cols = params.n;
    let eps = bitcast<f32>(params.k);
    if row >= rows { return; }

    let offset = row * cols;

    // Recompute mean, rstd
    var sum = 0.0;
    for (var j = 0u; j < cols; j++) {
        sum += src_b[offset + j];
    }
    let mean = sum / f32(cols);
    var var_sum = 0.0;
    for (var j = 0u; j < cols; j++) {
        let diff = src_b[offset + j] - mean;
        var_sum += diff * diff;
    }
    let rstd = inverseSqrt(var_sum / f32(cols) + eps);

    // Compute dot products for the backward formula
    var dot_dy_w = 0.0;    // sum_j(dy[j] * w[j])
    var dot_dy_w_norm = 0.0; // sum_j(dy[j] * w[j] * normed[j])
    for (var j = 0u; j < cols; j++) {
        let dy_w = src_a[offset + j] * bias[j];
        let normed = (src_b[offset + j] - mean) * rstd;
        dot_dy_w += dy_w;
        dot_dy_w_norm += dy_w * normed;
    }

    // Write grad_x
    let inv_n = 1.0 / f32(cols);
    for (var j = 0u; j < cols; j++) {
        let normed = (src_b[offset + j] - mean) * rstd;
        let dy_w = src_a[offset + j] * bias[j];
        dst[offset + j] = rstd * (dy_w - inv_n * dot_dy_w - normed * inv_n * dot_dy_w_norm);
    }
}