use metaltile::kernel;
use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
spec::{BenchDispatch, BenchSpec},
};
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn mt_fp4_qmm_mma<T>(
w: Tensor<u32>,
scales: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let n_tile = tgid_x;
let m_tile = tgid_y;
let lane = simd_lane;
let sg = simd_group_id();
let sm = sg / 2u32;
let sn = sg & 1u32;
let lane_in_tg = sg * 32u32 + lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
threadgroup_alloc("xs", 1152, T);
threadgroup_alloc("ws", 1152, T);
let c_f00 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f00, 0, 0.0f32);
simdgroup_elem_store(c_f00, 1, 0.0f32);
let c_f01 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f01, 0, 0.0f32);
simdgroup_elem_store(c_f01, 1, 0.0f32);
let c_f10 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f10, 0, 0.0f32);
simdgroup_elem_store(c_f10, 1, 0.0f32);
let c_f11 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f11, 0, 0.0f32);
simdgroup_elem_store(c_f11, 1, 0.0f32);
let a_f0 = simdgroup_alloc::<T, 8, 8>();
let a_f1 = simdgroup_alloc::<T, 8, 8>();
let b_f0 = simdgroup_alloc::<T, 8, 8>();
let b_f1 = simdgroup_alloc::<T, 8, 8>();
let w_row = lane_in_tg / 4u32;
let word_in_row = lane_in_tg & 3u32;
let x_m_row = lane_in_tg / 4u32;
let x_k_quad = lane_in_tg & 3u32;
let x_k_base = x_k_quad * 8u32;
let xs_ld = 36u32;
let ws_ld = 36u32;
let x_m_base = m_tile * 32u32;
let w_n_base = n_tile * 32u32;
let packs_per_row = k / 8u32;
let sb_base = (w_n_base + w_row) * gs_per_row;
let w_pack_row_base = (w_n_base + w_row) * packs_per_row;
let group_size = k / gs_per_row;
for kb in range(0u32, k, 32u32) {
let x_row_dev_base = (x_m_base + x_m_row) * k + kb + x_k_base;
let x_ws_base = x_m_row * xs_ld + x_k_base;
threadgroup_store("xs", x_ws_base, load(x[x_row_dev_base]).cast::<T>());
threadgroup_store("xs", x_ws_base + 1u32, load(x[x_row_dev_base + 1u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 2u32, load(x[x_row_dev_base + 2u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 3u32, load(x[x_row_dev_base + 3u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 4u32, load(x[x_row_dev_base + 4u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 5u32, load(x[x_row_dev_base + 5u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 6u32, load(x[x_row_dev_base + 6u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 7u32, load(x[x_row_dev_base + 7u32]).cast::<T>());
let pack_k_off = kb / 8u32 + word_in_row;
let pack = load(w[w_pack_row_base + pack_k_off]);
let k_off = kb + word_in_row * 8u32;
let g = k_off / group_size;
let s = load(scales[sb_base + g]).cast::<f32>();
let ws_base = w_row * ws_ld + word_in_row * 8u32;
for _ci in range(0u32, 8u32, 1u32) {
let nibble = (pack >> (_ci * 4u32)) & 15u32;
let sign = 1.0f32 - 2.0f32 * ((nibble >> 3u32) & 1u32).cast::<f32>();
let code3 = nibble & 7u32;
let exp = code3 >> 1u32;
let mantissa = code3 & 1u32;
let is_normal = select(exp > 0u32, 1u32, 0u32);
let two_m_int_sub = mantissa;
let two_m_int_norm = (mantissa + 2u32) << (exp - 1u32);
let two_m_int = select(is_normal == 1u32, two_m_int_norm, two_m_int_sub);
let val = s * sign * two_m_int.cast::<f32>() * 0.5f32;
threadgroup_store("ws", ws_base + _ci, val.cast::<T>());
}
threadgroup_barrier();
let row_a0 = sm * 16u32 + fm;
let row_a1 = sm * 16u32 + 8u32 + fm;
let col_b0 = sn * 16u32;
let col_b1 = sn * 16u32 + 8u32;
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 8u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 16u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 24u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
threadgroup_barrier();
}
let out_m_base = m_tile * 32u32 + sm * 16u32;
let out_n_base = n_tile * 32u32 + sn * 16u32;
store(out[(out_m_base + fm) * n + out_n_base + fn0], simdgroup_elem_load(c_f00, 0).cast::<T>());
store(out[(out_m_base + fm) * n + out_n_base + fn1], simdgroup_elem_load(c_f00, 1).cast::<T>());
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f01, 0).cast::<T>(),
);
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f01, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn0],
simdgroup_elem_load(c_f10, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn1],
simdgroup_elem_load(c_f10, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f11, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f11, 1).cast::<T>(),
);
}
inventory::submit! {
BenchSpec {
op: "fp_quantized",
subop: "fp4_qmm_mma",
kernel_name: "mt_fp4_qmm_mma",
kernel_ir: mt_fp4_qmm_mma::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 5e-2,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn mt_fp8_e4m3_qmm_mma<T>(
w: Tensor<u32>,
scales: Tensor<T>,
x: Tensor<T>,
out: Tensor<T>,
#[constexpr] k: u32,
#[constexpr] n: u32,
#[constexpr] gs_per_row: u32,
) {
let n_tile = tgid_x;
let m_tile = tgid_y;
let lane = simd_lane;
let sg = simd_group_id();
let sm = sg / 2u32;
let sn = sg & 1u32;
let lane_in_tg = sg * 32u32 + lane;
let qid = lane / 4u32;
let fm = (qid & 4u32) + ((lane / 2u32) % 4u32);
let fn0 = (qid & 2u32) * 2u32 + (lane % 2u32) * 2u32;
let fn1 = fn0 + 1u32;
threadgroup_alloc("xs", 1152, T);
threadgroup_alloc("ws", 1152, T);
let c_f00 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f00, 0, 0.0f32);
simdgroup_elem_store(c_f00, 1, 0.0f32);
let c_f01 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f01, 0, 0.0f32);
simdgroup_elem_store(c_f01, 1, 0.0f32);
let c_f10 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f10, 0, 0.0f32);
simdgroup_elem_store(c_f10, 1, 0.0f32);
let c_f11 = simdgroup_alloc::<f32, 8, 8>();
simdgroup_elem_store(c_f11, 0, 0.0f32);
simdgroup_elem_store(c_f11, 1, 0.0f32);
let a_f0 = simdgroup_alloc::<T, 8, 8>();
let a_f1 = simdgroup_alloc::<T, 8, 8>();
let b_f0 = simdgroup_alloc::<T, 8, 8>();
let b_f1 = simdgroup_alloc::<T, 8, 8>();
let w_row = lane_in_tg / 4u32;
let word_in_row = lane_in_tg & 3u32;
let x_m_row = lane_in_tg / 4u32;
let x_k_quad = lane_in_tg & 3u32;
let x_k_base = x_k_quad * 8u32;
let xs_ld = 36u32;
let ws_ld = 36u32;
let x_m_base = m_tile * 32u32;
let w_n_base = n_tile * 32u32;
let packs_per_row = k / 4u32;
let sb_base = (w_n_base + w_row) * gs_per_row;
let w_pack_row_base = (w_n_base + w_row) * packs_per_row;
let group_size = k / gs_per_row;
for kb in range(0u32, k, 32u32) {
let x_row_dev_base = (x_m_base + x_m_row) * k + kb + x_k_base;
let x_ws_base = x_m_row * xs_ld + x_k_base;
threadgroup_store("xs", x_ws_base, load(x[x_row_dev_base]).cast::<T>());
threadgroup_store("xs", x_ws_base + 1u32, load(x[x_row_dev_base + 1u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 2u32, load(x[x_row_dev_base + 2u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 3u32, load(x[x_row_dev_base + 3u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 4u32, load(x[x_row_dev_base + 4u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 5u32, load(x[x_row_dev_base + 5u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 6u32, load(x[x_row_dev_base + 6u32]).cast::<T>());
threadgroup_store("xs", x_ws_base + 7u32, load(x[x_row_dev_base + 7u32]).cast::<T>());
let k_off = kb + word_in_row * 4u32;
let g = k_off / group_size;
let s = load(scales[sb_base + g]).cast::<f32>();
let pack_a = load(w[w_pack_row_base + kb / 4u32 + word_in_row]);
let ws_base_a = w_row * ws_ld + word_in_row * 4u32;
for _ci in range(0u32, 4u32, 1u32) {
let code = (pack_a >> (_ci * 8u32)) & 255u32;
let sign = 1.0f32 - 2.0f32 * (code >> 7u32).cast::<f32>();
let code7 = code & 127u32;
let e_raw = code7 >> 3u32;
let m = code7 & 7u32;
let is_normal = select(e_raw > 0u32, 1u32, 0u32);
let exp_f = e_raw.cast::<f32>() - 7.0f32;
let norm_mag = exp2(exp_f) * (1.0f32 + m.cast::<f32>() * 0.125f32);
let sub_mag = exp2(-6.0f32) * m.cast::<f32>() * 0.125f32;
let mag = select(is_normal == 1u32, norm_mag, sub_mag);
let val = s * sign * mag;
threadgroup_store("ws", ws_base_a + _ci, val.cast::<T>());
}
let k_off_b = kb + (word_in_row + 4u32) * 4u32;
let g_b = k_off_b / group_size;
let s_b = load(scales[sb_base + g_b]).cast::<f32>();
let pack_b = load(w[w_pack_row_base + kb / 4u32 + word_in_row + 4u32]);
let ws_base_b = w_row * ws_ld + (word_in_row + 4u32) * 4u32;
for _ci in range(0u32, 4u32, 1u32) {
let code = (pack_b >> (_ci * 8u32)) & 255u32;
let sign = 1.0f32 - 2.0f32 * (code >> 7u32).cast::<f32>();
let code7 = code & 127u32;
let e_raw = code7 >> 3u32;
let m = code7 & 7u32;
let is_normal = select(e_raw > 0u32, 1u32, 0u32);
let exp_f = e_raw.cast::<f32>() - 7.0f32;
let norm_mag = exp2(exp_f) * (1.0f32 + m.cast::<f32>() * 0.125f32);
let sub_mag = exp2(-6.0f32) * m.cast::<f32>() * 0.125f32;
let mag = select(is_normal == 1u32, norm_mag, sub_mag);
let val = s_b * sign * mag;
threadgroup_store("ws", ws_base_b + _ci, val.cast::<T>());
}
threadgroup_barrier();
let row_a0 = sm * 16u32 + fm;
let row_a1 = sm * 16u32 + 8u32 + fm;
let col_b0 = sn * 16u32;
let col_b1 = sn * 16u32 + 8u32;
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 8u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 8u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 8u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 8u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 16u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 16u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 16u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 16u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
simdgroup_elem_store(a_f0, 0, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f0, 1, threadgroup_load("xs", row_a0 * xs_ld + 24u32 + fn1));
simdgroup_elem_store(a_f1, 0, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn0));
simdgroup_elem_store(a_f1, 1, threadgroup_load("xs", row_a1 * xs_ld + 24u32 + fn1));
simdgroup_barrier_mem_none();
simdgroup_elem_store(b_f0, 0, threadgroup_load("ws", (col_b0 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f0, 1, threadgroup_load("ws", (col_b0 + fn1) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 0, threadgroup_load("ws", (col_b1 + fn0) * ws_ld + 24u32 + fm));
simdgroup_elem_store(b_f1, 1, threadgroup_load("ws", (col_b1 + fn1) * ws_ld + 24u32 + fm));
simdgroup_barrier_mem_none();
simdgroup_matmul(a_f0, b_f0, c_f00);
simdgroup_matmul(a_f0, b_f1, c_f01);
simdgroup_matmul(a_f1, b_f1, c_f11);
simdgroup_matmul(a_f1, b_f0, c_f10);
simdgroup_barrier_mem_none();
threadgroup_barrier();
}
let out_m_base = m_tile * 32u32 + sm * 16u32;
let out_n_base = n_tile * 32u32 + sn * 16u32;
store(out[(out_m_base + fm) * n + out_n_base + fn0], simdgroup_elem_load(c_f00, 0).cast::<T>());
store(out[(out_m_base + fm) * n + out_n_base + fn1], simdgroup_elem_load(c_f00, 1).cast::<T>());
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f01, 0).cast::<T>(),
);
store(
out[(out_m_base + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f01, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn0],
simdgroup_elem_load(c_f10, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + fn1],
simdgroup_elem_load(c_f10, 1).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn0],
simdgroup_elem_load(c_f11, 0).cast::<T>(),
);
store(
out[(out_m_base + 8u32 + fm) * n + out_n_base + 8u32 + fn1],
simdgroup_elem_load(c_f11, 1).cast::<T>(),
);
}
inventory::submit! {
BenchSpec {
op: "fp_quantized",
subop: "fp8_e4m3_qmm_mma",
kernel_name: "mt_fp8_e4m3_qmm_mma",
kernel_ir: mt_fp8_e4m3_qmm_mma::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 5e-2,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}