use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="conv3d",
subop="mma",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn conv3d_mma<T>(
input: Tensor<T>,
weight: Tensor<T>,
out: Tensor<T>,
#[constexpr] in_ch: u32,
#[constexpr] in_d: u32,
#[constexpr] in_h: u32,
#[constexpr] in_w: u32,
#[constexpr] out_ch: u32,
#[constexpr] out_d: u32,
#[constexpr] out_h: u32,
#[constexpr] out_w: u32,
#[constexpr] kd: u32,
#[constexpr] kh: u32,
#[constexpr] kw: u32,
) {
let oc_tile = tgid_x;
let pv_tile = tgid_y;
let lane = simd_lane;
let sg = simd_group_id();
let sm = sg / 2u32;
let sn = sg & 1u32;
let lane_in_tg = sg * 32u32 + lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
let stride = 36u32;
threadgroup_alloc("as", 1152, T);
threadgroup_alloc("bs", 1152, T);
let c_f00 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f00, 0, 0.0f32);
simdgroup_elem_store(c_f00, 1, 0.0f32);
let c_f01 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f01, 0, 0.0f32);
simdgroup_elem_store(c_f01, 1, 0.0f32);
let c_f10 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f10, 0, 0.0f32);
simdgroup_elem_store(c_f10, 1, 0.0f32);
let c_f11 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f11, 0, 0.0f32);
simdgroup_elem_store(c_f11, 1, 0.0f32);
let a_f0 = simdgroup_alloc::<T, 8, 8>();
let a_f1 = simdgroup_alloc::<T, 8, 8>();
let b_f0 = simdgroup_alloc::<T, 8, 8>();
let b_f1 = simdgroup_alloc::<T, 8, 8>();
let khw = kh * kw;
let kdhw = kd * khw; let total_k = in_ch * kdhw; let out_hw = out_h * out_w;
let out_dhw = out_d * out_hw;
let a_pv_row = lane_in_tg / 4u32;
let a_k_quad = lane_in_tg & 3u32;
let a_k_base = a_k_quad * 8u32;
let global_pv = pv_tile * 32u32 + a_pv_row;
let n_pv = global_pv / out_dhw;
let rem_pv = global_pv - n_pv * out_dhw;
let od_pv = rem_pv / out_hw;
let rem_hw = rem_pv - od_pv * out_hw;
let oh_pv = rem_hw / out_w;
let ow_pv = rem_hw - oh_pv * out_w;
let in_plane = in_h * in_w;
let in_vol = in_d * in_plane;
let in_n_stride = in_ch * in_vol;
let pv_in_base = n_pv * in_n_stride;
let b_oc_row = lane_in_tg / 4u32;
let b_k_quad = lane_in_tg & 3u32;
let b_k_base = b_k_quad * 8u32;
let global_oc = oc_tile * 32u32 + b_oc_row;
let w_oc_base = global_oc * total_k;
for kb in range(0u32, total_k, 32u32) {
for i in range(0u32, 8u32, 1u32) {
let kt = kb + a_k_base + i;
let in_bounds = kt < total_k;
let kt_safe = select(in_bounds, kt, 0u32);
let ic = kt_safe / kdhw;
let rem_kt = kt_safe - ic * kdhw;
let kz = rem_kt / khw;
let rem_kh = rem_kt - kz * khw;
let ky = rem_kh / kw;
let kx = rem_kh - ky * kw;
let id = od_pv + kz;
let ih = oh_pv + ky;
let iw = ow_pv + kx;
let in_idx = pv_in_base + ic * in_vol + id * in_plane + ih * in_w + iw;
let raw = load(input[in_idx]).cast::<f32>();
let val = select(in_bounds, raw, 0.0f32).cast::<T>();
threadgroup_store("as", a_pv_row * stride + a_k_base + i, val);
}
for i in range(0u32, 8u32, 1u32) {
let kt = kb + b_k_base + i;
let in_bounds = kt < total_k;
let kt_safe = select(in_bounds, kt, 0u32);
let w_idx = w_oc_base + kt_safe;
let raw = load(weight[w_idx]).cast::<f32>();
let val = select(in_bounds, raw, 0.0f32).cast::<T>();
threadgroup_store("bs", b_oc_row * stride + b_k_base + i, val);
}
threadgroup_barrier();
let row_a0 = sm * 16u32 + fm;
let row_a1 = sm * 16u32 + 8u32 + fm;
let col_b0 = sn * 16u32;
let col_b1 = sn * 16u32 + 8u32;
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + 8u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + 8u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + 8u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + 8u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + 8u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + 8u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + 8u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + 8u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + 16u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + 16u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + 16u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + 16u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + 16u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + 16u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + 16u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + 16u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("as", row_a0 * stride + 24u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("as", row_a0 * stride + 24u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("as", row_a1 * stride + 24u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("as", row_a1 * stride + 24u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("bs", (col_b0 + fn0) * stride + 24u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("bs", (col_b0 + fn1) * stride + 24u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("bs", (col_b1 + fn0) * stride + 24u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("bs", (col_b1 + fn1) * stride + 24u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
threadgroup_barrier();
}
let out_pv_base = pv_tile * 32u32 + sm * 16u32;
let out_oc_base = oc_tile * 32u32 + sn * 16u32;
store(
out[(out_pv_base + fm) * out_ch + out_oc_base + fn0],
simdgroup_elem_load(c_f00, 0).cast::<T>(),
);
store(
out[(out_pv_base + fm) * out_ch + out_oc_base + fn1],
simdgroup_elem_load(c_f00, 1).cast::<T>(),
);
store(
out[(out_pv_base + fm) * out_ch + out_oc_base + 8u32 + fn0],
simdgroup_elem_load(c_f01, 0).cast::<T>(),
);
store(
out[(out_pv_base + fm) * out_ch + out_oc_base + 8u32 + fn1],
simdgroup_elem_load(c_f01, 1).cast::<T>(),
);
store(
out[(out_pv_base + 8u32 + fm) * out_ch + out_oc_base + fn0],
simdgroup_elem_load(c_f10, 0).cast::<T>(),
);
store(
out[(out_pv_base + 8u32 + fm) * out_ch + out_oc_base + fn1],
simdgroup_elem_load(c_f10, 1).cast::<T>(),
);
store(
out[(out_pv_base + 8u32 + fm) * out_ch + out_oc_base + 8u32 + fn0],
simdgroup_elem_load(c_f11, 0).cast::<T>(),
);
store(
out[(out_pv_base + 8u32 + fm) * out_ch + out_oc_base + 8u32 + fn1],
simdgroup_elem_load(c_f11, 1).cast::<T>(),
);
}