meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
// Cooperative matrix matmul: 2×2 tile grid ($OUTPUT_TILE×$OUTPUT_TILE output per WG)
// Dispatch: [ceil(m/$OUTPUT_TILE), ceil(n/$OUTPUT_TILE), 1], WG=64
// Parameterized by tile size ($TILE_SIZE) and element type ($ELEM_TYPE).
// - 16×16 f16 path: RDNA3/Volta+ (VK_KHR_cooperative_matrix)
// -  8×8 f32 path:  Apple Silicon (simdgroup_matrix)

$ENABLE_F16
enable wgpu_cooperative_matrix;

struct Params {
    m: u32,
    n: u32,
    k: u32,
    _pad: u32,
}

var<storage> matrix_a: array<f32>;
var<storage> matrix_b: array<f32>;
var<storage, read_write> matrix_c: array<f32>;
$FUSED_ADD_DECL
var<uniform> params: Params;
var<workgroup> shared_a0: array<$ELEM_TYPE, $SHARED_SIZE>;
var<workgroup> shared_a1: array<$ELEM_TYPE, $SHARED_SIZE>;
var<workgroup> shared_b0: array<$ELEM_TYPE, $SHARED_SIZE>;
var<workgroup> shared_b1: array<$ELEM_TYPE, $SHARED_SIZE>;

@compute @workgroup_size(64)
fn main(@builtin(workgroup_id) wgid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) {
    let tile_row = wgid.x * $OUTPUT_TILE_U;
    let tile_col = wgid.y * $OUTPUT_TILE_U;
    let m = params.m;
    let n = params.n;
    let k = params.k;

    // C offsets for the 4 output tiles
    let c00 = tile_row * n + tile_col;
    let c01 = tile_row * n + (tile_col + $TILE_SIZE_U);
    let c10 = (tile_row + $TILE_SIZE_U) * n + tile_col;
    let c11 = (tile_row + $TILE_SIZE_U) * n + (tile_col + $TILE_SIZE_U);

    // Validity flags for secondary tiles
    let n1_valid = (tile_col + $TILE_SIZE_U) < n;
    let m1_valid = (tile_row + $TILE_SIZE_U) < m;

    // Initialize accumulators
    $ACC_INIT

    // Hoisted staging index components
    let src_col = lid.x & $TILE_MASK_U;
    let base_row = lid.x >> $TILE_SHIFT_U;
    let cc = tile_col + src_col;
    let in_n = cc < n;
    let cc1 = cc + $TILE_SIZE_U;
    let in_n1 = cc1 < n;

    var t = 0u;
    loop {
        if t >= k { break; }

        // Stage sa0: B[t:t+tile, tile_col:tile_col+tile] → shared_a0
        let zero_val = $ELEM_ZERO;
        for (var e = 0u; e < $STAGING_ITERS_U; e++) {
            let flat = lid.x + e * 64u;
            let tr = t + base_row + e * $ROW_STRIDE_U;
            let in_bounds = (tr < k) && in_n;
            if in_bounds {
                shared_a0[flat] = $CAST_OPEN matrix_b[$B_INDEX_0] $CAST_CLOSE;
            } else {
                shared_a0[flat] = zero_val;
            }
        }

        // Stage sa1: B[t:t+tile, tile_col+tile:tile_col+2*tile] → shared_a1
        for (var e = 0u; e < $STAGING_ITERS_U; e++) {
            let flat = lid.x + e * 64u;
            let tr = t + base_row + e * $ROW_STRIDE_U;
            let in_bounds = (tr < k) && in_n1;
            if in_bounds {
                shared_a1[flat] = $CAST_OPEN matrix_b[$B_INDEX_1] $CAST_CLOSE;
            } else {
                shared_a1[flat] = zero_val;
            }
        }

        // Stage sb0: A[tile_row:tile_row+tile, t:t+tile] → shared_b0
        let tc = t + src_col;
        let in_k = tc < k;
        for (var e = 0u; e < $STAGING_ITERS_U; e++) {
            let flat = lid.x + e * 64u;
            let gr = tile_row + base_row + e * $ROW_STRIDE_U;
            let in_bounds = (gr < m) && in_k;
            if in_bounds {
                let a_val = matrix_a[$A_INDEX_0];
                shared_b0[flat] = $CAST_OPEN a_val $A_TRANSFORM_0 $CAST_CLOSE;
            } else {
                shared_b0[flat] = zero_val;
            }
        }

        // Stage sb1: A[tile_row+tile:tile_row+2*tile, t:t+tile] → shared_b1
        for (var e = 0u; e < $STAGING_ITERS_U; e++) {
            let flat = lid.x + e * 64u;
            let gr = tile_row + $TILE_SIZE_U + base_row + e * $ROW_STRIDE_U;
            let in_bounds = (gr < m) && in_k;
            if in_bounds {
                let a_val = matrix_a[$A_INDEX_1];
                shared_b1[flat] = $CAST_OPEN a_val $A_TRANSFORM_1 $CAST_CLOSE;
            } else {
                shared_b1[flat] = zero_val;
            }
        }

        workgroupBarrier();

        // Cooperative matrix multiply-add: C += A × B
        // shared_b{0,1} hold A-matrix row tiles; shared_a{0,1} hold B-matrix column tiles.
        // Load A data into role-A (left operand), B data into role-B (right operand).
        let a0 = coopLoadT<$COOP_AB>(&shared_b0[0], $TILE_SIZE_U);
        let a1 = coopLoadT<$COOP_AB>(&shared_b1[0], $TILE_SIZE_U);
        let b0 = coopLoadT<$COOP_BA>(&shared_a0[0], $TILE_SIZE_U);
        let b1 = coopLoadT<$COOP_BA>(&shared_a1[0], $TILE_SIZE_U);
        acc00 = coopMultiplyAdd(a0, b0, acc00);
        acc01 = coopMultiplyAdd(a0, b1, acc01);
        acc10 = coopMultiplyAdd(a1, b0, acc10);
        acc11 = coopMultiplyAdd(a1, b1, acc11);

        workgroupBarrier();
        t += $TILE_SIZE_U;
    }

    // Store results
    coopStoreT(acc00, &matrix_c[c00], n);
    if n1_valid {
        coopStoreT(acc01, &matrix_c[c01], n);
    }
    if m1_valid {
        coopStoreT(acc10, &matrix_c[c10], n);
    }
    if n1_valid && m1_valid {
        coopStoreT(acc11, &matrix_c[c11], n);
    }
}