use metaltile_core::{dtype::DType, ir::Kernel};
use crate::mlx::quantized::{mt_qmm_mma, patch_qmm_mma_dtype_aware_skew};
pub const BM_TILE: u32 = 32;
pub const BN_TILE: u32 = 32;
pub const BK_TILE: u32 = 32;
pub const TPG: u32 = 128;
pub const fn pad_t_to_bm(t: usize) -> usize {
let bm = BM_TILE as usize;
t.div_ceil(bm) * bm
}
pub fn pad_x_rows_bytes(x_bytes: &[u8], t: usize, k: usize, bytes_per_elem: usize) -> Vec<u8> {
let m_padded = pad_t_to_bm(t);
let row_bytes = k * bytes_per_elem;
assert_eq!(x_bytes.len(), t * row_bytes, "x_bytes must be t * k * bytes_per_elem");
let mut out = Vec::with_capacity(m_padded * row_bytes);
out.extend_from_slice(x_bytes);
out.resize(m_padded * row_bytes, 0);
out
}
pub fn kernel_ir_for(dtype: DType) -> Kernel {
let mut k = mt_qmm_mma::kernel_ir_for(dtype);
patch_qmm_mma_dtype_aware_skew(&mut k, dtype);
k.mode = metaltile_core::ir::KernelMode::Reduction;
k
}
pub fn dispatch_grid(t: usize, n: usize) -> [usize; 3] {
assert!(n.is_multiple_of(BN_TILE as usize), "n must be multiple of {} (BN tile)", BN_TILE);
let m_padded = pad_t_to_bm(t);
[n / BN_TILE as usize, m_padded / BM_TILE as usize, 1]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pad_t_to_bm_rounds_up_to_multiple_of_32() {
assert_eq!(pad_t_to_bm(0), 0);
assert_eq!(pad_t_to_bm(1), 32);
assert_eq!(pad_t_to_bm(31), 32);
assert_eq!(pad_t_to_bm(32), 32);
assert_eq!(pad_t_to_bm(33), 64);
assert_eq!(pad_t_to_bm(37), 64);
assert_eq!(pad_t_to_bm(64), 64);
assert_eq!(pad_t_to_bm(4096), 4096);
assert_eq!(pad_t_to_bm(4097), 4128);
}
#[test]
fn dispatch_grid_pads_m_axis() {
assert_eq!(dispatch_grid(1, 128), [4, 1, 1]);
assert_eq!(dispatch_grid(37, 128), [4, 2, 1]);
assert_eq!(dispatch_grid(4096, 2048), [64, 128, 1]);
}
#[test]
fn pad_x_rows_zero_fills_trailing() {
let x = vec![0x01u8; 16];
let padded = pad_x_rows_bytes(&x, 2, 4, 2);
assert_eq!(padded.len(), 32 * 4 * 2);
assert!(padded[..16].iter().all(|&b| b == 0x01));
assert!(padded[16..].iter().all(|&b| b == 0));
}
#[test]
fn kernel_ir_for_returns_mt_qmm_mma_per_dtype() {
for dt in [DType::F32, DType::F16, DType::BF16] {
let k = kernel_ir_for(dt);
assert_eq!(k.name, "mt_qmm_mma", "dynamic-M routes to mt_qmm_mma for dtype {:?}", dt);
assert_eq!(k.mode, metaltile_core::ir::KernelMode::Reduction);
}
}
}