use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="patch_embed",
subop="mma",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn patch_embed_mma<T>(
image: Tensor<T>,
weight: Tensor<T>,
bias: Tensor<T>,
out: Tensor<T>,
#[constexpr] in_ch: u32,
#[constexpr] in_h: u32,
#[constexpr] in_w: u32,
#[constexpr] patch_h: u32,
#[constexpr] patch_w: u32,
#[constexpr] hidden: u32,
) {
let h_tile = tgid_x;
let pat_tile = tgid_y;
let lane = simd_lane;
let sg = simd_group_id();
let sm = sg / 2u32;
let sn = sg & 1u32;
let lane_in_tg = sg * 32u32 + lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
let stride = 36u32;
threadgroup_alloc("as", 1152, T); threadgroup_alloc("bs", 1152, T); let c_f00 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f00, 0, 0.0f32);
simdgroup_elem_store(c_f00, 1, 0.0f32);
let c_f01 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f01, 0, 0.0f32);
simdgroup_elem_store(c_f01, 1, 0.0f32);
let c_f10 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f10, 0, 0.0f32);
simdgroup_elem_store(c_f10, 1, 0.0f32);
let c_f11 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f11, 0, 0.0f32);
simdgroup_elem_store(c_f11, 1, 0.0f32);
let a_f0 = simdgroup_alloc::<T, 8, 8>();
let a_f1 = simdgroup_alloc::<T, 8, 8>();
let b_f0 = simdgroup_alloc::<T, 8, 8>();
let b_f1 = simdgroup_alloc::<T, 8, 8>();
let phw = patch_h * patch_w;
let patch_dim = in_ch * phw; let patches_w = in_w / patch_w; let input_plane = in_h * in_w;
let a_pat_row = lane_in_tg / 4u32;
let a_k_quad = lane_in_tg & 3u32;
let a_k_base = a_k_quad * 8u32;
let global_pat = pat_tile * 32u32 + a_pat_row;
let py0 = (global_pat / patches_w) * patch_h;
let px0 = (global_pat - (global_pat / patches_w) * patches_w) * patch_w;
let b_h_row = lane_in_tg / 4u32; let b_k_quad = lane_in_tg & 3u32;
let b_k_base = b_k_quad * 8u32;
let global_h = h_tile * 32u32 + b_h_row;
let w_h_base = global_h * patch_dim;
for kb in range(0u32, patch_dim, 32u32) {
for i in range(0u32, 8u32, 1u32) {
let kt = kb + a_k_base + i;
let in_bounds = kt < patch_dim;
let kt_safe = select(in_bounds, kt, 0u32);
let ic = kt_safe / phw;
let rem_kt = kt_safe - ic * phw;
let py = rem_kt / patch_w;
let px = rem_kt - py * patch_w;
let img_idx = ic * input_plane + (py0 + py) * in_w + (px0 + px);
let raw = load(image[img_idx]).cast::<f32>();
let val = select(in_bounds, raw, 0.0f32).cast::<T>();
threadgroup_store("as", a_pat_row * stride + a_k_base + i, val);
}
for i in range(0u32, 8u32, 1u32) {
let kt = kb + b_k_base + i;
let in_bounds = kt < patch_dim;
let kt_safe = select(in_bounds, kt, 0u32);
let w_idx = w_h_base + kt_safe;
let raw = load(weight[w_idx]).cast::<f32>();
let val = select(in_bounds, raw, 0.0f32).cast::<T>();
threadgroup_store("bs", b_h_row * stride + b_k_base + i, val);
}
threadgroup_barrier();
let row_a0 = sm * 16u32 + fm;
let row_a1 = sm * 16u32 + 8u32 + fm;
let col_b0 = sn * 16u32;
let col_b1 = sn * 16u32 + 8u32;
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + 8u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + 8u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + 8u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + 8u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + 8u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + 8u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + 8u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + 8u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + 16u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + 16u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + 16u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + 16u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + 16u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + 16u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + 16u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + 16u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + 24u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + 24u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + 24u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + 24u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + 24u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + 24u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + 24u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + 24u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
threadgroup_barrier();
}
let out_pat_base = pat_tile * 32u32 + sm * 16u32;
let out_h_base = h_tile * 32u32 + sn * 16u32;
let b00 = load(bias[out_h_base + fn0]).cast::<f32>();
let b01 = load(bias[out_h_base + fn1]).cast::<f32>();
let b10 = load(bias[out_h_base + 8u32 + fn0]).cast::<f32>();
let b11 = load(bias[out_h_base + 8u32 + fn1]).cast::<f32>();
store(
out[(out_pat_base + fm) * hidden + out_h_base + fn0],
(simdgroup_elem_load(c_f00, 0) + b00).cast::<T>(),
);
store(
out[(out_pat_base + fm) * hidden + out_h_base + fn1],
(simdgroup_elem_load(c_f00, 1) + b01).cast::<T>(),
);
store(
out[(out_pat_base + fm) * hidden + out_h_base + 8u32 + fn0],
(simdgroup_elem_load(c_f01, 0) + b10).cast::<T>(),
);
store(
out[(out_pat_base + fm) * hidden + out_h_base + 8u32 + fn1],
(simdgroup_elem_load(c_f01, 1) + b11).cast::<T>(),
);
store(
out[(out_pat_base + 8u32 + fm) * hidden + out_h_base + fn0],
(simdgroup_elem_load(c_f10, 0) + b00).cast::<T>(),
);
store(
out[(out_pat_base + 8u32 + fm) * hidden + out_h_base + fn1],
(simdgroup_elem_load(c_f10, 1) + b01).cast::<T>(),
);
store(
out[(out_pat_base + 8u32 + fm) * hidden + out_h_base + 8u32 + fn0],
(simdgroup_elem_load(c_f11, 0) + b10).cast::<T>(),
);
store(
out[(out_pat_base + 8u32 + fm) * hidden + out_h_base + 8u32 + fn1],
(simdgroup_elem_load(c_f11, 1) + b11).cast::<T>(),
);
}