meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
struct Params {
    batch: u32,
    features: u32,
    _pad0: u32,
    _pad1: u32,
}

var<storage> src: array<f32>;
var<storage, read_write> dst: array<f32>;
var<uniform> params: Params;

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let row = gid.x;
    if row >= params.batch { return; }
    let offset = row * params.features;

    // Find max for numerical stability
    var max_val = src[offset];
    for (var j = 1u; j < params.features; j++) {
        max_val = max(max_val, src[offset + j]);
    }

    // Compute exp(x - max) and sum
    var sum_exp = 0.0;
    for (var j = 0u; j < params.features; j++) {
        let exp_val = exp(src[offset + j] - max_val);
        dst[offset + j] = exp_val;
        sum_exp += exp_val;
    }

    // Normalize (guard against division by zero when all inputs are -inf)
    let safe_sum = select(sum_exp, 1.0, sum_exp == 0.0);
    for (var j = 0u; j < params.features; j++) {
        dst[offset + j] /= safe_sum;
    }
}