meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
// Split (backward of Concat): extract first Ca channels from [N, Ca+Cb, H, W].
// Dispatch: [ceil(total_out / 256), 1, 1]  workgroup_size(256)
// Entry point split_a extracts channels [0..Ca), split_b extracts [Ca..Ca+Cb).

struct Params {
    batch: u32,
    channels_a: u32,
    channels_b: u32,
    spatial: u32,  // H * W
}

var<storage> src: array<f32>;        // grad_output [N, Ca+Cb, H, W]
var<storage, read_write> dst: array<f32>;
var<uniform> params: Params;

@compute @workgroup_size(256)
fn split_a(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    let total_a = params.batch * params.channels_a * params.spatial;
    if i >= total_a { return; }

    let hw = i % params.spatial;
    let c = (i / params.spatial) % params.channels_a;
    let n = i / (params.channels_a * params.spatial);

    let total_channels = params.channels_a + params.channels_b;
    let src_idx = ((n * total_channels + c) * params.spatial) + hw;
    dst[i] = src[src_idx];
}

@compute @workgroup_size(256)
fn split_b(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    let total_b = params.batch * params.channels_b * params.spatial;
    if i >= total_b { return; }

    let hw = i % params.spatial;
    let cb = (i / params.spatial) % params.channels_b;
    let n = i / (params.channels_b * params.spatial);

    let total_channels = params.channels_a + params.channels_b;
    let src_idx = ((n * total_channels + (params.channels_a + cb)) * params.spatial) + hw;
    dst[i] = src[src_idx];
}