#![cfg(target_os = "macos")]
use std::collections::BTreeMap;
mod common;
use common::gpu_lock;
use metaltile_core::dtype::DType;
use metaltile_runtime::{Context, DispatchSpec, ResidentBuffer};
use metaltile_std::mlx::quantized::{mt_qmm, mt_qmm_bm2, mt_qmm_bm4, mt_qmm_mma, mt_qmm_mma_m16};
#[allow(clippy::too_many_arguments)]
fn run_qmm(
ctx: &Context,
dtype: DType,
w: &[u32],
scales_bytes: &[u8],
biases_bytes: &[u8],
x_bytes: &[u8],
m: usize,
n: usize,
k: usize,
gs_per_row: usize,
out_bytes_per_elem: usize,
) -> Vec<u8> {
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("w".into(), w.iter().flat_map(|v| v.to_le_bytes()).collect());
buffers.insert("scales".into(), scales_bytes.to_vec());
buffers.insert("biases".into(), biases_bytes.to_vec());
buffers.insert("x".into(), x_bytes.to_vec());
buffers.insert("out".into(), vec![0u8; m * n * out_bytes_per_elem]);
buffers.insert("k".into(), (k as u32).to_le_bytes().to_vec());
buffers.insert("n".into(), (n as u32).to_le_bytes().to_vec());
buffers.insert("gs_per_row".into(), (gs_per_row as u32).to_le_bytes().to_vec());
let mut kernel = mt_qmm::kernel_ir_for(dtype);
kernel.mode = metaltile_core::ir::KernelMode::Reduction;
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [n / 8, m, 1], [64, 1, 1])
.expect("dispatch_with_grid should succeed");
result.outputs.get("out").expect("`out` buffer in dispatch result").clone()
}
#[allow(clippy::too_many_arguments)]
fn cpu_qmm_reference(
w: &[u32],
scales: &[f32],
biases: &[f32],
x: &[f32],
m: usize,
n: usize,
k: usize,
gs_per_row: usize,
group_size: usize,
) -> Vec<f32> {
let mut out = vec![0.0f32; m * n];
for m_row in 0..m {
for n_col in 0..n {
let mut acc = 0.0f32;
for g in 0..gs_per_row {
let s = scales[n_col * gs_per_row + g];
let bias = biases[n_col * gs_per_row + g];
let mut q_dot = 0.0f32;
let mut x_sum = 0.0f32;
for p in 0..8usize {
let packed = w[n_col * k / 8 + g * 8 + p];
for bit in 0..8u32 {
let q = ((packed >> (bit * 4)) & 0xF) as f32;
let xv = x[m_row * k + g * group_size + p * 8 + bit as usize];
q_dot += q * xv;
x_sum += xv;
}
}
acc += s * q_dot + bias * x_sum;
}
out[m_row * n + n_col] = acc;
}
}
out
}
#[test]
fn mt_qmm_matches_cpu_reference_f32() {
let m = 8usize;
let n = 16usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let mut expected = vec![0.0f32; m * n];
for m_row in 0..m {
for n_col in 0..n {
let mut acc = 0.0f32;
for g in 0..gs_per_row {
let s = scales[n_col * gs_per_row + g];
let bias = biases[n_col * gs_per_row + g];
let mut q_dot = 0.0f32;
let mut x_sum = 0.0f32;
for p in 0..8usize {
let packed = w[n_col * k / 8 + g * 8 + p];
for bit in 0..8u32 {
let q = ((packed >> (bit * 4)) & 0xF) as f32;
let xv = x[m_row * k + g * group_size + p * 8 + bit as usize];
q_dot += q * xv;
x_sum += xv;
}
}
acc += s * q_dot + bias * x_sum;
}
expected[m_row * n + n_col] = acc;
}
}
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
assert_eq!(actual.len(), expected.len(), "output element count");
let mut max_diff = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let diff = (e - a).abs();
if diff > max_diff {
max_diff = diff;
max_at = i;
}
}
assert!(
max_diff < 1e-3,
"max |diff| = {max_diff:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_matches_cpu_reference_f16() {
let m = 8usize;
let n = 16usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales_f32: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases_f32: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x_f32: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let round_f16 = |v: f32| -> f32 { half::f16::from_f32(v).to_f32() };
let scales: Vec<f32> = scales_f32.iter().map(|&v| round_f16(v)).collect();
let biases: Vec<f32> = biases_f32.iter().map(|&v| round_f16(v)).collect();
let x: Vec<f32> = x_f32.iter().map(|&v| round_f16(v)).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> =
scales.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let biases_bytes: Vec<u8> =
biases.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let x_bytes: Vec<u8> =
x.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm(
&ctx,
DType::F16,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
2,
);
let actual: Vec<f32> = out_bytes
.chunks_exact(2)
.map(|c| half::f16::from_bits(u16::from_le_bytes([c[0], c[1]])).to_f32())
.collect();
assert_eq!(actual.len(), expected.len(), "output element count");
let mut max_rel = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let rel = (e - a).abs() / e.abs().max(1.0);
if rel > max_rel {
max_rel = rel;
max_at = i;
}
}
assert!(
max_rel < 5e-3,
"max relative diff = {max_rel:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_runs_on_qwen3_attention_proj_shape() {
let m = 4usize;
let n = 5120usize;
let k = 5120usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8).map(|i| (i as u32).wrapping_mul(2654435761u32)).collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.01 + (i % 13) as f32 * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i % 7) as f32 * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 0.1 + ((i % 31) as f32) * 0.01).collect();
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
assert_eq!(actual.len(), m * n);
for (i, &v) in actual.iter().enumerate() {
assert!(v.is_finite(), "non-finite output at index {i}: {v}");
}
}
const QWEN3_SHAPES: &[(usize, usize, &str)] = &[
(4096, 4096, "baseline 4096²"),
(5120, 5120, "Qwen3-8B/14B attn proj"),
(14336, 5120, "Qwen3-8B/14B MLP up_proj"),
(5120, 14336, "Qwen3-8B/14B MLP down_proj"),
(27648, 5120, "Qwen3-coder-30B MoE expert up_proj"),
];
#[test]
fn mt_qmm_m1_byte_identical_to_qmv_dispatch_path() {
let _g = gpu_lock();
let m = 1usize;
let n = 32usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
let max_diff =
expected.iter().zip(actual.iter()).map(|(e, a)| (e - a).abs()).fold(0.0_f32, f32::max);
assert!(max_diff < 1e-3, "max |diff| = {max_diff:.2e}");
}
#[test]
fn mt_qmm_matches_cpu_reference_bf16_small_shape() {
let _g = gpu_lock();
let m = 4usize;
let n = 16usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales_f32: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases_f32: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x_f32: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let round_bf16 = |v: f32| -> f32 { half::bf16::from_f32(v).to_f32() };
let scales: Vec<f32> = scales_f32.iter().map(|&v| round_bf16(v)).collect();
let biases: Vec<f32> = biases_f32.iter().map(|&v| round_bf16(v)).collect();
let x: Vec<f32> = x_f32.iter().map(|&v| round_bf16(v)).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> =
scales.iter().flat_map(|v| half::bf16::from_f32(*v).to_bits().to_le_bytes()).collect();
let biases_bytes: Vec<u8> =
biases.iter().flat_map(|v| half::bf16::from_f32(*v).to_bits().to_le_bytes()).collect();
let x_bytes: Vec<u8> =
x.iter().flat_map(|v| half::bf16::from_f32(*v).to_bits().to_le_bytes()).collect();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm(
&ctx,
DType::BF16,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
2,
);
let actual: Vec<f32> = out_bytes
.chunks_exact(2)
.map(|c| half::bf16::from_bits(u16::from_le_bytes([c[0], c[1]])).to_f32())
.collect();
let mut max_rel = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let rel = (e - a).abs() / e.abs().max(1.0);
if rel > max_rel {
max_rel = rel;
max_at = i;
}
}
assert!(
max_rel < 2e-2,
"max relative diff = {max_rel:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_dispatches_all_qwen3_shapes_at_b4_f16() {
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let m = 4usize;
let group_size = 64usize;
for &(n, k, label) in QWEN3_SHAPES {
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8).map(|i| (i as u32).wrapping_mul(2654435761u32)).collect();
let scales: Vec<f32> =
(0..n * gs_per_row).map(|i| 0.01 + (i % 13) as f32 * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i % 7) as f32 * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 0.1 + ((i % 31) as f32) * 0.01).collect();
let scales_bytes: Vec<u8> =
scales.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let biases_bytes: Vec<u8> =
biases.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let x_bytes: Vec<u8> =
x.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let out_bytes = run_qmm(
&ctx,
DType::F16,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
2,
);
let actual: Vec<f32> = out_bytes
.chunks_exact(2)
.map(|c| half::f16::from_bits(u16::from_le_bytes([c[0], c[1]])).to_f32())
.collect();
assert_eq!(actual.len(), m * n, "{label}: output length");
for (i, &v) in actual.iter().enumerate() {
assert!(v.is_finite(), "{label}: non-finite output at {i}: {v}");
}
}
}
#[test]
#[ignore = "perf bench, run via --ignored --nocapture"]
fn mt_qmm_perf_bench_qwen3_shapes_f16_m_sweep() {
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let group_size = 64usize;
const WARMUP: usize = 20;
const ITERS: usize = 50;
println!();
println!(
"mt_qmm f16 — Apple M-series (median of {ITERS} iters)\n {:>30} {:>5} {:>10} {:>10}",
"shape (n × k)", "M", "µs", "GB/s"
);
let mut kernel = mt_qmm::kernel_ir_for(DType::F16);
kernel.mode = metaltile_core::ir::KernelMode::Reduction;
let empty_fn_consts: BTreeMap<String, u32> = BTreeMap::new();
for &(n, k, label) in QWEN3_SHAPES {
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8).map(|i| (i as u32).wrapping_mul(2654435761u32)).collect();
let scales: Vec<f32> =
(0..n * gs_per_row).map(|i| 0.01 + (i % 13) as f32 * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i % 7) as f32 * 0.0001).collect();
let w_bytes_vec: Vec<u8> = w.iter().flat_map(|v| v.to_le_bytes()).collect();
let scales_bytes: Vec<u8> =
scales.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let biases_bytes: Vec<u8> =
biases.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let w_res = ctx.upload_resident(&w_bytes_vec).expect("upload w");
let scales_res = ctx.upload_resident(&scales_bytes).expect("upload scales");
let biases_res = ctx.upload_resident(&biases_bytes).expect("upload biases");
let mut residents: BTreeMap<String, ResidentBuffer> = BTreeMap::new();
residents.insert("w".into(), w_res);
residents.insert("scales".into(), scales_res);
residents.insert("biases".into(), biases_res);
for &m in &[1usize, 4, 8, 32] {
let x: Vec<f32> = (0..m * k).map(|i| 0.1 + ((i % 31) as f32) * 0.01).collect();
let x_bytes: Vec<u8> =
x.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("x".into(), x_bytes);
buffers.insert("out".into(), vec![0u8; m * n * 2]);
buffers.insert("k".into(), (k as u32).to_le_bytes().to_vec());
buffers.insert("n".into(), (n as u32).to_le_bytes().to_vec());
buffers.insert("gs_per_row".into(), (gs_per_row as u32).to_le_bytes().to_vec());
let mut samples = Vec::with_capacity(ITERS);
for i in 0..(WARMUP + ITERS) {
let r = ctx
.dispatch_chain(&[DispatchSpec {
kernel: &kernel,
buffers: &buffers,
fn_consts: &empty_fn_consts,
grid_groups: [n / 8, m, 1],
threads_per_group: [64, 1, 1],
resident: &residents,
}])
.expect("dispatch");
if i >= WARMUP {
samples.push(r[0].elapsed_us);
}
}
let mid = samples.len() / 2;
samples.select_nth_unstable_by(mid, |a, b| a.partial_cmp(b).unwrap());
let median_us = samples[mid];
let bytes = (n * k / 2 + 2 * n * gs_per_row * 2 + m * k * 2 + m * n * 2) as f64;
let gbps = bytes / (median_us * 1e-6) / 1e9;
println!(" {label:>30} {m:>5} {median_us:>10.2} {gbps:>10.1}");
}
}
}
#[allow(clippy::too_many_arguments)]
fn run_qmm_bm2(
ctx: &Context,
dtype: DType,
w: &[u32],
scales_bytes: &[u8],
biases_bytes: &[u8],
x_bytes: &[u8],
m: usize,
n: usize,
k: usize,
gs_per_row: usize,
out_bytes_per_elem: usize,
) -> Vec<u8> {
assert!(m.is_multiple_of(2), "mt_qmm_bm2 requires m %% 2 == 0 (BM=2 tile)");
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("w".into(), w.iter().flat_map(|v| v.to_le_bytes()).collect());
buffers.insert("scales".into(), scales_bytes.to_vec());
buffers.insert("biases".into(), biases_bytes.to_vec());
buffers.insert("x".into(), x_bytes.to_vec());
buffers.insert("out".into(), vec![0u8; m * n * out_bytes_per_elem]);
buffers.insert("k".into(), (k as u32).to_le_bytes().to_vec());
buffers.insert("n".into(), (n as u32).to_le_bytes().to_vec());
buffers.insert("gs_per_row".into(), (gs_per_row as u32).to_le_bytes().to_vec());
let mut kernel = mt_qmm_bm2::kernel_ir_for(dtype);
kernel.mode = metaltile_core::ir::KernelMode::Reduction;
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [n / 8, m / 2, 1], [64, 1, 1])
.expect("dispatch_with_grid should succeed");
result.outputs.get("out").expect("`out` buffer in dispatch result").clone()
}
#[test]
fn mt_qmm_bm2_matches_cpu_reference_f32() {
let m = 8usize;
let n = 16usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_bm2(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
assert_eq!(actual.len(), expected.len(), "output element count");
let mut max_diff = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let diff = (e - a).abs();
if diff > max_diff {
max_diff = diff;
max_at = i;
}
}
assert!(
max_diff < 1e-3,
"max |diff| = {max_diff:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_bm2_matches_cpu_reference_f16() {
let m = 8usize;
let n = 16usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales_f32: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases_f32: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x_f32: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let round_f16 = |v: f32| -> f32 { half::f16::from_f32(v).to_f32() };
let scales: Vec<f32> = scales_f32.iter().map(|&v| round_f16(v)).collect();
let biases: Vec<f32> = biases_f32.iter().map(|&v| round_f16(v)).collect();
let x: Vec<f32> = x_f32.iter().map(|&v| round_f16(v)).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> =
scales.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let biases_bytes: Vec<u8> =
biases.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let x_bytes: Vec<u8> =
x.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_bm2(
&ctx,
DType::F16,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
2,
);
let actual: Vec<f32> = out_bytes
.chunks_exact(2)
.map(|c| half::f16::from_bits(u16::from_le_bytes([c[0], c[1]])).to_f32())
.collect();
assert_eq!(actual.len(), expected.len(), "output element count");
let mut max_rel = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let rel = (e - a).abs() / e.abs().max(1.0);
if rel > max_rel {
max_rel = rel;
max_at = i;
}
}
assert!(
max_rel < 5e-3,
"max relative diff = {max_rel:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_bm2_matches_mt_qmm_at_same_shape_f32() {
let m = 8usize;
let n = 16usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_v2 = run_qmm(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let out_bm2 = run_qmm_bm2(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let a_v2: Vec<f32> =
out_v2.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
let a_bm2: Vec<f32> =
out_bm2.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
let mut max_diff = 0.0_f32;
let mut max_at = 0usize;
for (i, (v2, b)) in a_v2.iter().zip(a_bm2.iter()).enumerate() {
let d = (v2 - b).abs();
if d > max_diff {
max_diff = d;
max_at = i;
}
}
assert!(
max_diff < 1e-3,
"v2 vs bm2 diverge: max |diff| = {max_diff:.2e} at index {max_at} (v2 {:.6}, bm2 {:.6})",
a_v2[max_at],
a_bm2[max_at],
);
}
#[test]
fn mt_qmm_bm2_runs_on_qwen3_attention_proj_shape() {
let m = 8usize;
let n = 5120usize;
let k = 5120usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8).map(|i| (i as u32).wrapping_mul(2654435761u32)).collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.01 + (i % 13) as f32 * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i % 7) as f32 * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 0.1 + ((i % 31) as f32) * 0.01).collect();
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_bm2(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
for (i, v) in actual.iter().enumerate() {
assert!(v.is_finite(), "non-finite output at index {i}: {v}");
}
}
#[allow(clippy::too_many_arguments)]
fn run_qmm_bm4(
ctx: &Context,
dtype: DType,
w: &[u32],
scales_bytes: &[u8],
biases_bytes: &[u8],
x_bytes: &[u8],
m: usize,
n: usize,
k: usize,
gs_per_row: usize,
out_bytes_per_elem: usize,
) -> Vec<u8> {
assert!(m.is_multiple_of(4), "mt_qmm_bm4 requires m %% 4 == 0 (BM=4 tile)");
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("w".into(), w.iter().flat_map(|v| v.to_le_bytes()).collect());
buffers.insert("scales".into(), scales_bytes.to_vec());
buffers.insert("biases".into(), biases_bytes.to_vec());
buffers.insert("x".into(), x_bytes.to_vec());
buffers.insert("out".into(), vec![0u8; m * n * out_bytes_per_elem]);
buffers.insert("k".into(), (k as u32).to_le_bytes().to_vec());
buffers.insert("n".into(), (n as u32).to_le_bytes().to_vec());
buffers.insert("gs_per_row".into(), (gs_per_row as u32).to_le_bytes().to_vec());
let mut kernel = mt_qmm_bm4::kernel_ir_for(dtype);
kernel.mode = metaltile_core::ir::KernelMode::Reduction;
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [n / 8, m / 4, 1], [64, 1, 1])
.expect("dispatch_with_grid should succeed");
result.outputs.get("out").expect("`out` buffer in dispatch result").clone()
}
#[test]
fn mt_qmm_bm4_matches_cpu_reference_f32() {
let m = 8usize;
let n = 16usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_bm4(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
assert_eq!(actual.len(), expected.len(), "output element count");
let mut max_diff = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let diff = (e - a).abs();
if diff > max_diff {
max_diff = diff;
max_at = i;
}
}
assert!(
max_diff < 1e-3,
"max |diff| = {max_diff:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_bm4_matches_cpu_reference_f16() {
let m = 8usize;
let n = 16usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales_f32: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases_f32: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x_f32: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let round_f16 = |v: f32| -> f32 { half::f16::from_f32(v).to_f32() };
let scales: Vec<f32> = scales_f32.iter().map(|&v| round_f16(v)).collect();
let biases: Vec<f32> = biases_f32.iter().map(|&v| round_f16(v)).collect();
let x: Vec<f32> = x_f32.iter().map(|&v| round_f16(v)).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> =
scales.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let biases_bytes: Vec<u8> =
biases.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let x_bytes: Vec<u8> =
x.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_bm4(
&ctx,
DType::F16,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
2,
);
let actual: Vec<f32> = out_bytes
.chunks_exact(2)
.map(|c| half::f16::from_bits(u16::from_le_bytes([c[0], c[1]])).to_f32())
.collect();
let mut max_rel = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let rel = (e - a).abs() / e.abs().max(1.0);
if rel > max_rel {
max_rel = rel;
max_at = i;
}
}
assert!(
max_rel < 5e-3,
"max relative diff = {max_rel:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_bm4_matches_mt_qmm_at_same_shape_f32() {
let m = 8usize;
let n = 16usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_v2 = run_qmm(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let out_bm4 = run_qmm_bm4(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let a_v2: Vec<f32> =
out_v2.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
let a_bm4: Vec<f32> =
out_bm4.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
let mut max_diff = 0.0_f32;
let mut max_at = 0usize;
for (i, (v2, b)) in a_v2.iter().zip(a_bm4.iter()).enumerate() {
let d = (v2 - b).abs();
if d > max_diff {
max_diff = d;
max_at = i;
}
}
assert!(
max_diff < 1e-3,
"v2 vs bm4 diverge: max |diff| = {max_diff:.2e} at index {max_at} (v2 {:.6}, bm4 {:.6})",
a_v2[max_at],
a_bm4[max_at],
);
}
#[test]
fn mt_qmm_bm4_runs_on_qwen3_attention_proj_shape() {
let m = 16usize;
let n = 5120usize;
let k = 5120usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8).map(|i| (i as u32).wrapping_mul(2654435761u32)).collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.01 + (i % 13) as f32 * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i % 7) as f32 * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 0.1 + ((i % 31) as f32) * 0.01).collect();
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_bm4(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
for (i, v) in actual.iter().enumerate() {
assert!(v.is_finite(), "non-finite output at index {i}: {v}");
}
}
#[allow(clippy::too_many_arguments)]
fn run_qmm_mma(
ctx: &Context,
dtype: DType,
w: &[u32],
scales_bytes: &[u8],
biases_bytes: &[u8],
x_bytes: &[u8],
m: usize,
n: usize,
k: usize,
gs_per_row: usize,
out_bytes_per_elem: usize,
) -> Vec<u8> {
assert!(m.is_multiple_of(32), "mt_qmm_mma requires m %% 32 == 0 (BM=32 tile)");
assert!(n.is_multiple_of(32), "mt_qmm_mma requires n %% 32 == 0 (BN=32 tile)");
assert!(k.is_multiple_of(32), "mt_qmm_mma requires k %% 32 == 0 (BK=32 step)");
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("w".into(), w.iter().flat_map(|v| v.to_le_bytes()).collect());
buffers.insert("scales".into(), scales_bytes.to_vec());
buffers.insert("biases".into(), biases_bytes.to_vec());
buffers.insert("x".into(), x_bytes.to_vec());
buffers.insert("out".into(), vec![0u8; m * n * out_bytes_per_elem]);
buffers.insert("k".into(), (k as u32).to_le_bytes().to_vec());
buffers.insert("n".into(), (n as u32).to_le_bytes().to_vec());
buffers.insert("gs_per_row".into(), (gs_per_row as u32).to_le_bytes().to_vec());
let mut kernel = mt_qmm_mma::kernel_ir_for(dtype);
metaltile_std::mlx::quantized::patch_qmm_mma_dtype_aware_skew(&mut kernel, dtype);
kernel.mode = metaltile_core::ir::KernelMode::Reduction;
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [n / 32, m / 32, 1], [128, 1, 1])
.expect("dispatch_with_grid should succeed");
result.outputs.get("out").expect("`out` buffer in dispatch result").clone()
}
#[test]
fn mt_qmm_mma_matches_cpu_reference_f32() {
let m = 32usize;
let n = 32usize;
let k = 64usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_mma(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
assert_eq!(actual.len(), expected.len(), "output element count");
let mut max_diff = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let diff = (e - a).abs();
if diff > max_diff {
max_diff = diff;
max_at = i;
}
}
assert!(
max_diff < 1e-3,
"max |diff| = {max_diff:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_mma_matches_cpu_reference_f16() {
let m = 32usize;
let n = 32usize;
let k = 64usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales_f32: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases_f32: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x_f32: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let round_f16 = |v: f32| -> f32 { half::f16::from_f32(v).to_f32() };
let scales: Vec<f32> = scales_f32.iter().map(|&v| round_f16(v)).collect();
let biases: Vec<f32> = biases_f32.iter().map(|&v| round_f16(v)).collect();
let x: Vec<f32> = x_f32.iter().map(|&v| round_f16(v)).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> =
scales.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let biases_bytes: Vec<u8> =
biases.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let x_bytes: Vec<u8> =
x.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_mma(
&ctx,
DType::F16,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
2,
);
let actual: Vec<f32> = out_bytes
.chunks_exact(2)
.map(|c| half::f16::from_bits(u16::from_le_bytes([c[0], c[1]])).to_f32())
.collect();
let mut max_rel = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let rel = (e - a).abs() / e.abs().max(1.0);
if rel > max_rel {
max_rel = rel;
max_at = i;
}
}
assert!(
max_rel < 1.5e-2,
"max relative diff = {max_rel:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_mma_matches_mt_qmm_at_same_shape_f32() {
let m = 32usize;
let n = 32usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_v2 = run_qmm(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let out_mma = run_qmm_mma(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let a_v2: Vec<f32> =
out_v2.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
let a_mma: Vec<f32> =
out_mma.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
let mut max_rel = 0.0_f32;
let mut max_at = 0usize;
for (i, (v2, m_)) in a_v2.iter().zip(a_mma.iter()).enumerate() {
let rel = (v2 - m_).abs() / v2.abs().max(1.0);
if rel > max_rel {
max_rel = rel;
max_at = i;
}
}
assert!(
max_rel < 1e-4,
"v2 vs mma diverge: max rel diff = {max_rel:.2e} at index {max_at} (v2 {:.6}, mma {:.6})",
a_v2[max_at],
a_mma[max_at],
);
}
#[test]
fn mt_qmm_mma_runs_on_qwen3_attention_proj_shape() {
let m = 32usize;
let n = 5120usize;
let k = 5120usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8).map(|i| (i as u32).wrapping_mul(2654435761u32)).collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.01 + (i % 13) as f32 * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i % 7) as f32 * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 0.1 + ((i % 31) as f32) * 0.01).collect();
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_mma(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
for (i, v) in actual.iter().enumerate() {
assert!(v.is_finite(), "non-finite output at index {i}: {v}");
}
}
#[allow(clippy::too_many_arguments)]
fn run_qmm_mma_m16(
ctx: &Context,
dtype: DType,
w: &[u32],
scales_bytes: &[u8],
biases_bytes: &[u8],
x_bytes: &[u8],
m: usize,
n: usize,
k: usize,
gs_per_row: usize,
out_bytes_per_elem: usize,
) -> Vec<u8> {
assert!(m.is_multiple_of(16), "mt_qmm_mma_m16 requires m %% 16 == 0 (BM=16 tile)");
assert!(n.is_multiple_of(32), "mt_qmm_mma_m16 requires n %% 32 == 0 (BN=32 tile)");
assert!(k.is_multiple_of(32), "mt_qmm_mma_m16 requires k %% 32 == 0 (BK=32 step)");
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("w".into(), w.iter().flat_map(|v| v.to_le_bytes()).collect());
buffers.insert("scales".into(), scales_bytes.to_vec());
buffers.insert("biases".into(), biases_bytes.to_vec());
buffers.insert("x".into(), x_bytes.to_vec());
buffers.insert("out".into(), vec![0u8; m * n * out_bytes_per_elem]);
buffers.insert("k".into(), (k as u32).to_le_bytes().to_vec());
buffers.insert("n".into(), (n as u32).to_le_bytes().to_vec());
buffers.insert("gs_per_row".into(), (gs_per_row as u32).to_le_bytes().to_vec());
let mut kernel = mt_qmm_mma_m16::kernel_ir_for(dtype);
kernel.mode = metaltile_core::ir::KernelMode::Reduction;
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [n / 32, m / 16, 1], [64, 1, 1])
.expect("dispatch_with_grid should succeed");
result.outputs.get("out").expect("`out` buffer in dispatch result").clone()
}
#[test]
fn mt_qmm_mma_m16_matches_cpu_reference_f32() {
let m = 16usize;
let n = 32usize;
let k = 64usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_mma_m16(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
assert_eq!(actual.len(), expected.len(), "output element count");
let mut max_diff = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let diff = (e - a).abs();
if diff > max_diff {
max_diff = diff;
max_at = i;
}
}
assert!(
max_diff < 1e-3,
"max |diff| = {max_diff:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_mma_m16_matches_cpu_reference_f16() {
let m = 16usize;
let n = 32usize;
let k = 64usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales_f32: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases_f32: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x_f32: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let round_f16 = |v: f32| -> f32 { half::f16::from_f32(v).to_f32() };
let scales: Vec<f32> = scales_f32.iter().map(|&v| round_f16(v)).collect();
let biases: Vec<f32> = biases_f32.iter().map(|&v| round_f16(v)).collect();
let x: Vec<f32> = x_f32.iter().map(|&v| round_f16(v)).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> =
scales.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let biases_bytes: Vec<u8> =
biases.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let x_bytes: Vec<u8> =
x.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_mma_m16(
&ctx,
DType::F16,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
2,
);
let actual: Vec<f32> = out_bytes
.chunks_exact(2)
.map(|c| half::f16::from_bits(u16::from_le_bytes([c[0], c[1]])).to_f32())
.collect();
let mut max_rel = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let rel = (e - a).abs() / e.abs().max(1.0);
if rel > max_rel {
max_rel = rel;
max_at = i;
}
}
assert!(
max_rel < 1.5e-2,
"max relative diff = {max_rel:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_mma_m16_matches_mt_qmm_at_same_shape_f32() {
let m = 16usize;
let n = 32usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_v2 = run_qmm(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let out_mma = run_qmm_mma_m16(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let a_v2: Vec<f32> =
out_v2.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
let a_mma: Vec<f32> =
out_mma.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
let mut max_rel = 0.0_f32;
let mut max_at = 0usize;
for (i, (v2, m_)) in a_v2.iter().zip(a_mma.iter()).enumerate() {
let rel = (v2 - m_).abs() / v2.abs().max(1.0);
if rel > max_rel {
max_rel = rel;
max_at = i;
}
}
assert!(
max_rel < 1e-4,
"v2 vs mma_m16 diverge: max rel diff = {max_rel:.2e} at index {max_at} (v2 {:.6}, mma {:.6})",
a_v2[max_at],
a_mma[max_at],
);
}
#[test]
fn mt_qmm_mma_m16_runs_on_qwen3_attention_proj_shape() {
let m = 16usize;
let n = 5120usize;
let k = 5120usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8).map(|i| (i as u32).wrapping_mul(2654435761u32)).collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.01 + (i % 13) as f32 * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i % 7) as f32 * 0.0001).collect();
let x: Vec<f32> = (0..m * k).map(|i| 0.1 + ((i % 31) as f32) * 0.01).collect();
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let out_bytes = run_qmm_mma_m16(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
for (i, v) in actual.iter().enumerate() {
assert!(v.is_finite(), "non-finite output at index {i}: {v}");
}
}
#[test]
#[ignore = "perf bench, run via --ignored --nocapture"]
fn mt_qmm_v2_vs_bm2_head_to_head_f16_m_sweep() {
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new should succeed on macOS");
let group_size = 64usize;
const WARMUP: usize = 20;
const ITERS: usize = 50;
const M_SWEEP: &[usize] = &[1, 2, 4, 6, 8, 12, 16, 32];
println!();
println!(
"mt_qmm v2 vs mt_qmm_bm2 head-to-head — f16, Apple M-series (median of {ITERS} iters)"
);
let mut kernel_v2 = mt_qmm::kernel_ir_for(DType::F16);
kernel_v2.mode = metaltile_core::ir::KernelMode::Reduction;
let mut kernel_bm2 = mt_qmm_bm2::kernel_ir_for(DType::F16);
kernel_bm2.mode = metaltile_core::ir::KernelMode::Reduction;
let empty_fn_consts: BTreeMap<String, u32> = BTreeMap::new();
#[derive(Default, Clone, Copy)]
struct CellAgg {
bm2_wins: u32,
v2_wins: u32,
ratio_sum: f64, ratio_n: u32,
}
let mut per_m: BTreeMap<usize, CellAgg> = BTreeMap::new();
for &m in M_SWEEP {
per_m.insert(m, CellAgg::default());
}
for &(n, k, label) in QWEN3_SHAPES {
let gs_per_row = k / group_size;
let w: Vec<u32> = (0..n * k / 8).map(|i| (i as u32).wrapping_mul(2654435761u32)).collect();
let scales: Vec<f32> =
(0..n * gs_per_row).map(|i| 0.01 + (i % 13) as f32 * 0.001).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i % 7) as f32 * 0.0001).collect();
let w_bytes_vec: Vec<u8> = w.iter().flat_map(|v| v.to_le_bytes()).collect();
let scales_bytes: Vec<u8> =
scales.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let biases_bytes: Vec<u8> =
biases.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let w_res = ctx.upload_resident(&w_bytes_vec).expect("upload w");
let scales_res = ctx.upload_resident(&scales_bytes).expect("upload scales");
let biases_res = ctx.upload_resident(&biases_bytes).expect("upload biases");
let mut residents: BTreeMap<String, ResidentBuffer> = BTreeMap::new();
residents.insert("w".into(), w_res);
residents.insert("scales".into(), scales_res);
residents.insert("biases".into(), biases_res);
println!();
println!(" shape = {label} (n={n}, k={k})");
println!(
" {:>5} {:>10} {:>10} {:>10} {:>10} {:>10}",
"M", "v2 µs", "v2 GB/s", "bm2 µs", "bm2 GB/s", "bm2/v2"
);
for &m in M_SWEEP {
let x: Vec<f32> = (0..m * k).map(|i| 0.1 + ((i % 31) as f32) * 0.01).collect();
let x_bytes: Vec<u8> =
x.iter().flat_map(|v| half::f16::from_f32(*v).to_bits().to_le_bytes()).collect();
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("x".into(), x_bytes);
buffers.insert("out".into(), vec![0u8; m * n * 2]);
buffers.insert("k".into(), (k as u32).to_le_bytes().to_vec());
buffers.insert("n".into(), (n as u32).to_le_bytes().to_vec());
buffers.insert("gs_per_row".into(), (gs_per_row as u32).to_le_bytes().to_vec());
let mut samples_v2 = Vec::with_capacity(ITERS);
for i in 0..(WARMUP + ITERS) {
let r = ctx
.dispatch_chain(&[DispatchSpec {
kernel: &kernel_v2,
buffers: &buffers,
fn_consts: &empty_fn_consts,
grid_groups: [n / 8, m, 1],
threads_per_group: [64, 1, 1],
resident: &residents,
}])
.expect("dispatch v2");
if i >= WARMUP {
samples_v2.push(r[0].elapsed_us);
}
}
let mid = samples_v2.len() / 2;
samples_v2.select_nth_unstable_by(mid, |a, b| a.partial_cmp(b).unwrap());
let v2_us = samples_v2[mid];
let bytes = (n * k / 2 + 2 * n * gs_per_row * 2 + m * k * 2 + m * n * 2) as f64;
let v2_gbps = bytes / (v2_us * 1e-6) / 1e9;
let bm2_cell: Option<(f64, f64, f64)> = if m % 2 == 0 {
let mut samples_bm2 = Vec::with_capacity(ITERS);
for i in 0..(WARMUP + ITERS) {
let r = ctx
.dispatch_chain(&[DispatchSpec {
kernel: &kernel_bm2,
buffers: &buffers,
fn_consts: &empty_fn_consts,
grid_groups: [n / 8, m / 2, 1],
threads_per_group: [64, 1, 1],
resident: &residents,
}])
.expect("dispatch bm2");
if i >= WARMUP {
samples_bm2.push(r[0].elapsed_us);
}
}
let mid_b = samples_bm2.len() / 2;
samples_bm2.select_nth_unstable_by(mid_b, |a, b| a.partial_cmp(b).unwrap());
let bm2_us = samples_bm2[mid_b];
let bm2_gbps = bytes / (bm2_us * 1e-6) / 1e9;
let ratio = bm2_us / v2_us; Some((bm2_us, bm2_gbps, ratio))
} else {
None
};
match bm2_cell {
Some((bm2_us, bm2_gbps, ratio)) => {
println!(
" {m:>5} {v2_us:>10.2} {v2_gbps:>10.1} {bm2_us:>10.2} \
{bm2_gbps:>10.1} {ratio:>10.3}"
);
let agg = per_m.get_mut(&m).unwrap();
if ratio < 1.0 {
agg.bm2_wins += 1;
} else {
agg.v2_wins += 1;
}
agg.ratio_sum += ratio;
agg.ratio_n += 1;
},
None => {
println!(
" {m:>5} {v2_us:>10.2} {v2_gbps:>10.1} {:>10} {:>10} {:>10}",
"n/a", "n/a", "skip"
);
},
}
}
}
println!();
println!("selector-route accuracy (across all {} shapes)", QWEN3_SHAPES.len());
println!(
" {:>5} {:>10} {:>10} {:>14} {:>14} {:>10}",
"M", "bm2_wins", "v2_wins", "mean bm2/v2", "selector→", "matches?"
);
for &m in M_SWEEP {
let agg = per_m[&m];
let mean_ratio =
if agg.ratio_n > 0 { agg.ratio_sum / agg.ratio_n as f64 } else { f64::NAN };
let selector_route = if (4..=12).contains(&(m as u32)) { "bm2" } else { "v2" };
let data_winner = if agg.ratio_n == 0 {
"v2"
} else if agg.bm2_wins > agg.v2_wins {
"bm2"
} else {
"v2"
};
let matches = if selector_route == data_winner { "YES" } else { "NO" };
let mean_disp =
if mean_ratio.is_nan() { "n/a".to_string() } else { format!("{mean_ratio:.3}") };
println!(
" {m:>5} {:>10} {:>10} {:>14} {:>14} {:>10}",
agg.bm2_wins, agg.v2_wins, mean_disp, selector_route, matches
);
}
println!();
println!("legend: bm2/v2 < 1.0 ⇒ bm2 faster than v2 at that cell");
}
fn f32_to_bf16_bits(v: f32) -> u16 { half::bf16::from_f32(v).to_bits() }
fn bf16_bits_to_f32(b: u16) -> f32 { half::bf16::from_bits(b).to_f32() }
#[allow(clippy::type_complexity)]
fn build_bf16_inputs(
n: usize,
k: usize,
gs_per_row: usize,
m: usize,
) -> (Vec<u32>, Vec<f32>, Vec<f32>, Vec<f32>, Vec<u8>, Vec<u8>, Vec<u8>) {
let w: Vec<u32> = (0..n * k / 8)
.map(|i| {
let mut v = 0u32;
for bit in 0..8u32 {
v |= ((i as u32 + bit) & 0xF) << (bit * 4);
}
v
})
.collect();
let scales_f32: Vec<f32> = (0..n * gs_per_row).map(|i| 0.1 + (i as f32) * 0.001).collect();
let biases_f32: Vec<f32> = (0..n * gs_per_row).map(|i| (i as f32) * 0.0001).collect();
let x_f32: Vec<f32> = (0..m * k).map(|i| 1.0 + (i as f32) * 0.001).collect();
let scales: Vec<f32> =
scales_f32.iter().map(|&v| bf16_bits_to_f32(f32_to_bf16_bits(v))).collect();
let biases: Vec<f32> =
biases_f32.iter().map(|&v| bf16_bits_to_f32(f32_to_bf16_bits(v))).collect();
let x: Vec<f32> = x_f32.iter().map(|&v| bf16_bits_to_f32(f32_to_bf16_bits(v))).collect();
let scales_bytes: Vec<u8> =
scales_f32.iter().flat_map(|v| f32_to_bf16_bits(*v).to_le_bytes()).collect();
let biases_bytes: Vec<u8> =
biases_f32.iter().flat_map(|v| f32_to_bf16_bits(*v).to_le_bytes()).collect();
let x_bytes: Vec<u8> = x_f32.iter().flat_map(|v| f32_to_bf16_bits(*v).to_le_bytes()).collect();
(w, scales, biases, x, scales_bytes, biases_bytes, x_bytes)
}
fn check_bf16_outputs(out_bytes: &[u8], expected: &[f32], tol_rel: f32) {
let actual: Vec<f32> = out_bytes
.chunks_exact(2)
.map(|c| bf16_bits_to_f32(u16::from_le_bytes([c[0], c[1]])))
.collect();
assert_eq!(actual.len(), expected.len(), "output element count");
let mut max_rel = 0.0_f32;
let mut max_at = 0usize;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let rel = (e - a).abs() / e.abs().max(1.0);
if rel > max_rel {
max_rel = rel;
max_at = i;
}
}
assert!(
max_rel < tol_rel,
"max rel diff = {max_rel:.2e} at index {max_at} (expected {:.6}, got {:.6})",
expected[max_at],
actual[max_at],
);
}
#[test]
fn mt_qmm_bm2_matches_cpu_reference_bf16() {
let m = 8usize;
let n = 16usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let (w, scales, biases, x, scales_bytes, biases_bytes, x_bytes) =
build_bf16_inputs(n, k, gs_per_row, m);
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let _g = gpu_lock();
let ctx = Context::new().unwrap();
let out_bytes = run_qmm_bm2(
&ctx,
DType::BF16,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
2,
);
check_bf16_outputs(&out_bytes, &expected, 5e-2);
}
#[test]
fn mt_qmm_bm4_matches_cpu_reference_bf16() {
let m = 8usize;
let n = 16usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let (w, scales, biases, x, scales_bytes, biases_bytes, x_bytes) =
build_bf16_inputs(n, k, gs_per_row, m);
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let _g = gpu_lock();
let ctx = Context::new().unwrap();
let out_bytes = run_qmm_bm4(
&ctx,
DType::BF16,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
2,
);
check_bf16_outputs(&out_bytes, &expected, 5e-2);
}
#[test]
fn mt_qmm_mma_matches_cpu_reference_bf16() {
let m = 32usize;
let n = 32usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let (w, scales, biases, x, scales_bytes, biases_bytes, x_bytes) =
build_bf16_inputs(n, k, gs_per_row, m);
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let _g = gpu_lock();
let ctx = Context::new().unwrap();
let out_bytes = run_qmm_mma(
&ctx,
DType::BF16,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
2,
);
check_bf16_outputs(&out_bytes, &expected, 5e-2);
}
#[test]
fn mt_qmm_mma_m16_matches_cpu_reference_bf16() {
let m = 16usize;
let n = 32usize;
let k = 512usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let (w, scales, biases, x, scales_bytes, biases_bytes, x_bytes) =
build_bf16_inputs(n, k, gs_per_row, m);
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let _g = gpu_lock();
let ctx = Context::new().unwrap();
let out_bytes = run_qmm_mma_m16(
&ctx,
DType::BF16,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
2,
);
check_bf16_outputs(&out_bytes, &expected, 5e-2);
}
#[test]
fn mt_qmm_mma_matches_cpu_reference_f32_qwen3_shape_full_oracle() {
let m = 32usize;
let n = 5120usize;
let k = 5120usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> =
(0..n * k / 8).map(|i| ((i as u32) % 17).wrapping_mul(0xDEADBEEFu32)).collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.005 + (i % 7) as f32 * 0.0007).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i % 5) as f32 * 0.00005).collect();
let x: Vec<f32> = (0..m * k).map(|i| 0.05 + ((i % 23) as f32) * 0.003).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().unwrap();
let out_bytes = run_qmm_mma(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
assert_eq!(actual.len(), expected.len());
let mut max_diff = 0.0f32;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let d = (e - a).abs();
if d > max_diff {
max_diff = d;
assert!(
d < 5e-3,
"production-shape mma divergence at [{i}]: {a:.6} vs {e:.6} (diff {d:.2e})"
);
}
}
}
#[test]
fn mt_qmm_mma_m16_at_qwen3_shape_full_oracle_f32() {
let m = 16usize;
let n = 5120usize;
let k = 5120usize;
let group_size = 64usize;
let gs_per_row = k / group_size;
let w: Vec<u32> =
(0..n * k / 8).map(|i| ((i as u32) % 13).wrapping_mul(0xC0FFEE00u32)).collect();
let scales: Vec<f32> = (0..n * gs_per_row).map(|i| 0.003 + (i % 11) as f32 * 0.0005).collect();
let biases: Vec<f32> = (0..n * gs_per_row).map(|i| (i % 9) as f32 * 0.00003).collect();
let x: Vec<f32> = (0..m * k).map(|i| 0.07 + ((i % 19) as f32) * 0.002).collect();
let expected = cpu_qmm_reference(&w, &scales, &biases, &x, m, n, k, gs_per_row, group_size);
let scales_bytes: Vec<u8> = scales.iter().flat_map(|v| v.to_le_bytes()).collect();
let biases_bytes: Vec<u8> = biases.iter().flat_map(|v| v.to_le_bytes()).collect();
let x_bytes: Vec<u8> = x.iter().flat_map(|v| v.to_le_bytes()).collect();
let _g = gpu_lock();
let ctx = Context::new().unwrap();
let out_bytes = run_qmm_mma_m16(
&ctx,
DType::F32,
&w,
&scales_bytes,
&biases_bytes,
&x_bytes,
m,
n,
k,
gs_per_row,
4,
);
let actual: Vec<f32> =
out_bytes.chunks_exact(4).map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])).collect();
let mut max_diff = 0.0f32;
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
let d = (e - a).abs();
if d > max_diff {
max_diff = d;
assert!(d < 5e-3, "production-shape mma_m16 divergence at [{i}]: {a:.6} vs {e:.6}");
}
}
}