use metaltile::kernel;
#[kernel]
pub fn mt_mma_probe_a_identity_b_gemm(out: Tensor<f32>) {
let lane = simd_lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
let a = simdgroup_alloc::<f32, 8, 8>();
let b = simdgroup_alloc::<f32, 8, 8>();
let c = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c, 0, 0.0f32);
simdgroup_elem_store(c, 1, 0.0f32);
let a0 = select(fm == fn0, 1.0f32, 0.0f32);
let a1 = select(fm == fn1, 1.0f32, 0.0f32);
simdgroup_elem_store(a, 0, a0);
simdgroup_elem_store(a, 1, a1);
let b0 = (fn0 * 8u32 + fm).cast::<f32>();
let b1 = (fn1 * 8u32 + fm).cast::<f32>();
simdgroup_elem_store(b, 0, b0);
simdgroup_elem_store(b, 1, b1);
simdgroup_matmul(a, b, c);
let c0 = simdgroup_elem_load(c, 0);
let c1 = simdgroup_elem_load(c, 1);
store(out[fm * 8u32 + fn0], c0);
store(out[fm * 8u32 + fn1], c1);
}
#[kernel]
pub fn mt_mma_probe_a_identity_b_identity(out: Tensor<f32>) {
let lane = simd_lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
let a = simdgroup_alloc::<f32, 8, 8>();
let b = simdgroup_alloc::<f32, 8, 8>();
let c = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c, 0, 0.0f32);
simdgroup_elem_store(c, 1, 0.0f32);
let a0 = select(fm == fn0, 1.0f32, 0.0f32);
let a1 = select(fm == fn1, 1.0f32, 0.0f32);
simdgroup_elem_store(a, 0, a0);
simdgroup_elem_store(a, 1, a1);
let b0 = (fm * 8u32 + fn0).cast::<f32>();
let b1 = (fm * 8u32 + fn1).cast::<f32>();
simdgroup_elem_store(b, 0, b0);
simdgroup_elem_store(b, 1, b1);
simdgroup_matmul(a, b, c);
let c0 = simdgroup_elem_load(c, 0);
let c1 = simdgroup_elem_load(c, 1);
store(out[fm * 8u32 + fn0], c0);
store(out[fm * 8u32 + fn1], c1);
}