#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
#![cfg(target_vendor = "apple")]
use mlx_native::{DType, GgmlQuantizedMatmulParams, GgmlType, KernelRegistry, MlxDevice};
fn pseudo_random_f32(seed: u64, n: usize) -> Vec<f32> {
let mut state = seed;
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let frac = ((state >> 33) as f32) / (u32::MAX as f32) - 0.5;
frac
})
.collect()
}
fn pack_q4_0(values: &[f32]) -> Vec<u8> {
assert!(values.len() % 32 == 0);
let mut buf = Vec::new();
for block in values.chunks(32) {
let amax = block.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
let d = amax / 7.0;
let id = if d != 0.0 { 1.0 / d } else { 0.0 };
let d_f16 = half::f16::from_f32(d);
buf.extend_from_slice(&d_f16.to_le_bytes());
let mut nibbles = [0u8; 16];
for i in 0..16 {
let v0 = (block[i] * id + 8.0).round().clamp(0.0, 15.0) as u8;
let v1 = (block[i + 16] * id + 8.0).round().clamp(0.0, 15.0) as u8;
nibbles[i] = v0 | (v1 << 4);
}
buf.extend_from_slice(&nibbles);
}
buf
}
fn pack_q8_0(values: &[f32]) -> Vec<u8> {
assert!(values.len() % 32 == 0);
let mut buf = Vec::new();
for block in values.chunks(32) {
let amax = block.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
let d = amax / 127.0;
let id = if d != 0.0 { 1.0 / d } else { 0.0 };
let d_f16 = half::f16::from_f32(d);
buf.extend_from_slice(&d_f16.to_le_bytes());
for &v in block {
let q = (v * id).round().clamp(-128.0, 127.0) as i8;
buf.push(q as u8);
}
}
buf
}
fn pack_q6_k(values: &[f32]) -> Vec<u8> {
assert!(values.len() % 256 == 0);
let mut buf = Vec::new();
for block in values.chunks(256) {
let mut sub_scales = [0.0f32; 16];
let mut sub_scale_int = [0i8; 16];
let mut max_scale: f32 = 0.0;
for (s, sub) in block.chunks(16).enumerate() {
let amax = sub.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
sub_scales[s] = amax;
if amax > max_scale {
max_scale = amax;
}
}
let d = max_scale / (32.0 * 127.0);
let id = if d != 0.0 { 1.0 / d } else { 0.0 };
for s in 0..16 {
sub_scale_int[s] = if sub_scales[s] != 0.0 {
(sub_scales[s] * id / 32.0).round().clamp(-128.0, 127.0) as i8
} else {
0
};
}
let mut q6 = [0u8; 256];
for (s, sub) in block.chunks(16).enumerate() {
let sc = sub_scale_int[s] as f32;
let sub_d = d * sc;
let sub_id = if sub_d != 0.0 { 1.0 / sub_d } else { 0.0 };
for (i, &v) in sub.iter().enumerate() {
let q = (v * sub_id + 32.0).round().clamp(0.0, 63.0) as u8;
q6[s * 16 + i] = q;
}
}
let mut ql = [0u8; 128];
let mut qh = [0u8; 64];
for l0_base in (0..32usize).step_by(4) {
for l in 0..4usize {
let ql_idx = l0_base + l;
let v0 = q6[l0_base + l];
let v2 = q6[l0_base + l + 64];
ql[ql_idx] = (v0 & 0x0F) | ((v2 & 0x0F) << 4);
let v1 = q6[l0_base + l + 32];
let v3 = q6[l0_base + l + 96];
ql[ql_idx + 32] = (v1 & 0x0F) | ((v3 & 0x0F) << 4);
let h0 = (v0 >> 4) & 0x03;
let h1 = (v1 >> 4) & 0x03;
let h2 = (v2 >> 4) & 0x03;
let h3 = (v3 >> 4) & 0x03;
qh[ql_idx] = h0 | (h1 << 2) | (h2 << 4) | (h3 << 6);
}
}
for l0_base in (0..32usize).step_by(4) {
for l in 0..4usize {
let ql_idx = 64 + l0_base + l;
let qh_idx = 32 + l0_base + l;
let v0 = q6[128 + l0_base + l];
let v2 = q6[128 + l0_base + l + 64];
ql[ql_idx] = (v0 & 0x0F) | ((v2 & 0x0F) << 4);
let v1 = q6[128 + l0_base + l + 32];
let v3 = q6[128 + l0_base + l + 96];
ql[ql_idx + 32] = (v1 & 0x0F) | ((v3 & 0x0F) << 4);
let h0 = (v0 >> 4) & 0x03;
let h1 = (v1 >> 4) & 0x03;
let h2 = (v2 >> 4) & 0x03;
let h3 = (v3 >> 4) & 0x03;
qh[qh_idx] = h0 | (h1 << 2) | (h2 << 4) | (h3 << 6);
}
}
buf.extend_from_slice(&ql);
buf.extend_from_slice(&qh);
buf.extend_from_slice(
&sub_scale_int.iter().map(|&s| s as u8).collect::<Vec<_>>(),
);
let d_f16 = half::f16::from_f32(d);
buf.extend_from_slice(&d_f16.to_le_bytes());
}
buf
}
enum DispatchPath {
Mv,
Mm,
}
fn run_qmatmul_path(
path: DispatchPath,
m: usize,
n: usize,
k: usize,
ggml_type: GgmlType,
weight_bytes: &[u8],
input: &[f32],
) -> Vec<f32> {
let device = MlxDevice::new().expect("Metal device");
let mut registry = KernelRegistry::new();
let input_bytes = m * k * 4;
let mut input_buf = device
.alloc_buffer(input_bytes, DType::F32, vec![m, k])
.expect("alloc input");
input_buf
.as_mut_slice::<f32>()
.expect("input mut slice")
.copy_from_slice(input);
let mut weight_buf = device
.alloc_buffer(weight_bytes.len(), DType::U8, vec![weight_bytes.len()])
.expect("alloc weight");
weight_buf
.as_mut_slice::<u8>()
.expect("weight mut slice")
.copy_from_slice(weight_bytes);
let output_bytes = m * n * 4;
let mut output_buf = device
.alloc_buffer(output_bytes, DType::F32, vec![m, n])
.expect("alloc output");
for v in output_buf
.as_mut_slice::<f32>()
.expect("output mut slice")
.iter_mut()
{
*v = 0.0;
}
let params = GgmlQuantizedMatmulParams {
m: m as u32,
n: n as u32,
k: k as u32,
ggml_type,
};
let mut encoder = device.command_encoder().expect("encoder");
match path {
DispatchPath::Mv => {
mlx_native::quantized_matmul_ggml(
&mut encoder,
&mut registry,
&device,
&input_buf,
&weight_buf,
&mut output_buf,
¶ms,
)
.expect("mv dispatch");
}
DispatchPath::Mm => {
mlx_native::dispatch_mm_for_test(
&mut encoder,
&mut registry,
&device,
&input_buf,
&weight_buf,
&mut output_buf,
¶ms,
)
.expect("mm dispatch");
}
}
encoder.commit_and_wait().expect("GPU execution");
output_buf
.as_slice::<f32>()
.expect("read output")
.to_vec()
}
fn run_mv_gpu(
m: usize,
n: usize,
k: usize,
ggml_type: GgmlType,
weight_bytes: &[u8],
input: &[f32],
) -> Vec<f32> {
run_qmatmul_path(DispatchPath::Mv, m, n, k, ggml_type, weight_bytes, input)
}
fn run_mm_gpu(
m: usize,
n: usize,
k: usize,
ggml_type: GgmlType,
weight_bytes: &[u8],
input: &[f32],
) -> Vec<f32> {
run_qmatmul_path(DispatchPath::Mm, m, n, k, ggml_type, weight_bytes, input)
}
fn check_mm_matches_mv_by_row(
m: usize,
n: usize,
k: usize,
ggml_type: GgmlType,
weight_bytes: &[u8],
input: &[f32],
tolerance_abs: f32,
label: &str,
) {
let mut mv_output = vec![0.0f32; m * n];
for row in 0..m {
let row_input = &input[row * k..(row + 1) * k];
let row_output = run_mv_gpu(1, n, k, ggml_type, weight_bytes, row_input);
mv_output[row * n..(row + 1) * n].copy_from_slice(&row_output);
}
let mm_output = run_mm_gpu(m, n, k, ggml_type, weight_bytes, input);
let mut max_err: f32 = 0.0;
let mut max_err_idx = 0usize;
for i in 0..mm_output.len() {
let err = (mm_output[i] - mv_output[i]).abs();
if err > max_err {
max_err = err;
max_err_idx = i;
}
}
if max_err > tolerance_abs {
let row = max_err_idx / n;
let col = max_err_idx % n;
panic!(
"{}: mm diverges from mv at [{},{}]: mm={} mv={} err={} (tol={}, M={}, N={}, K={})",
label, row, col, mm_output[max_err_idx], mv_output[max_err_idx],
max_err, tolerance_abs, m, n, k,
);
}
eprintln!(
"{}: PASS (max_err={:.6}, tol={}, M={}, N={}, K={})",
label, max_err, tolerance_abs, m, n, k,
);
}
#[test]
fn test_q4_0_mm_matches_mv_small() {
let m = 16usize;
let n = 64usize;
let k = 128usize;
let weights_f32 = pseudo_random_f32(0xA1B2, n * k);
let mut weight_bytes = Vec::new();
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q4_0(&weights_f32[row * k..(row + 1) * k]));
}
let input = pseudo_random_f32(0x1234, m * k);
check_mm_matches_mv_by_row(m, n, k, GgmlType::Q4_0, &weight_bytes, &input, 5e-3,
"Q4_0 mm matches mv, M=16 N=64 K=128");
}
#[test]
fn test_q4_0_mm_matches_mv_prefill_shape() {
let m = 64usize;
let n = 256usize;
let k = 2048usize;
let weights_f32 = pseudo_random_f32(0xC0FFEE, n * k);
let mut weight_bytes = Vec::new();
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q4_0(&weights_f32[row * k..(row + 1) * k]));
}
let input = pseudo_random_f32(0xDEADBEEF, m * k);
check_mm_matches_mv_by_row(m, n, k, GgmlType::Q4_0, &weight_bytes, &input, 5e-2,
"Q4_0 mm matches mv, M=64 N=256 K=2048");
}
#[test]
fn test_q4_0_mm_matches_mv_irregular() {
let m = 40usize; let n = 96usize; let k = 256usize;
let weights_f32 = pseudo_random_f32(0x13, n * k);
let mut weight_bytes = Vec::new();
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q4_0(&weights_f32[row * k..(row + 1) * k]));
}
let input = pseudo_random_f32(0x42, m * k);
check_mm_matches_mv_by_row(m, n, k, GgmlType::Q4_0, &weight_bytes, &input, 1e-2,
"Q4_0 mm matches mv, M=40 N=96 K=256 (partial tiles)");
}
#[test]
fn test_q8_0_mm_matches_mv_small() {
let m = 16usize;
let n = 64usize;
let k = 128usize;
let weights_f32 = pseudo_random_f32(0xB1A2, n * k);
let mut weight_bytes = Vec::new();
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q8_0(&weights_f32[row * k..(row + 1) * k]));
}
let input = pseudo_random_f32(0x5678, m * k);
check_mm_matches_mv_by_row(m, n, k, GgmlType::Q8_0, &weight_bytes, &input, 1e-3,
"Q8_0 mm matches mv, M=16 N=64 K=128");
}
#[test]
fn test_q8_0_mm_matches_mv_prefill_shape() {
let m = 64usize;
let n = 128usize;
let k = 2112usize;
let weights_f32 = pseudo_random_f32(0xBEEF, n * k);
let mut weight_bytes = Vec::new();
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q8_0(&weights_f32[row * k..(row + 1) * k]));
}
let input = pseudo_random_f32(0xBABE, m * k);
check_mm_matches_mv_by_row(m, n, k, GgmlType::Q8_0, &weight_bytes, &input, 2e-2,
"Q8_0 mm matches mv, M=64 N=128 K=2112");
}
#[test]
fn test_q8_0_mm_matches_mv_irregular() {
let m = 17usize;
let n = 72usize;
let k = 256usize;
let weights_f32 = pseudo_random_f32(0x21, n * k);
let mut weight_bytes = Vec::new();
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q8_0(&weights_f32[row * k..(row + 1) * k]));
}
let input = pseudo_random_f32(0x22, m * k);
check_mm_matches_mv_by_row(m, n, k, GgmlType::Q8_0, &weight_bytes, &input, 5e-3,
"Q8_0 mm matches mv, M=17 N=72 K=256 (partial tiles)");
}
#[test]
fn test_q6_k_mm_matches_mv_small() {
let m = 16usize;
let n = 32usize;
let k = 256usize;
let weights_f32 = pseudo_random_f32(0xD1E2, n * k);
let mut weight_bytes = Vec::new();
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q6_k(&weights_f32[row * k..(row + 1) * k]));
}
let input = pseudo_random_f32(0x9ABC, m * k);
check_mm_matches_mv_by_row(m, n, k, GgmlType::Q6_K, &weight_bytes, &input, 5e-3,
"Q6_K mm matches mv, M=16 N=32 K=256");
}
#[test]
fn test_q6_k_mm_matches_mv_prefill_shape() {
let m = 64usize;
let n = 128usize;
let k = 2048usize;
let weights_f32 = pseudo_random_f32(0xFACE, n * k);
let mut weight_bytes = Vec::new();
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q6_k(&weights_f32[row * k..(row + 1) * k]));
}
let input = pseudo_random_f32(0xF00D, m * k);
check_mm_matches_mv_by_row(m, n, k, GgmlType::Q6_K, &weight_bytes, &input, 2e-2,
"Q6_K mm matches mv, M=64 N=128 K=2048");
}
#[test]
fn test_q6_k_mm_matches_mv_irregular() {
let m = 33usize; let n = 100usize;
let k = 512usize;
let weights_f32 = pseudo_random_f32(0x55, n * k);
let mut weight_bytes = Vec::new();
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q6_k(&weights_f32[row * k..(row + 1) * k]));
}
let input = pseudo_random_f32(0x66, m * k);
check_mm_matches_mv_by_row(m, n, k, GgmlType::Q6_K, &weight_bytes, &input, 5e-3,
"Q6_K mm matches mv, M=33 N=100 K=512 (partial tiles)");
}
#[test]
fn test_dispatcher_produces_finite_output_at_m_1() {
let m = 1usize;
let n = 16usize;
let k = 64usize;
let weights_f32 = pseudo_random_f32(0xAA, n * k);
let mut weight_bytes = Vec::new();
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q4_0(&weights_f32[row * k..(row + 1) * k]));
}
let input = pseudo_random_f32(0xBB, m * k);
let out = run_mv_gpu(m, n, k, GgmlType::Q4_0, &weight_bytes, &input);
assert_eq!(out.len(), m * n);
let abs_sum: f32 = out.iter().map(|v| v.abs()).sum();
assert!(abs_sum > 0.0);
for v in &out {
assert!(v.is_finite());
}
eprintln!("test_dispatcher_produces_finite_output_at_m_1: PASS (abs_sum={})", abs_sum);
}
#[test]
fn test_mm_for_test_helper_works_at_m_just_above_threshold() {
let m = (mlx_native::MM_ROUTING_THRESHOLD + 1) as usize;
let n = 32usize;
let k = 128usize;
let weights_f32 = pseudo_random_f32(0xEE, n * k);
let mut weight_bytes = Vec::new();
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q4_0(&weights_f32[row * k..(row + 1) * k]));
}
let input = pseudo_random_f32(0xFF, m * k);
check_mm_matches_mv_by_row(m, n, k, GgmlType::Q4_0, &weight_bytes, &input, 5e-3,
"Q4_0 mm dispatch at m=threshold+1");
}