use metaltile::kernel;
macro_rules! steel_gemm_segmented_kernel {
($name:ident, $bm:literal, $bn:literal, $wm:literal, $wn:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(
a: Tensor<T>,
b: Tensor<T>,
segments: Tensor<u32>,
out: Tensor<T>,
#[constexpr] m: u32,
#[constexpr] n: u32,
#[constexpr] total_k: u32,
) {
let bm = $bm;
let bn = $bn;
let wm = $wm;
let wn = $wn;
let sub_m = bm / wm;
let sub_n = bn / wn;
let n_fm = sub_m / 8u32;
let n_fn = sub_n / 8u32;
let n_kf = 2u32;
let tg_col = program_id::<0>(); let tg_row = program_id::<1>(); let seg = program_id::<2>(); let sg_id = simd_group_id();
let sg_m = sg_id / wn;
let sg_n = sg_id % wn;
let lane = simd_lane_id();
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
let sub_m0 = sg_m * sub_m;
let sub_n0 = sg_n * sub_n;
let block_m0 = tg_row * bm;
let block_n0 = tg_col * bn;
let k_start = load(segments[seg * 2u32]);
let k_end = load(segments[seg * 2u32 + 1u32]);
let out_base = seg * m * n;
for _fm_i in range(0, n_fm, 1) {
for _fn_i in range(0, n_fn, 1) {
let acc = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(acc, 0, 0.0f32);
simdgroup_elem_store(acc, 1, 0.0f32);
let m_row = block_m0 + sub_m0 + _fm_i * 8u32;
let n_col = block_n0 + sub_n0 + _fn_i * 8u32;
for kb in range(k_start, k_end, 16) {
for _kf in range(0, n_kf, 1) {
let kf = kb + _kf * 8u32;
let sub_a = simdgroup_alloc::<T, 8, 8>();
let sub_b = simdgroup_alloc::<T, 8, 8>();
simdgroup_elem_store(
sub_a,
0,
load(a[(m_row + fm) * total_k + kf + fn0]).cast::<T>(),
);
simdgroup_elem_store(
sub_a,
1,
load(a[(m_row + fm) * total_k + kf + fn1]).cast::<T>(),
);
simdgroup_elem_store(
sub_b,
0,
load(b[(kf + fm) * n + n_col + fn0]).cast::<T>(),
);
simdgroup_elem_store(
sub_b,
1,
load(b[(kf + fm) * n + n_col + fn1]).cast::<T>(),
);
simdgroup_matmul(sub_a, sub_b, acc);
}
}
let r0 = simdgroup_elem_load(acc, 0);
let r1 = simdgroup_elem_load(acc, 1);
store(out[out_base + (m_row + fm) * n + n_col + fn0], r0.cast::<T>());
store(out[out_base + (m_row + fm) * n + n_col + fn1], r1.cast::<T>());
}
}
}
inventory::submit! { crate::spec::BenchSpec {
op: "steel_gemm_segmented", subop: $subop,
kernel_name: stringify!($name),
kernel_ir: $name::kernel_ir_for,
dtypes: crate::bench_types::FLOAT_DTYPES, tol: 1e-2f32,
mlx_src: Some(include_str!(concat!(env!("OUT_DIR"), "/metal/steel/gemm/steel_gemm_segmented.metal"))),
mlx_pattern: Some(concat!(
"steel_segmented_mm_nn_{tn}_{tn}_bm", stringify!($bm),
"_bn", stringify!($bn), "_bk16_wm", stringify!($wm),
"_wn", stringify!($wn),
)),
shapes: &[],
dispatch: crate::spec::BenchDispatch::SteelGemm {
m: 4096, n: 4096, k: 4096,
check_m: $bm as usize, check_n: $bn as usize, check_k: 16,
bm: $bm as usize, bn: $bn as usize,
tpg: ($wm * $wn * 32u32) as usize,
},
kernel_mode: Some(metaltile_core::ir::KernelMode::SimdGroup2D),
}}
};
}
steel_gemm_segmented_kernel!(
mt_steel_gemm_segmented_64x64x16_2x2,
64u32,
64u32,
2u32,
2u32,
"bm64_bn64_bk16_wm2_wn2"
);
steel_gemm_segmented_kernel!(
mt_steel_gemm_segmented_32x32x16_2x2,
32u32,
32u32,
2u32,
2u32,
"bm32_bn32_bk16_wm2_wn2"
);