#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
#![cfg(target_vendor = "apple")]
use mlx_native::{
DType, GgmlQuantizedMatmulIdParams, GgmlQuantizedMatmulParams, GgmlType,
KernelRegistry, MlxDevice,
};
fn pseudo_random_f32(seed: u64, n: usize) -> Vec<f32> {
let mut state = seed;
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32) / (u32::MAX as f32) - 0.5
})
.collect()
}
fn encode_q4k_scales(sc: &[u8; 8], m: &[u8; 8]) -> [u8; 12] {
let mut s = [0u8; 12];
for j in 0..4 {
s[j] = sc[j] & 63;
s[j + 4] = m[j] & 63;
}
for j in 4..8 {
s[j + 4] = (s[j + 4] & 0xF0) | (sc[j] & 0xF);
s[j + 4] |= (m[j] & 0xF) << 4;
s[j - 4] |= (sc[j] >> 4) << 6;
s[j] |= (m[j] >> 4) << 6;
}
s
}
fn decode_q4k_scale_min(j: usize, scales: &[u8]) -> (u8, u8) {
if j < 4 {
(scales[j] & 63, scales[j + 4] & 63)
} else {
let sc = (scales[j + 4] & 0xF) | ((scales[j - 4] >> 6) << 4);
let m = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4);
(sc, m)
}
}
fn pack_q4_k_block(values: &[f32]) -> [u8; 144] {
assert_eq!(values.len(), 256, "Q4_K block requires 256 values");
let global_min: f32 = values.iter().cloned().fold(f32::MAX, f32::min);
let shift = if global_min < 0.0 { -global_min } else { 0.0 };
let global_max: f32 = values.iter().map(|&v| v + shift).fold(0.0f32, f32::max);
let d_val = if global_max > 0.0 {
global_max / (15.0 * 63.0)
} else {
1.0
};
let dmin_val = if shift > 0.0 { shift / 63.0 } else { 1.0 };
let id_d = 1.0 / d_val;
let id_dmin = 1.0 / dmin_val;
let mut sc_arr = [0u8; 8];
let mut m_arr = [0u8; 8];
for s in 0..8 {
let sub = &values[s * 32..(s + 1) * 32];
let sub_shifted_max = sub.iter().map(|&v| v + shift).fold(0.0f32, f32::max);
sc_arr[s] = ((sub_shifted_max / 15.0) * id_d)
.round()
.clamp(0.0, 63.0) as u8;
m_arr[s] = (shift * id_dmin).round().clamp(0.0, 63.0) as u8;
}
let mut q4 = [0u8; 256];
for s in 0..8 {
let sub = &values[s * 32..(s + 1) * 32];
let sc = sc_arr[s] as f32;
let m_val = m_arr[s] as f32;
let sub_scale = d_val * sc;
let sub_min = dmin_val * m_val;
let inv_sub_scale = if sub_scale != 0.0 { 1.0 / sub_scale } else { 0.0 };
for (i, &v) in sub.iter().enumerate() {
let q = ((v + sub_min) * inv_sub_scale).round().clamp(0.0, 15.0) as u8;
q4[s * 32 + i] = q;
}
}
let mut qs = [0u8; 128];
for p in 0..4 {
let s0 = 2 * p;
let s1 = 2 * p + 1;
for l in 0..32 {
let lo = q4[s0 * 32 + l] & 0x0F;
let hi = q4[s1 * 32 + l] & 0x0F;
qs[p * 32 + l] = lo | (hi << 4);
}
}
let scales = encode_q4k_scales(&sc_arr, &m_arr);
let d_f16 = half::f16::from_f32(d_val);
let dmin_f16 = half::f16::from_f32(dmin_val);
let mut block = [0u8; 144];
block[0..2].copy_from_slice(&d_f16.to_le_bytes());
block[2..4].copy_from_slice(&dmin_f16.to_le_bytes());
block[4..16].copy_from_slice(&scales);
block[16..144].copy_from_slice(&qs);
block
}
fn pack_q4_k(values: &[f32]) -> Vec<u8> {
assert!(
values.len() % 256 == 0,
"Q4_K requires multiple of 256 values, got {}",
values.len()
);
let mut out = Vec::with_capacity(values.len() / 256 * 144);
for chunk in values.chunks(256) {
out.extend_from_slice(&pack_q4_k_block(chunk));
}
out
}
fn cpu_dequant_q4k_block(block: &[u8; 144]) -> [f32; 256] {
let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
let dmin = half::f16::from_le_bytes([block[2], block[3]]).to_f32();
let scales = &block[4..16];
let qs = &block[16..144];
let mut out = [0.0f32; 256];
for p in 0..4 {
let s0 = 2 * p;
let s1 = 2 * p + 1;
let (sc0, m0) = decode_q4k_scale_min(s0, scales);
let (sc1, m1) = decode_q4k_scale_min(s1, scales);
let d0 = d * sc0 as f32;
let mn0 = dmin * m0 as f32;
let d1 = d * sc1 as f32;
let mn1 = dmin * m1 as f32;
for l in 0..32 {
let byte = qs[p * 32 + l];
let lo = (byte & 0x0F) as u32;
let hi = (byte >> 4) as u32;
out[s0 * 32 + l] = d0 * lo as f32 - mn0;
out[s1 * 32 + l] = d1 * hi as f32 - mn1;
}
}
out
}
fn cpu_q4k_matvec(weight_packed: &[u8], input: &[f32], n: usize, k: usize) -> Vec<f32> {
assert_eq!(k % 256, 0);
let n_blocks_per_row = k / 256;
assert_eq!(weight_packed.len(), n * n_blocks_per_row * 144);
let mut out = vec![0.0f32; n];
for row in 0..n {
let mut acc = 0.0f32;
for b in 0..n_blocks_per_row {
let block_offset = (row * n_blocks_per_row + b) * 144;
let block_arr: [u8; 144] = weight_packed[block_offset..block_offset + 144]
.try_into()
.expect("block slice is 144 bytes");
let dq = cpu_dequant_q4k_block(&block_arr);
for i in 0..256 {
acc += dq[i] * input[b * 256 + i];
}
}
out[row] = acc;
}
out
}
#[test]
fn q4k_encode_decode_roundtrip_all_6bit_pairs() {
for sc_pattern in 0..64u8 {
for m_pattern in 0..64u8 {
let sc = [sc_pattern; 8];
let m = [m_pattern; 8];
let encoded = encode_q4k_scales(&sc, &m);
for j in 0..8 {
let (got_sc, got_m) = decode_q4k_scale_min(j, &encoded);
assert_eq!(
got_sc, sc_pattern,
"round-trip sc mismatch at j={}: encoded={:?}",
j, encoded
);
assert_eq!(
got_m, m_pattern,
"round-trip m mismatch at j={}: encoded={:?}",
j, encoded
);
}
}
}
}
#[test]
fn q4k_all_zeros_dequants_to_zero() {
let block = [0u8; 144];
let out = cpu_dequant_q4k_block(&block);
for (i, &v) in out.iter().enumerate() {
assert!(v.abs() < 1e-9, "expected 0 at {}, got {}", i, v);
}
}
#[test]
fn q4k_first_value_low_nibble() {
let mut block = [0u8; 144];
let d = half::f16::from_f32(1.0);
block[0..2].copy_from_slice(&d.to_le_bytes());
block[4] = 1; block[16] = 0x0F; let out = cpu_dequant_q4k_block(&block);
assert!(
(out[0] - 15.0).abs() < 1e-6,
"expected 15.0 at idx 0, got {}",
out[0]
);
assert!(out[32].abs() < 1e-6, "expected 0 at idx 32, got {}", out[32]);
}
#[test]
fn q4k_dmin_offset() {
let mut block = [0u8; 144];
let d = half::f16::from_f32(2.0);
let dmin = half::f16::from_f32(3.0);
block[0..2].copy_from_slice(&d.to_le_bytes());
block[2..4].copy_from_slice(&dmin.to_le_bytes());
block[4] = 1; block[8] = 1; block[16] = 0x0F; let out = cpu_dequant_q4k_block(&block);
assert!(
(out[0] - 27.0).abs() < 1e-6,
"expected 27.0 (= 2*1*15 - 3*1) at idx 0, got {}",
out[0]
);
}
fn run_q4k_mv_vs_cpu(n: usize, k: usize, seed_w: u64, seed_in: u64, tolerance: f32, label: &str) {
assert_eq!(k % 256, 0, "Q4_K requires k divisible by 256");
assert_eq!(n % 2, 0, "Q4_K mv kernel requires n even (2 rows per tg)");
let device = MlxDevice::new().expect("Metal device");
let mut registry = KernelRegistry::new();
let f32_sz = std::mem::size_of::<f32>();
let weights_f32 = pseudo_random_f32(seed_w, n * k);
let input = pseudo_random_f32(seed_in, k);
let mut weight_bytes = Vec::with_capacity(n * (k / 256) * 144);
for row in 0..n {
weight_bytes.extend_from_slice(&pack_q4_k(&weights_f32[row * k..(row + 1) * k]));
}
let mut input_buf = device
.alloc_buffer(k * f32_sz, DType::F32, vec![k])
.unwrap();
input_buf
.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&input);
let mut weight_buf = device
.alloc_buffer(weight_bytes.len(), DType::U8, vec![weight_bytes.len()])
.unwrap();
weight_buf
.as_mut_slice::<u8>()
.unwrap()
.copy_from_slice(&weight_bytes);
let mut output_buf = device
.alloc_buffer(n * f32_sz, DType::F32, vec![n])
.unwrap();
{
let sl = output_buf.as_mut_slice::<f32>().unwrap();
for v in sl.iter_mut() {
*v = 0.0;
}
}
let params = GgmlQuantizedMatmulParams {
m: 1,
n: n as u32,
k: k as u32,
ggml_type: GgmlType::Q4_K,
};
let mut encoder = device.command_encoder().unwrap();
mlx_native::quantized_matmul_ggml(
&mut encoder,
&mut registry,
&device,
&input_buf,
&weight_buf,
&mut output_buf,
¶ms,
)
.expect("dense Q4_K dispatch");
encoder.commit_and_wait().expect("GPU exec");
let gpu_out = output_buf.as_slice::<f32>().unwrap().to_vec();
let cpu_out = cpu_q4k_matvec(&weight_bytes, &input, n, k);
let mut max_err = 0.0f32;
let mut max_err_idx = 0usize;
let mut err_count = 0;
for i in 0..n {
let err = (gpu_out[i] - cpu_out[i]).abs();
if err > max_err {
max_err = err;
max_err_idx = i;
}
if err > tolerance {
if err_count < 5 {
eprintln!(
" {}: mismatch [{}]: gpu={:.6} cpu={:.6} err={:.6}",
label, i, gpu_out[i], cpu_out[i], err
);
}
err_count += 1;
}
}
assert_eq!(
err_count, 0,
"{}: {} mismatches > {:.6} (max_err={:.6} at idx {})",
label, err_count, tolerance, max_err, max_err_idx
);
eprintln!(
" PASS {}: n={} k={} max_err={:.6} (tol={:.6})",
label, n, k, max_err, tolerance
);
}
#[test]
fn q4k_mv_synthetic_2x256() {
run_q4k_mv_vs_cpu(2, 256, 42, 100, 1e-3, "Q4_K mv 2x256");
}
#[test]
fn q4k_mv_synthetic_8x256() {
run_q4k_mv_vs_cpu(8, 256, 7, 11, 1e-3, "Q4_K mv 8x256");
}
#[test]
fn q4k_mv_synthetic_4x512() {
run_q4k_mv_vs_cpu(4, 512, 13, 17, 1e-3, "Q4_K mv 4x512");
}
#[test]
fn q4k_mv_production_shape_64x1024() {
run_q4k_mv_vs_cpu(64, 1024, 99, 199, 5e-2, "Q4_K mv 64x1024 prod-ish");
}
fn run_q4k_mvid_vs_cpu(
n_tokens: usize,
n_experts: usize,
top_k: usize,
n: usize,
k: usize,
tolerance: f32,
) {
assert_eq!(k % 256, 0);
assert_eq!(n % 2, 0, "Q4_K mv_id requires n even (2 rows per tg)");
let device = MlxDevice::new().expect("Metal device");
let mut registry = KernelRegistry::new();
let f32_sz = std::mem::size_of::<f32>();
let u32_sz = std::mem::size_of::<u32>();
let input_data = pseudo_random_f32(42, n_tokens * k);
let mut expert_packed: Vec<Vec<u8>> = Vec::new();
for e in 0..n_experts {
let w_data = pseudo_random_f32(100 + e as u64, n * k);
expert_packed.push(pack_q4_k(&w_data));
}
let per_expert_bytes = expert_packed[0].len();
let mut stacked = Vec::with_capacity(per_expert_bytes * n_experts);
for ep in &expert_packed {
stacked.extend_from_slice(ep);
}
let mut ids = Vec::with_capacity(n_tokens * top_k);
for t in 0..n_tokens {
for s in 0..top_k {
ids.push(((t * 3 + s * 7 + 1) % n_experts) as u32);
}
}
let total_rows = n_tokens * top_k;
let mut input_buf = device
.alloc_buffer(n_tokens * k * f32_sz, DType::F32, vec![n_tokens * k])
.unwrap();
input_buf
.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&input_data);
let mut weight_buf = device
.alloc_buffer(stacked.len(), DType::U8, vec![stacked.len()])
.unwrap();
weight_buf
.as_mut_slice::<u8>()
.unwrap()
.copy_from_slice(&stacked);
let mut ids_buf = device
.alloc_buffer(total_rows * u32_sz, DType::U32, vec![total_rows])
.unwrap();
ids_buf
.as_mut_slice::<u32>()
.unwrap()
.copy_from_slice(&ids);
let mut id_output_buf = device
.alloc_buffer(total_rows * n * f32_sz, DType::F32, vec![total_rows * n])
.unwrap();
{
let sl = id_output_buf.as_mut_slice::<f32>().unwrap();
for v in sl.iter_mut() {
*v = 0.0;
}
}
{
let params = GgmlQuantizedMatmulIdParams {
n_tokens: n_tokens as u32,
top_k: top_k as u32,
n: n as u32,
k: k as u32,
n_experts: n_experts as u32,
expert_stride: per_expert_bytes as u64,
ggml_type: GgmlType::Q4_K,
};
let mut encoder = device.command_encoder().unwrap();
mlx_native::ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml(
&mut encoder,
&mut registry,
&device,
&input_buf,
&weight_buf,
&ids_buf,
&mut id_output_buf,
¶ms,
)
.expect("Q4_K mv_id dispatch");
encoder.commit_and_wait().unwrap();
}
let mut cpu_results = vec![0.0f32; total_rows * n];
for t in 0..n_tokens {
for s in 0..top_k {
let row_idx = t * top_k + s;
let expert_id = ids[row_idx] as usize;
let expert_weights =
&stacked[expert_id * per_expert_bytes..(expert_id + 1) * per_expert_bytes];
let input_slice = &input_data[t * k..(t + 1) * k];
let cpu_out = cpu_q4k_matvec(expert_weights, input_slice, n, k);
cpu_results[row_idx * n..(row_idx + 1) * n].copy_from_slice(&cpu_out);
}
}
let gpu_out = id_output_buf.as_slice::<f32>().unwrap();
let mut max_err = 0.0f32;
let mut max_idx = 0usize;
let mut err_count = 0;
for i in 0..total_rows * n {
let err = (gpu_out[i] - cpu_results[i]).abs();
if err > max_err {
max_err = err;
max_idx = i;
}
if err > tolerance {
if err_count < 5 {
eprintln!(
" Q4_K mv_id mismatch [{}]: gpu={:.6} cpu={:.6} err={:.6}",
i, gpu_out[i], cpu_results[i], err
);
}
err_count += 1;
}
}
assert_eq!(
err_count, 0,
"Q4_K mv_id vs cpu: {} mismatches > {:.6} (max_err={:.6} at {})",
err_count, tolerance, max_err, max_idx
);
eprintln!(
" PASS Q4_K mv_id: n={} k={} {} tokens top-{} max_err={:.6}",
n, k, n_tokens, top_k, max_err
);
}
#[test]
fn q4k_mvid_1tok_4experts_top1() {
run_q4k_mvid_vs_cpu(1, 4, 1, 2, 256, 1e-3);
}
#[test]
fn q4k_mvid_1tok_8experts_top8() {
run_q4k_mvid_vs_cpu(1, 8, 8, 4, 256, 1e-3);
}
#[test]
fn q4k_mvid_4tok_8experts_top2() {
run_q4k_mvid_vs_cpu(4, 8, 2, 4, 512, 1e-3);
}
#[test]
fn q4k_mvid_8tok_16experts_top4() {
run_q4k_mvid_vs_cpu(8, 16, 4, 16, 1024, 5e-2);
}
#[test]
fn q4k_mmid_16tok_8experts_top8_k512() {
run_q4k_mvid_vs_cpu(16, 8, 8, 16, 512, 5e-2);
}
#[test]
fn q4k_mmid_64tok_4experts_top1_k1024() {
run_q4k_mvid_vs_cpu(64, 4, 1, 16, 1024, 5e-2);
}
#[test]
fn q4k_mmid_32tok_16experts_top8_k2048() {
run_q4k_mvid_vs_cpu(32, 16, 8, 16, 2048, 5e-2);
}