meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
struct Params {
    batch: u32,
    features: u32,
    _pad0: u32,
    _pad1: u32,
}

var<storage> logits: array<f32>;
var<storage> labels: array<f32>;
var<storage, read_write> grad_out: array<f32>;
var<storage, read_write> loss_out: array<f32>;
var<uniform> params: Params;

@compute @workgroup_size(1)
fn main() {
    var total_loss = 0.0;
    for (var b = 0u; b < params.batch; b++) {
        let offset = b * params.features;

        // Find max for numerical stability
        var max_val = logits[offset];
        for (var j = 1u; j < params.features; j++) {
            max_val = max(max_val, logits[offset + j]);
        }

        // Log-sum-exp
        var sum_exp = 0.0;
        for (var j = 0u; j < params.features; j++) {
            sum_exp += exp(logits[offset + j] - max_val);
        }
        let log_sum_exp = log(sum_exp) + max_val;

        // Loss and gradient
        for (var j = 0u; j < params.features; j++) {
            let log_softmax = logits[offset + j] - log_sum_exp;
            let softmax = exp(log_softmax);
            total_loss -= labels[offset + j] * log_softmax;
            grad_out[offset + j] = (softmax - labels[offset + j]) / f32(params.batch);
        }
    }
    loss_out[0] = total_loss / f32(params.batch);
}