meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
// Conv2d backward w.r.t. input: grad_output[N,Co,oH,oW] × kernel[Co,Ci,kH,kW] → grad_input[N,Ci,H,W]
// This is a "full convolution" of grad_output with flipped kernel.
// Dispatch: [ceil(W/16), ceil(H/16), N*Ci]  workgroup_size(16,16,1)
//
// Optimizations:
//   - Stride-1 fast path: eliminates modulo/division per iteration
//   - Shared memory for weights: one cooperative load per (co, kernel) instead of 256 redundant reads

struct Params {
    batch: u32,
    in_channels: u32,
    in_h: u32,
    in_w: u32,
    out_channels: u32,
    kernel_h: u32,
    kernel_w: u32,
    stride: u32,
    padding_h: u32,
    out_h: u32,
    out_w: u32,
    padding_w: u32,
}

var<storage> grad_out: array<f32>;    // grad_output [N,Co,oH,oW]
var<storage> weight: array<f32>;      // kernel [Co,Ci,kH,kW]
var<storage, read_write> dst: array<f32>;  // grad_input [N,Ci,H,W]
var<uniform> params: Params;

// Shared memory for kernel weights: up to 7×7 kernel (practically ≤ 3×3)
var<workgroup> wg_weight: array<f32, 49>;

@compute @workgroup_size(16, 16, 1)
fn main(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
) {
    let iw = gid.x;
    let ih = gid.y;
    let nci = gid.z;  // n * in_channels + ci

    let n = nci / params.in_channels;
    let ci = nci % params.in_channels;
    let in_bounds = iw < params.in_w && ih < params.in_h && n < params.batch;

    let tid = lid.y * 16u + lid.x;
    let kernel_size = params.kernel_h * params.kernel_w;
    let i_padding_h = i32(params.padding_h);
    let i_padding_w = i32(params.padding_w);

    var sum = 0.0;

    if params.stride == 1u {
        // Fast path for stride=1: no modulo/division, simplified bounds
        for (var co = 0u; co < params.out_channels; co++) {
            // Cooperative weight load into shared memory
            if tid < kernel_size {
                wg_weight[tid] = weight[(co * params.in_channels + ci) * kernel_size + tid];
            }
            workgroupBarrier();

            if in_bounds {
                let go_base = (n * params.out_channels + co) * params.out_h * params.out_w;

                for (var kh = 0u; kh < params.kernel_h; kh++) {
                    let oh = i32(ih) + i_padding_h - i32(kh);
                    if oh >= 0 && u32(oh) < params.out_h {
                        for (var kw = 0u; kw < params.kernel_w; kw++) {
                            let ow = i32(iw) + i_padding_w - i32(kw);
                            if ow >= 0 && u32(ow) < params.out_w {
                                sum += grad_out[go_base + u32(oh) * params.out_w + u32(ow)]
                                     * wg_weight[kh * params.kernel_w + kw];
                            }
                        }
                    }
                }
            }

            workgroupBarrier();
        }
    } else {
        // General path for stride > 1
        for (var co = 0u; co < params.out_channels; co++) {
            if tid < kernel_size {
                wg_weight[tid] = weight[(co * params.in_channels + ci) * kernel_size + tid];
            }
            workgroupBarrier();

            if in_bounds {
                let go_base = (n * params.out_channels + co) * params.out_h * params.out_w;
                let i_stride = i32(params.stride);

                for (var kh = 0u; kh < params.kernel_h; kh++) {
                    let h_off = i32(ih) + i_padding_h - i32(kh);
                    if h_off >= 0 && (h_off % i_stride) == 0 {
                        let oh = u32(h_off) / params.stride;
                        if oh < params.out_h {
                            for (var kw = 0u; kw < params.kernel_w; kw++) {
                                let w_off = i32(iw) + i_padding_w - i32(kw);
                                if w_off >= 0 && (w_off % i_stride) == 0 {
                                    let ow = u32(w_off) / params.stride;
                                    if ow < params.out_w {
                                        sum += grad_out[go_base + oh * params.out_w + ow]
                                             * wg_weight[kh * params.kernel_w + kw];
                                    }
                                }
                            }
                        }
                    }
                }
            }

            workgroupBarrier();
        }
    }

    if in_bounds {
        let in_idx = ((n * params.in_channels + ci) * params.in_h + ih) * params.in_w + iw;
        dst[in_idx] = sum;
    }
}