meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
struct Params {
    len: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
}

var<storage> src: array<f32>;
var<storage, read_write> dst: array<f32>;
var<uniform> params: Params;
var<workgroup> wg_data: array<f32, 256>;

@compute @workgroup_size(256)
fn sum_all(@builtin(local_invocation_id) lid: vec3<u32>) {
    let tid = lid.x;
    // Strided accumulation
    var acc = 0.0;
    var idx = tid;
    loop {
        if idx >= params.len { break; }
        acc += src[idx];
        idx += 256u;
    }
    wg_data[tid] = acc;
    workgroupBarrier();

    // Tree reduction
    var stride = 128u;
    loop {
        if stride == 0u { break; }
        if tid < stride {
            wg_data[tid] += wg_data[tid + stride];
        }
        workgroupBarrier();
        stride >>= 1u;
    }

    if tid == 0u {
        dst[0] = wg_data[0];
    }
}

@compute @workgroup_size(256)
fn mean_all(@builtin(local_invocation_id) lid: vec3<u32>) {
    let tid = lid.x;
    // Strided accumulation
    var acc = 0.0;
    var idx = tid;
    loop {
        if idx >= params.len { break; }
        acc += src[idx];
        idx += 256u;
    }
    wg_data[tid] = acc;
    workgroupBarrier();

    // Tree reduction
    var stride = 128u;
    loop {
        if stride == 0u { break; }
        if tid < stride {
            wg_data[tid] += wg_data[tid + stride];
        }
        workgroupBarrier();
        stride >>= 1u;
    }

    if tid == 0u {
        dst[0] = wg_data[0] / f32(params.len);
    }
}