meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
struct Params {
    total: u32,
    seq_len: u32,
    embed_dim: u32,
    _pad: u32,
}

var<storage> indices: array<u32>;
var<storage> src: array<f32>;
var<storage, read_write> dst: array<f32>;
var<uniform> params: Params;

// For each output element (out_row, col), sum src rows where indices[s] == out_row.
// Dispatch: [ceil(vocab*embed/256), 1, 1]
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    if i >= params.total { return; }

    let out_row = i / params.embed_dim;
    let col = i % params.embed_dim;

    var sum = 0.0;
    for (var s = 0u; s < params.seq_len; s++) {
        if indices[s] == out_row {
            sum += src[s * params.embed_dim + col];
        }
    }
    dst[i] = sum;
}