use std::time::Instant;
use mlx_native::ops::qmm_affine::{
dispatch_qmm_affine_t_f32, dispatch_qmm_affine_t_f32_simd,
dispatch_qmm_affine_t_f32_simd4, dispatch_qmm_affine_t_f32_tiled,
};
use mlx_native::{DType, KernelRegistry, MlxDevice};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut args = std::env::args().skip(1);
let m: u32 = args
.next()
.map(|s| s.parse().expect("M"))
.unwrap_or(64);
let n: u32 = args
.next()
.map(|s| s.parse().expect("N"))
.unwrap_or(4096);
let k: u32 = args
.next()
.map(|s| s.parse().expect("K"))
.unwrap_or(4096);
let group_size: u32 = 32;
let n_iter: usize = args
.next()
.map(|s| s.parse().expect("n_iter"))
.unwrap_or(20);
println!("[bench] shape M={m} N={n} K={k} group_size={group_size} iter={n_iter}");
let device = MlxDevice::new()?;
let mut registry = KernelRegistry::new();
let m_us = m as usize;
let n_us = n as usize;
let k_us = k as usize;
let groups = k_us / group_size as usize;
let x: Vec<f32> = (0..(m_us * k_us))
.map(|i| ((i as f32) * 0.013).sin() * 0.5)
.collect();
let q_int: Vec<u8> = (0..(n_us * k_us)).map(|i| ((i * 7) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n_us * groups))
.map(|i| 0.05 + (i as f32) * 1e-5)
.collect();
let biases: Vec<f32> = (0..(n_us * groups))
.map(|i| -0.1 + (i as f32) * 1e-5)
.collect();
let mut x_buf = device.alloc_buffer(m_us * k_us * 4, DType::F32, vec![m_us, k_us])?;
x_buf.as_mut_slice::<f32>()?.copy_from_slice(&x);
let mut q_buf = device.alloc_buffer(n_us * k_us, DType::U8, vec![n_us, k_us])?;
q_buf.as_mut_slice::<u8>()?.copy_from_slice(&q_int);
let mut s_buf = device.alloc_buffer(n_us * groups * 4, DType::F32, vec![n_us, groups])?;
s_buf.as_mut_slice::<f32>()?.copy_from_slice(&scales);
let mut b_buf = device.alloc_buffer(n_us * groups * 4, DType::F32, vec![n_us, groups])?;
b_buf.as_mut_slice::<f32>()?.copy_from_slice(&biases);
let y_buf = device.alloc_buffer(m_us * n_us * 4, DType::F32, vec![m_us, n_us])?;
let mut meta = device.alloc_buffer(16, DType::U32, vec![4])?;
meta.as_mut_slice::<u32>()?
.copy_from_slice(&[m, n, k, group_size]);
for _ in 0..3 {
let mut encoder = device.command_encoder()?;
dispatch_qmm_affine_t_f32(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta, m, n, k, group_size,
)?;
encoder.commit_and_wait()?;
}
for _ in 0..3 {
let mut encoder = device.command_encoder()?;
dispatch_qmm_affine_t_f32_tiled(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta, m, n, k, group_size,
)?;
encoder.commit_and_wait()?;
}
let t0 = Instant::now();
for _ in 0..n_iter {
let mut encoder = device.command_encoder()?;
dispatch_qmm_affine_t_f32(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta, m, n, k, group_size,
)?;
encoder.commit_and_wait()?;
}
let dt_pe = t0.elapsed();
let avg_pe = dt_pe.as_secs_f64() / n_iter as f64;
let t0 = Instant::now();
for _ in 0..n_iter {
let mut encoder = device.command_encoder()?;
dispatch_qmm_affine_t_f32_tiled(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta, m, n, k, group_size,
)?;
encoder.commit_and_wait()?;
}
let dt_tl = t0.elapsed();
let avg_tl = dt_tl.as_secs_f64() / n_iter as f64;
for _ in 0..3 {
let mut encoder = device.command_encoder()?;
dispatch_qmm_affine_t_f32_simd(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta, m, n, k, group_size,
)?;
encoder.commit_and_wait()?;
}
let t0 = Instant::now();
for _ in 0..n_iter {
let mut encoder = device.command_encoder()?;
dispatch_qmm_affine_t_f32_simd(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta, m, n, k, group_size,
)?;
encoder.commit_and_wait()?;
}
let dt_sm = t0.elapsed();
let avg_sm = dt_sm.as_secs_f64() / n_iter as f64;
for _ in 0..3 {
let mut encoder = device.command_encoder()?;
dispatch_qmm_affine_t_f32_simd4(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta, m, n, k, group_size,
)?;
encoder.commit_and_wait()?;
}
let t0 = Instant::now();
for _ in 0..n_iter {
let mut encoder = device.command_encoder()?;
dispatch_qmm_affine_t_f32_simd4(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta, m, n, k, group_size,
)?;
encoder.commit_and_wait()?;
}
let dt_s4 = t0.elapsed();
let avg_s4 = dt_s4.as_secs_f64() / n_iter as f64;
let speedup_tl = avg_pe / avg_tl;
let speedup_sm = avg_pe / avg_sm;
let speedup_s4 = avg_pe / avg_s4;
let flops_per_call = 2.0 * (m as f64) * (n as f64) * (k as f64);
let pe_gflops = flops_per_call / avg_pe / 1e9;
let tl_gflops = flops_per_call / avg_tl / 1e9;
let sm_gflops = flops_per_call / avg_sm / 1e9;
let s4_gflops = flops_per_call / avg_s4 / 1e9;
println!("[bench] per-element: avg {:.3} ms = {:.1} GFLOPS", avg_pe * 1000.0, pe_gflops);
println!("[bench] tiled: avg {:.3} ms = {:.1} GFLOPS", avg_tl * 1000.0, tl_gflops);
println!("[bench] simd-MMA: avg {:.3} ms = {:.1} GFLOPS", avg_sm * 1000.0, sm_gflops);
println!("[bench] simd4-MMA: avg {:.3} ms = {:.1} GFLOPS", avg_s4 * 1000.0, s4_gflops);
println!("[bench] speedup tiled / per-element = {speedup_tl:.2}×");
println!("[bench] speedup simd / per-element = {speedup_sm:.2}×");
println!("[bench] speedup simd4 / per-element = {speedup_s4:.2}×");
Ok(())
}