use metaltile::kernel;
macro_rules! steel_gemm_splitk_kernel {
($name:ident, $bm:literal, $bn:literal, $wm:literal, $wn:literal, $subop:literal) => {
#[kernel]
pub fn $name<T>(
a: Tensor<T>,
b: Tensor<T>,
mut partials: Tensor<f32>,
#[constexpr] m: u32,
#[constexpr] n: u32,
#[constexpr] k: u32,
#[constexpr] k_per_split: u32,
) {
let bm = $bm;
let bn = $bn;
let wm = $wm;
let wn = $wn;
let sub_m = bm / wm;
let sub_n = bn / wn;
let n_fm = sub_m / 8u32;
let n_fn = sub_n / 8u32;
let n_kf = 2u32;
let tg_col = program_id::<0>(); let tg_row = program_id::<1>(); let split = program_id::<2>(); let sg_id = simd_group_id();
let sg_m = sg_id / wn;
let sg_n = sg_id % wn;
let lane = simd_lane_id();
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
let sub_m0 = sg_m * sub_m;
let sub_n0 = sg_n * sub_n;
let block_m0 = tg_row * bm;
let block_n0 = tg_col * bn;
let k_start = split * k_per_split;
let k_end_raw = k_start + k_per_split;
let k_end = select(k_end_raw < k, k_end_raw, k);
let part_base = split * m * n;
for _fm_i in range(0, n_fm, 1) {
for _fn_i in range(0, n_fn, 1) {
let acc = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(acc, 0, 0.0f32);
simdgroup_elem_store(acc, 1, 0.0f32);
let m_row = block_m0 + sub_m0 + _fm_i * 8u32;
let n_col = block_n0 + sub_n0 + _fn_i * 8u32;
for kb in range(k_start, k_end, 16) {
for _kf in range(0, n_kf, 1) {
let kf = kb + _kf * 8u32;
let sub_a = simdgroup_alloc::<T, 8, 8>();
let sub_b = simdgroup_alloc::<T, 8, 8>();
simdgroup_elem_store(
sub_a,
0,
load(a[(m_row + fm) * k + kf + fn0]).cast::<T>(),
);
simdgroup_elem_store(
sub_a,
1,
load(a[(m_row + fm) * k + kf + fn1]).cast::<T>(),
);
simdgroup_elem_store(
sub_b,
0,
load(b[(kf + fm) * n + n_col + fn0]).cast::<T>(),
);
simdgroup_elem_store(
sub_b,
1,
load(b[(kf + fm) * n + n_col + fn1]).cast::<T>(),
);
simdgroup_matmul(sub_a, sub_b, acc);
}
}
let r0 = simdgroup_elem_load(acc, 0);
let r1 = simdgroup_elem_load(acc, 1);
store(partials[part_base + (m_row + fm) * n + n_col + fn0], r0);
store(partials[part_base + (m_row + fm) * n + n_col + fn1], r1);
}
}
}
inventory::submit! { crate::spec::BenchSpec {
op: "steel_gemm_splitk", subop: $subop,
kernel_name: stringify!($name),
kernel_ir: $name::kernel_ir_for,
dtypes: crate::bench_types::FLOAT_DTYPES, tol: 1e-2f32,
mlx_src: Some(include_str!(concat!(env!("OUT_DIR"), "/metal/steel/gemm/steel_gemm_splitk.metal"))),
mlx_pattern: Some(concat!(
"steel_gemm_splitk_nn_{tn}_{tn}_bm", stringify!($bm),
"_bn", stringify!($bn), "_bk16_wm", stringify!($wm),
"_wn", stringify!($wn),
)),
shapes: &[],
dispatch: crate::spec::BenchDispatch::SteelGemm {
m: 4096, n: 4096, k: 4096,
check_m: $bm as usize, check_n: $bn as usize, check_k: 16,
bm: $bm as usize, bn: $bn as usize,
tpg: ($wm * $wn * 32u32) as usize,
},
kernel_mode: Some(metaltile_core::ir::KernelMode::SimdGroup2D),
}}
};
}
steel_gemm_splitk_kernel!(
mt_steel_gemm_splitk_64x64x16_2x2,
64u32,
64u32,
2u32,
2u32,
"bm64_bn64_bk16_wm2_wn2"
);
steel_gemm_splitk_kernel!(
mt_steel_gemm_splitk_32x32x16_2x2,
32u32,
32u32,
2u32,
2u32,
"bm32_bn32_bk16_wm2_wn2"
);
#[kernel]
pub fn mt_steel_gemm_splitk_accum<T>(
partials: Tensor<f32>,
mut out: Tensor<T>,
#[constexpr] m: u32,
#[constexpr] n: u32,
#[constexpr] n_splits: u32,
) {
let idx = program_id::<0>();
let total = m * n;
let mut acc = 0.0f32;
for s in range(0u32, n_splits, 1u32) {
acc = acc + load(partials[s * total + idx]);
}
store(out[idx], acc.cast::<T>());
}
inventory::submit! { crate::spec::BenchSpec {
op: "steel_gemm_splitk", subop: "accum",
kernel_name: "mt_steel_gemm_splitk_accum",
kernel_ir: mt_steel_gemm_splitk_accum::kernel_ir_for,
dtypes: crate::bench_types::FLOAT_DTYPES, tol: 1e-3f32,
mlx_src: Some(include_str!(concat!(env!("OUT_DIR"), "/metal/steel/gemm/steel_gemm_splitk.metal"))),
mlx_pattern: Some("steel_gemm_splitk_accum_{tn}_float32"),
shapes: &[],
dispatch: crate::spec::BenchDispatch::Generic,
kernel_mode: Some(metaltile_core::ir::KernelMode::Elementwise),
}}
#[kernel]
pub fn mt_steel_gemm_splitk_accum_axpby<T>(
partials: Tensor<f32>,
c_in: Tensor<T>,
mut out: Tensor<T>,
#[constexpr] m: u32,
#[constexpr] n: u32,
#[constexpr] n_splits: u32,
#[constexpr] alpha: f32,
#[constexpr] beta: f32,
) {
let idx = program_id::<0>();
let total = m * n;
let mut acc = 0.0f32;
for s in range(0u32, n_splits, 1u32) {
acc = acc + load(partials[s * total + idx]);
}
let prev = load(c_in[idx]).cast::<f32>();
let res = alpha * acc + beta * prev;
store(out[idx], res.cast::<T>());
}
inventory::submit! { crate::spec::BenchSpec {
op: "steel_gemm_splitk", subop: "accum_axpby",
kernel_name: "mt_steel_gemm_splitk_accum_axpby",
kernel_ir: mt_steel_gemm_splitk_accum_axpby::kernel_ir_for,
dtypes: crate::bench_types::FLOAT_DTYPES, tol: 1e-3f32,
mlx_src: Some(include_str!(concat!(env!("OUT_DIR"), "/metal/steel/gemm/steel_gemm_splitk.metal"))),
mlx_pattern: Some("steel_gemm_splitk_accum_{tn}_float32_axbpy"),
shapes: &[],
dispatch: crate::spec::BenchDispatch::Generic,
kernel_mode: Some(metaltile_core::ir::KernelMode::Elementwise),
}}