meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
struct Params {
    seq: u32,
    dim: u32,
    theta_bits: u32,
    pos_offset: u32,
    head_dim: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
}

var<storage> src: array<f32>;       // grad_output
var<storage, read_write> dst: array<f32>;  // grad_input
var<uniform> params: Params;

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    let half_dim = params.dim / 2u;
    let total = params.seq * half_dim;
    if i >= total { return; }

    let row = i / half_dim;
    let pos = row + params.pos_offset;
    let pair_in_row = i % half_dim;
    let theta = bitcast<f32>(params.theta_bits);

    let half_head = params.head_dim / 2u;
    let head = pair_in_row / half_head;
    let pair_in_head = pair_in_row % half_head;

    let exponent = -2.0 * f32(pair_in_head) / f32(params.head_dim);
    let inv_freq = pow(theta, exponent);
    let angle = f32(pos) * inv_freq;
    let cos_val = cos(angle);
    let sin_val = sin(angle);

    // Same index layout as forward
    let base = row * params.dim + head * params.head_dim;
    let idx0 = base + pair_in_head;
    let idx1 = base + pair_in_head + half_head;
    let g0 = src[idx0];
    let g1 = src[idx1];

    // Inverse rotation (transpose of rotation matrix):
    // grad_x0 = grad_y0 * cos + grad_y1 * sin
    // grad_x1 = -grad_y0 * sin + grad_y1 * cos
    dst[idx0] = g0 * cos_val + g1 * sin_val;
    dst[idx1] = -g0 * sin_val + g1 * cos_val;
}