meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
// Conv2d forward: input[N,Ci,H,W] * kernel[Co,Ci,kH,kW] → output[N,Co,oH,oW]
// Dispatch: [ceil(oW/16), ceil(oH/16), N*Co]  workgroup_size(16,16,1)
//
// Optimizations:
//   - Shared memory for weights: one cooperative load per (ci, kernel) instead of 256 redundant reads

struct Params {
    batch: u32,
    in_channels: u32,
    in_h: u32,
    in_w: u32,
    out_channels: u32,
    kernel_h: u32,
    kernel_w: u32,
    stride: u32,
    padding_h: u32,
    out_h: u32,
    out_w: u32,
    padding_w: u32,
}

var<storage> src: array<f32>;       // input
var<storage> weight: array<f32>;    // kernel
var<storage, read_write> dst: array<f32>;  // output
var<uniform> params: Params;

// Shared memory for kernel weights: up to 7×7 kernel
var<workgroup> wg_weight: array<f32, 49>;

@compute @workgroup_size(16, 16, 1)
fn main(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
) {
    let ow = gid.x;
    let oh = gid.y;
    let nco = gid.z;  // n * out_channels + co

    let n = nco / params.out_channels;
    let co = nco % params.out_channels;
    let in_bounds = ow < params.out_w && oh < params.out_h && n < params.batch;

    let tid = lid.y * 16u + lid.x;
    let kernel_size = params.kernel_h * params.kernel_w;
    let i_padding_h = i32(params.padding_h);
    let i_padding_w = i32(params.padding_w);

    var sum = 0.0;

    for (var ci = 0u; ci < params.in_channels; ci++) {
        // Cooperative weight load into shared memory
        if tid < kernel_size {
            wg_weight[tid] = weight[(co * params.in_channels + ci) * kernel_size + tid];
        }
        workgroupBarrier();

        if in_bounds {
            for (var kh = 0u; kh < params.kernel_h; kh++) {
                for (var kw = 0u; kw < params.kernel_w; kw++) {
                    let ih = i32(oh * params.stride + kh) - i_padding_h;
                    let iw = i32(ow * params.stride + kw) - i_padding_w;

                    if ih >= 0 && u32(ih) < params.in_h && iw >= 0 && u32(iw) < params.in_w {
                        let in_idx = ((n * params.in_channels + ci) * params.in_h + u32(ih)) * params.in_w + u32(iw);
                        sum += src[in_idx] * wg_weight[kh * params.kernel_w + kw];
                    }
                }
            }
        }

        workgroupBarrier();
    }

    if in_bounds {
        let out_idx = ((n * params.out_channels + co) * params.out_h + oh) * params.out_w + ow;
        dst[out_idx] = sum;
    }
}