#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
use mlx_native::ops::quantized_matmul_id::{quantized_matmul_id, QuantizedMatmulIdParams};
use mlx_native::{DType, KernelRegistry, MlxDevice};
fn setup() -> (MlxDevice, KernelRegistry) {
let device = MlxDevice::new().expect("MlxDevice::new");
let registry = KernelRegistry::new();
(device, registry)
}
fn quantize_4bit(
data: &[f32],
n: usize,
k: usize,
group_size: usize,
) -> (Vec<u8>, Vec<u16>, Vec<u16>) {
let num_groups_per_row = (k + group_size - 1) / group_size;
let packs_per_row = (k + 7) / 8;
let mut weight_bytes = vec![0u8; n * packs_per_row * 4];
let mut scales = vec![0u16; n * num_groups_per_row];
let mut biases = vec![0u16; n * num_groups_per_row];
for row in 0..n {
for g in 0..num_groups_per_row {
let start = g * group_size;
let end = (start + group_size).min(k);
let mut min_val = f32::MAX;
let mut max_val = f32::MIN;
for i in start..end {
let v = data[row * k + i];
if v < min_val { min_val = v; }
if v > max_val { max_val = v; }
}
let range = max_val - min_val;
let scale = if range < 1e-10 { 1.0f32 } else { range / 15.0 };
let bias = min_val;
scales[row * num_groups_per_row + g] = f32_to_bf16(scale);
biases[row * num_groups_per_row + g] = f32_to_bf16(bias);
for i in start..end {
let v = data[row * k + i];
let qval = ((v - bias) / scale).round().clamp(0.0, 15.0) as u32;
let pack_idx = i / 8;
let in_pack = i % 8;
let byte_offset = row * packs_per_row * 4 + pack_idx * 4;
let existing = u32::from_le_bytes([
weight_bytes[byte_offset],
weight_bytes[byte_offset + 1],
weight_bytes[byte_offset + 2],
weight_bytes[byte_offset + 3],
]);
let updated = existing | (qval << (4 * in_pack as u32));
let bytes = updated.to_le_bytes();
weight_bytes[byte_offset] = bytes[0];
weight_bytes[byte_offset + 1] = bytes[1];
weight_bytes[byte_offset + 2] = bytes[2];
weight_bytes[byte_offset + 3] = bytes[3];
}
}
}
(weight_bytes, scales, biases)
}
fn dequant_matmul_ref(
input_row: &[f32],
weight_bytes: &[u8],
scales: &[u16],
biases: &[u16],
n: usize,
k: usize,
group_size: usize,
) -> Vec<f32> {
let num_groups_per_row = (k + group_size - 1) / group_size;
let packs_per_row = (k + 7) / 8;
let mut output = vec![0.0f32; n];
for col in 0..n {
let sb_base = col * num_groups_per_row;
let mut acc = 0.0f64;
for ki in 0..k {
let g = ki / group_size;
let scale = bf16_to_f32(scales[sb_base + g]);
let bias = bf16_to_f32(biases[sb_base + g]);
let pack_idx = ki / 8;
let in_pack = ki % 8;
let byte_offset = col * packs_per_row * 4 + pack_idx * 4;
let packed = u32::from_le_bytes([
weight_bytes[byte_offset],
weight_bytes[byte_offset + 1],
weight_bytes[byte_offset + 2],
weight_bytes[byte_offset + 3],
]);
let qval = (packed >> (4 * in_pack as u32)) & 0xF;
let w_bf16 = f32_to_bf16(qval as f32);
let w_dequant = bf16_to_f32(
bf16_mul_add(w_bf16, f32_to_bf16(scale), f32_to_bf16(bias)),
);
let x_bf16 = bf16_to_f32(f32_to_bf16(input_row[ki]));
acc += (w_dequant as f64) * (x_bf16 as f64);
}
output[col] = acc as f32;
}
output
}
fn f32_to_bf16(v: f32) -> u16 {
let bits = v.to_bits();
let round = ((bits >> 16) & 1) + 0x7FFF;
((bits.wrapping_add(round)) >> 16) as u16
}
fn bf16_to_f32(v: u16) -> f32 {
f32::from_bits((v as u32) << 16)
}
fn bf16_mul_add(a: u16, b: u16, c: u16) -> u16 {
let fa = bf16_to_f32(a);
let fb = bf16_to_f32(b);
let fc = bf16_to_f32(c);
f32_to_bf16(fa * fb + fc)
}
fn make_f32_vec(len: usize, seed: u64) -> Vec<f32> {
let mut s = seed.wrapping_add(0x9E3779B97F4A7C15);
(0..len)
.map(|_| {
s = s.wrapping_add(0x9E3779B97F4A7C15);
let mut z = s;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^= z >> 31;
((z as i64 as f64) * (1.0 / (i64::MAX as f64)) * 0.5) as f32
})
.collect()
}
#[test]
fn test_quantized_matmul_id_4bit_4tokens_8experts_top2() {
let (device, mut registry) = setup();
let num_experts: usize = 8;
let n: usize = 32; let k: usize = 64; let group_size: usize = 32;
let n_tokens: usize = 4;
let n_expert_used: usize = 2;
let bits: u32 = 4;
let mut all_weight_bytes: Vec<u8> = Vec::new();
let mut all_scales: Vec<u16> = Vec::new();
let mut all_biases: Vec<u16> = Vec::new();
let mut expert_wb: Vec<Vec<u8>> = Vec::new();
let mut expert_sc: Vec<Vec<u16>> = Vec::new();
let mut expert_bi: Vec<Vec<u16>> = Vec::new();
for e in 0..num_experts {
let float_data = make_f32_vec(n * k, 0xCAFE + (e as u64) * 0x1234);
let (wb, sc, bi) = quantize_4bit(&float_data, n, k, group_size);
all_weight_bytes.extend_from_slice(&wb);
all_scales.extend_from_slice(&sc);
all_biases.extend_from_slice(&bi);
expert_wb.push(wb);
expert_sc.push(sc);
expert_bi.push(bi);
}
let input_data = make_f32_vec(n_tokens * k, 0xBEEF);
let ids_data: Vec<u32> = vec![0, 3, 1, 5, 2, 7, 4, 6];
let mut ref_output = vec![0.0f32; n_tokens * n_expert_used * n];
for t in 0..n_tokens {
for s in 0..n_expert_used {
let expert_id = ids_data[t * n_expert_used + s] as usize;
let input_row = &input_data[t * k..(t + 1) * k];
let expert_out = dequant_matmul_ref(
input_row,
&expert_wb[expert_id],
&expert_sc[expert_id],
&expert_bi[expert_id],
n,
k,
group_size,
);
let base = (t * n_expert_used + s) * n;
ref_output[base..base + n].copy_from_slice(&expert_out);
}
}
let mut input_buf = device
.alloc_buffer(input_data.len() * 4, DType::F32, vec![n_tokens, k])
.expect("input");
input_buf
.as_mut_slice::<f32>()
.expect("w")
.copy_from_slice(&input_data);
let mut weight_buf = device
.alloc_buffer(all_weight_bytes.len(), DType::U8, vec![all_weight_bytes.len()])
.expect("weight");
weight_buf
.as_mut_slice::<u8>()
.expect("w")
.copy_from_slice(&all_weight_bytes);
let mut scales_buf = device
.alloc_buffer(all_scales.len() * 2, DType::U16, vec![all_scales.len()])
.expect("scales");
scales_buf
.as_mut_slice::<u16>()
.expect("w")
.copy_from_slice(&all_scales);
let mut biases_buf = device
.alloc_buffer(all_biases.len() * 2, DType::U16, vec![all_biases.len()])
.expect("biases");
biases_buf
.as_mut_slice::<u16>()
.expect("w")
.copy_from_slice(&all_biases);
let mut ids_buf = device
.alloc_buffer(ids_data.len() * 4, DType::U32, vec![n_tokens, n_expert_used])
.expect("ids");
ids_buf
.as_mut_slice::<u32>()
.expect("w")
.copy_from_slice(&ids_data);
let mut encoder = device.command_encoder().expect("encoder");
let output_buf = quantized_matmul_id(
&mut encoder,
&mut registry,
&device,
&input_buf,
&weight_buf,
&scales_buf,
&biases_buf,
&ids_buf,
&QuantizedMatmulIdParams {
m: n_tokens as u32,
k: k as u32,
n: n as u32,
group_size: group_size as u32,
bits,
n_expert_used: n_expert_used as u32,
num_experts: num_experts as u32,
},
)
.expect("quantized_matmul_id");
encoder.commit_and_wait().expect("commit");
let output: &[f32] = output_buf.as_slice().expect("read");
assert_eq!(output.len(), ref_output.len());
let mut max_diff: f32 = 0.0;
let mut max_idx: usize = 0;
for i in 0..output.len() {
let diff = (output[i] - ref_output[i]).abs();
if diff > max_diff {
max_diff = diff;
max_idx = i;
}
}
eprintln!(
"quantized_matmul_id Q4: max|delta|={:.3e} at idx {} (gpu={}, ref={})",
max_diff, max_idx, output[max_idx], ref_output[max_idx]
);
assert!(
max_diff <= 1e-4,
"quantized_matmul_id Q4 exceeds tolerance 1e-4: max|delta|={}",
max_diff
);
}
#[test]
fn test_quantized_matmul_id_zero_experts_error() {
let (device, mut registry) = setup();
let buf = device
.alloc_buffer(256, DType::F32, vec![64])
.expect("buf");
let mut encoder = device.command_encoder().expect("enc");
let result = quantized_matmul_id(
&mut encoder,
&mut registry,
&device,
&buf,
&buf,
&buf,
&buf,
&buf,
&QuantizedMatmulIdParams {
m: 4,
k: 64,
n: 32,
group_size: 32,
bits: 4,
n_expert_used: 2,
num_experts: 0,
},
);
assert!(result.is_err(), "zero num_experts should error");
}
#[test]
fn test_quantized_matmul_id_unsupported_bits_error() {
let (device, mut registry) = setup();
let buf = device
.alloc_buffer(256, DType::F32, vec![64])
.expect("buf");
let mut encoder = device.command_encoder().expect("enc");
let result = quantized_matmul_id(
&mut encoder,
&mut registry,
&device,
&buf,
&buf,
&buf,
&buf,
&buf,
&QuantizedMatmulIdParams {
m: 1,
k: 64,
n: 32,
group_size: 32,
bits: 3,
n_expert_used: 1,
num_experts: 1,
},
);
assert!(result.is_err(), "bits=3 should error");
}