use metaltile::kernel;
use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
spec::{BenchDispatch, BenchSpec},
};
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn mt_qmm_mma_mpp_int8<T>(
w: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
x: Tensor<T>,
mut out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let lane = simd_lane;
let sg = simd_group_id();
let lane_in_tg = sg * 32u32 + lane;
let sm = sg / 2u32;
let sn = sg & 1u32;
let sg_m_base = sm * 16u32;
let sg_n_base = sn * 16u32;
let x_m_base = tgid_y * 32u32;
let w_n_base = tgid_x * 32u32;
threadgroup_alloc("Xs", 1152u32, coop_stage(T)); threadgroup_alloc("Ws", 1152u32, coop_stage(T)); threadgroup_alloc("OutScratch", 1024u32, f32); coop_tile_setup(
"gemm",
16u32,
16u32,
32u32,
coop_stage(T),
"accumulate",
"simdgroup",
f32,
false,
true,
false,
);
coop_tile_zero("gemm");
let x_m_row = lane_in_tg / 4u32; let x_k_quad = lane_in_tg & 3u32; let x_k_base = x_k_quad * 8u32; let x_ws_base = x_m_row * 36u32 + x_k_base; let packs_per_row = k / 4u32;
let wn_plus_wr = w_n_base + x_m_row;
let sb_base = wn_plus_wr * gs_per_row;
let w_pack_row_base = wn_plus_wr * packs_per_row;
let xs_sg_off = sg_m_base * 36u32;
let ws_sg_off = sg_n_base * 36u32;
let sg_scratch_off = sg * 256u32;
for kb in range(0u32, k, 32u32) {
let x_row_dev_base = (x_m_base + x_m_row) * k + kb + x_k_base;
for _i in range(0u32, 8u32, 1u32) {
let xv = load(x[x_row_dev_base + _i]).cast::<f32>();
threadgroup_store("Xs", x_ws_base + _i, xv);
}
let w_kb_off = kb / 4u32 + x_k_quad * 2u32;
for _pi in range(0u32, 2u32, 1u32) {
let pack_dev = w_pack_row_base + w_kb_off + _pi;
let packed = load(w[pack_dev]);
let k_off = kb + x_k_quad * 8u32 + _pi * 4u32;
let g = k_off / 32u32; let sb_off = sb_base + g;
let scale = load(scales[sb_off]).cast::<f32>();
let bias = load(biases[sb_off]).cast::<f32>();
for _bi in range(0u32, 4u32, 1u32) {
let byte_val = ((packed >> (_bi * 8u32)) & 255u32).cast::<f32>();
threadgroup_store("Ws", x_ws_base + _pi * 4u32 + _bi, scale * byte_val + bias);
}
}
threadgroup_barrier();
coop_tile_load_a("gemm", "Xs", true, coop_stage(T), 36u32, 16u32, xs_sg_off);
coop_tile_load_b("gemm", "Ws", true, coop_stage(T), 36u32, 16u32, ws_sg_off);
coop_tile_run("gemm");
threadgroup_barrier();
}
coop_tile_store_c("gemm", "OutScratch", true, f32, 16u32, 16u32, sg_scratch_off);
threadgroup_barrier();
let out_m_base = x_m_base + sg_m_base;
let out_n_base = w_n_base + sg_n_base;
let o_row = lane / 2u32;
let o_col_base = (lane & 1u32) * 8u32;
for _i in range(0u32, 8u32, 1u32) {
let col = o_col_base + _i;
let v = threadgroup_load("OutScratch", sg_scratch_off + o_row * 16u32 + col);
store(out[(out_m_base + o_row) * n + (out_n_base + col)], v.cast::<T>());
}
}
inventory::submit! {
BenchSpec {
op: "quantized",
subop: "qmm_mma_mpp_int8",
kernel_name: "mt_qmm_mma_mpp_int8",
kernel_ir: mt_qmm_mma_mpp_int8::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 5e-2,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
#[cfg(test)]
mod tests {
use metaltile_codegen::msl::MslGenerator;
use metaltile_core::ir::Op;
use super::*;
#[test]
fn kernel_ir_constructs_and_uses_coop_tile_ops() {
for dt in [DType::F32, DType::F16, DType::BF16] {
let k = mt_qmm_mma_mpp_int8::kernel_ir_for(dt);
assert_eq!(k.name, "mt_qmm_mma_mpp_int8");
assert_eq!(k.params.len(), 5);
assert_eq!(k.params[0].name, "w");
assert_eq!(k.params[1].name, "scales");
assert_eq!(k.params[2].name, "biases");
assert_eq!(k.params[3].name, "x");
assert_eq!(k.params[4].name, "out");
assert!(k.params[4].is_output);
assert_eq!(k.constexprs.len(), 3);
assert_eq!(k.constexprs[0].name.name(), "k");
assert_eq!(k.constexprs[1].name.name(), "n");
assert_eq!(k.constexprs[2].name.name(), "gs_per_row");
let all_ops =
|| std::iter::once(&k.body).chain(k.blocks.values()).flat_map(|b| b.ops.iter());
assert!(!all_ops().any(|op| matches!(op, Op::InlineMsl { .. })));
assert!(all_ops().any(|op| matches!(op, Op::CoopTileSetup { .. })));
assert!(all_ops().any(|op| matches!(op, Op::CoopTileLoadA { .. })));
assert!(all_ops().any(|op| matches!(op, Op::CoopTileLoadB { .. })));
assert!(all_ops().any(|op| matches!(op, Op::CoopTileRun { .. })));
assert!(all_ops().any(|op| matches!(op, Op::CoopTileStoreC { .. })));
}
}
#[test]
fn bf16_stages_through_half() {
let k = mt_qmm_mma_mpp_int8::kernel_ir_for(DType::BF16);
let setup = std::iter::once(&k.body)
.chain(k.blocks.values())
.flat_map(|b| b.ops.iter())
.find_map(|op| match op {
Op::CoopTileSetup { act_dtype, .. } => Some(*act_dtype),
_ => None,
})
.expect("CoopTileSetup present");
assert_eq!(setup, DType::F16, "bf16 activation must stage as half for matmul2d");
}
#[test]
fn codegen_emits_mpp_include_and_kernel_decl() {
for (dt, t_name) in [(DType::F32, "float"), (DType::F16, "half"), (DType::BF16, "half")] {
let mut k = mt_qmm_mma_mpp_int8::kernel_ir_for(dt);
let suffix = match dt {
DType::F32 => "f32",
DType::F16 => "f16",
DType::BF16 => "bf16",
_ => unreachable!(),
};
k.name = format!("mt_qmm_mma_mpp_int8_{suffix}");
let msl = MslGenerator::default().generate(&k).expect("codegen");
assert!(
msl.contains("MetalPerformancePrimitives/MetalPerformancePrimitives.h"),
"MPP include missing:\n{msl}"
);
assert!(msl.contains("mpp::tensor_ops::matmul2d_descriptor"));
assert!(msl.contains(&format!("kernel void mt_qmm_mma_mpp_int8_{suffix}")));
assert!(msl.contains(&format!("threadgroup {t_name} Xs")));
}
}
}