meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
struct Params {
    len: u32,
    lr: f32,
    beta1: f32,
    beta2: f32,
    eps: f32,
    step: f32,
    _pad0: u32,
    _pad1: u32,
}

var<storage, read_write> param: array<f32>;
var<storage> grad: array<f32>;
var<storage, read_write> m: array<f32>;
var<storage, read_write> v: array<f32>;
var<uniform> params: Params;

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    if i >= params.len { return; }

    let g = grad[i];

    // Update biased first moment
    let m_new = params.beta1 * m[i] + (1.0 - params.beta1) * g;
    // Update biased second moment
    let v_new = params.beta2 * v[i] + (1.0 - params.beta2) * g * g;

    m[i] = m_new;
    v[i] = v_new;

    // Bias-corrected estimates
    let m_hat = m_new / (1.0 - pow(params.beta1, params.step));
    let v_hat = v_new / (1.0 - pow(params.beta2, params.step));

    // Update parameter
    param[i] = param[i] - params.lr * m_hat / (sqrt(v_hat) + params.eps);
}