use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="conv2d",
subop="mma",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn conv2d_mma<T>(
input: Tensor<T>,
weight: Tensor<T>,
out: Tensor<T>,
#[constexpr] in_ch: u32,
#[constexpr] in_h: u32,
#[constexpr] in_w: u32,
#[constexpr] out_ch: u32,
#[constexpr] out_h: u32,
#[constexpr] out_w: u32,
#[constexpr] kh: u32,
#[constexpr] kw: u32,
) {
let oc_tile = tgid_x;
let px_tile = tgid_y;
let lane = simd_lane;
let sg = simd_group_id();
let sm = sg / 2u32;
let sn = sg & 1u32;
let lane_in_tg = sg * 32u32 + lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
let stride = 36u32;
threadgroup_alloc("as", 1152, T); threadgroup_alloc("bs", 1152, T); let c_f00 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f00, 0, 0.0f32);
simdgroup_elem_store(c_f00, 1, 0.0f32);
let c_f01 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f01, 0, 0.0f32);
simdgroup_elem_store(c_f01, 1, 0.0f32);
let c_f10 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f10, 0, 0.0f32);
simdgroup_elem_store(c_f10, 1, 0.0f32);
let c_f11 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f11, 0, 0.0f32);
simdgroup_elem_store(c_f11, 1, 0.0f32);
let a_f0 = simdgroup_alloc::<T, 8, 8>();
let a_f1 = simdgroup_alloc::<T, 8, 8>();
let b_f0 = simdgroup_alloc::<T, 8, 8>();
let b_f1 = simdgroup_alloc::<T, 8, 8>();
let kk = kh * kw;
let total_k = in_ch * kk; let out_hw = out_h * out_w;
let a_px_row = lane_in_tg / 4u32; let a_k_quad = lane_in_tg & 3u32; let a_k_base = a_k_quad * 8u32;
let global_px = px_tile * 32u32 + a_px_row;
let n_px = global_px / out_hw;
let rem_px = global_px - n_px * out_hw;
let oh_px = rem_px / out_w;
let ow_px = rem_px - oh_px * out_w;
let in_n_stride = in_ch * in_h * in_w;
let px_in_base = n_px * in_n_stride;
let b_oc_row = lane_in_tg / 4u32;
let b_k_quad = lane_in_tg & 3u32;
let b_k_base = b_k_quad * 8u32;
let global_oc = oc_tile * 32u32 + b_oc_row;
let w_oc_base = global_oc * total_k;
for kb in range(0u32, total_k, 32u32) {
for i in range(0u32, 8u32, 1u32) {
let kt = kb + a_k_base + i;
let in_bounds = kt < total_k;
let kt_safe = select(in_bounds, kt, 0u32);
let ic = kt_safe / kk;
let rem_kt = kt_safe - ic * kk;
let ky = rem_kt / kw;
let kx = rem_kt - ky * kw;
let ih = oh_px + ky;
let iw = ow_px + kx;
let in_idx = px_in_base + ic * in_h * in_w + ih * in_w + iw;
let raw = load(input[in_idx]).cast::<f32>();
let val = select(in_bounds, raw, 0.0f32).cast::<T>();
threadgroup_store("as", a_px_row * stride + a_k_base + i, val);
}
for i in range(0u32, 8u32, 1u32) {
let kt = kb + b_k_base + i;
let in_bounds = kt < total_k;
let kt_safe = select(in_bounds, kt, 0u32);
let w_idx = w_oc_base + kt_safe;
let raw = load(weight[w_idx]).cast::<f32>();
let val = select(in_bounds, raw, 0.0f32).cast::<T>();
threadgroup_store("bs", b_oc_row * stride + b_k_base + i, val);
}
threadgroup_barrier();
let row_a0 = sm * 16u32 + fm;
let row_a1 = sm * 16u32 + 8u32 + fm;
let col_b0 = sn * 16u32;
let col_b1 = sn * 16u32 + 8u32;
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + 8u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + 8u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + 8u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + 8u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + 8u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + 8u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + 8u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + 8u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + 16u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + 16u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + 16u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + 16u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + 16u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + 16u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + 16u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + 16u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + 24u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + 24u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + 24u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + 24u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + 24u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + 24u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + 24u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + 24u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
threadgroup_barrier();
}
let out_px_base = px_tile * 32u32 + sm * 16u32;
let out_oc_base = oc_tile * 32u32 + sn * 16u32;
store(
out[(out_px_base + fm) * out_ch + out_oc_base + fn0],
simdgroup_elem_load(c_f00, 0).cast::<T>(),
);
store(
out[(out_px_base + fm) * out_ch + out_oc_base + fn1],
simdgroup_elem_load(c_f00, 1).cast::<T>(),
);
store(
out[(out_px_base + fm) * out_ch + out_oc_base + 8u32 + fn0],
simdgroup_elem_load(c_f01, 0).cast::<T>(),
);
store(
out[(out_px_base + fm) * out_ch + out_oc_base + 8u32 + fn1],
simdgroup_elem_load(c_f01, 1).cast::<T>(),
);
store(
out[(out_px_base + 8u32 + fm) * out_ch + out_oc_base + fn0],
simdgroup_elem_load(c_f10, 0).cast::<T>(),
);
store(
out[(out_px_base + 8u32 + fm) * out_ch + out_oc_base + fn1],
simdgroup_elem_load(c_f10, 1).cast::<T>(),
);
store(
out[(out_px_base + 8u32 + fm) * out_ch + out_oc_base + 8u32 + fn0],
simdgroup_elem_load(c_f11, 0).cast::<T>(),
);
store(
out[(out_px_base + 8u32 + fm) * out_ch + out_oc_base + 8u32 + fn1],
simdgroup_elem_load(c_f11, 1).cast::<T>(),
);
}