#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
use half::f16;
fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
if j < 4 {
(scales[j] & 63, scales[j + 4] & 63)
} else {
let sc = (scales[j + 4] & 0xF) | ((scales[j - 4] >> 6) << 4);
let m = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4);
(sc, m)
}
}
fn build_q5_k_block(d: f32, dmin: f32, scales: &[u8; 12], qh: &[u8; 32], qs: &[u8; 128]) -> [u8; 176] {
let mut b = [0u8; 176];
let d_bits = f16::from_f32(d).to_le_bytes();
let dm_bits = f16::from_f32(dmin).to_le_bytes();
b[0] = d_bits[0];
b[1] = d_bits[1];
b[2] = dm_bits[0];
b[3] = dm_bits[1];
b[4..16].copy_from_slice(scales);
b[16..48].copy_from_slice(qh);
b[48..176].copy_from_slice(qs);
b
}
fn cpu_dequant_q5_k(block: &[u8; 176]) -> [f32; 256] {
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
let scales: &[u8] = &block[4..16];
let qh: &[u8] = &block[16..48];
let qs: &[u8] = &block[48..176];
let mut out = [0.0f32; 256];
let mut is = 0usize;
let mut u1: u8 = 1;
let mut u2: u8 = 2;
let mut y = 0usize;
let mut ql_off = 0usize;
while ql_off < 128 {
let ql = &qs[ql_off..ql_off + 32];
let (sc1, m1) = get_scale_min_k4(is, scales);
let (sc2, m2) = get_scale_min_k4(is + 1, scales);
let d1 = d * (sc1 as f32);
let m1 = dmin * (m1 as f32);
let d2 = d * (sc2 as f32);
let m2 = dmin * (m2 as f32);
for l in 0..32 {
let low = (ql[l] & 0x0F) as u32;
let high = if (qh[l] & u1) != 0 { 16 } else { 0 };
out[y] = d1 * (low + high) as f32 - m1;
y += 1;
}
for l in 0..32 {
let low = (ql[l] >> 4) as u32;
let high = if (qh[l] & u2) != 0 { 16 } else { 0 };
out[y] = d2 * (low + high) as f32 - m2;
y += 1;
}
is += 2;
ql_off += 32;
u1 <<= 2;
u2 <<= 2;
}
out
}
#[test]
fn q5_k_all_zeros() {
let block = build_q5_k_block(0.0, 0.0, &[0; 12], &[0; 32], &[0; 128]);
let out = cpu_dequant_q5_k(&block);
for (i, v) in out.iter().enumerate() {
assert!(v.abs() < 1e-6, "expected 0 at {}, got {}", i, v);
}
}
#[test]
fn q5_k_zero_quant_nonzero_scale() {
let mut scales = [0u8; 12];
for s in &mut scales {
*s = 1; }
let mut scales = [0u8; 12];
for i in 0..4 {
scales[i] = 1;
}
let block = build_q5_k_block(1.0, 0.0, &scales, &[0; 32], &[0; 128]);
let out = cpu_dequant_q5_k(&block);
for (i, v) in out.iter().enumerate() {
assert!(v.abs() < 1e-6, "expected 0 at {} (zero-quant), got {}", i, v);
}
}
#[test]
fn q5_k_first_value_with_high_bit() {
let mut scales = [0u8; 12];
scales[0] = 1; let mut qs = [0u8; 128];
qs[0] = 0x0F; let mut qh = [0u8; 32];
qh[0] = 0x01; let block = build_q5_k_block(1.0, 0.0, &scales, &qh, &qs);
let out = cpu_dequant_q5_k(&block);
assert!(
(out[0] - 31.0).abs() < 1e-6,
"expected 31, got {}",
out[0]
);
assert!(
out[1].abs() < 1e-6,
"expected 0, got {}",
out[1]
);
}
#[test]
fn q5_k_second_subblock_uses_high_nibble_and_u2() {
let mut scales = [0u8; 12];
scales[0] = 1; scales[1] = 2; let mut qs = [0u8; 128];
qs[0] = 0xF0; let mut qh = [0u8; 32];
qh[0] = 0x02; let block = build_q5_k_block(1.0, 0.0, &scales, &qh, &qs);
let out = cpu_dequant_q5_k(&block);
assert!(out[0].abs() < 1e-6);
assert!(
(out[32] - 62.0).abs() < 1e-6,
"expected 62, got {}",
out[32]
);
}
#[test]
fn q5_k_third_pair_uses_shifted_u1() {
let mut scales = [0u8; 12];
scales[8] = 3;
let mut qs = [0u8; 128];
qs[64] = 0x05;
let mut qh = [0u8; 32];
qh[0] = 0x10;
let block = build_q5_k_block(1.0, 0.0, &scales, &qh, &qs);
let out = cpu_dequant_q5_k(&block);
assert!(
(out[128] - 63.0).abs() < 1e-6,
"expected 63 at idx 128, got {}",
out[128]
);
}
#[test]
fn q5_k_with_nonzero_min() {
let mut scales = [0u8; 12];
scales[0] = 1; scales[4] = 1; let mut qs = [0u8; 128];
qs[0] = 0x0F;
let mut qh = [0u8; 32];
qh[0] = 0x01;
let block = build_q5_k_block(2.0, 3.0, &scales, &qh, &qs);
let out = cpu_dequant_q5_k(&block);
assert!(
(out[0] - 59.0).abs() < 1e-6,
"expected 59, got {}",
out[0]
);
}
use mlx_native::{gguf::GgufFile, MlxDevice};
use std::io::Write;
fn write_minimal_gguf(path: &std::path::Path, tensor_name: &str, tensor_data: &[u8]) {
let mut buf = Vec::new();
buf.extend_from_slice(b"GGUF");
buf.extend_from_slice(&3u32.to_le_bytes());
buf.extend_from_slice(&1u64.to_le_bytes()); buf.extend_from_slice(&0u64.to_le_bytes());
buf.extend_from_slice(&(tensor_name.len() as u64).to_le_bytes());
buf.extend_from_slice(tensor_name.as_bytes());
buf.extend_from_slice(&1u32.to_le_bytes());
buf.extend_from_slice(&256u64.to_le_bytes());
buf.extend_from_slice(&13u32.to_le_bytes());
buf.extend_from_slice(&0u64.to_le_bytes());
while buf.len() % 32 != 0 {
buf.push(0);
}
buf.extend_from_slice(tensor_data);
let mut f = std::fs::File::create(path).expect("create tmp gguf");
f.write_all(&buf).expect("write");
f.flush().expect("flush");
}
#[test]
fn q5_k_mlx_native_load_tensor_f32_matches_cpu_reference() {
let mut scales = [0u8; 12];
scales[0] = 5;
scales[1] = 7;
scales[2] = 11;
scales[3] = 13;
scales[4] = 2; scales[5] = 3;
scales[6] = 5;
scales[7] = 7;
scales[8] = 0x21;
scales[9] = 0x43;
scales[10] = 0x65;
scales[11] = 0x17;
let mut qh = [0u8; 32];
let mut qs = [0u8; 128];
for i in 0..32 {
qh[i] = (i as u8).wrapping_mul(37);
}
for i in 0..128 {
qs[i] = (i as u8).wrapping_mul(97).wrapping_add(3);
}
let block = build_q5_k_block(0.125, 0.0625, &scales, &qh, &qs);
let expected = cpu_dequant_q5_k(&block);
let tmp = std::env::temp_dir().join(format!(
"mlx_q5k_test_{}.gguf",
std::process::id()
));
write_minimal_gguf(&tmp, "test_tensor", &block);
let device = MlxDevice::new().expect("device");
let gguf = GgufFile::open(&tmp).expect("open mini gguf");
let buf = gguf
.load_tensor_f32("test_tensor", &device)
.expect("load_tensor_f32");
let got: &[f32] = buf.as_slice().expect("as slice");
assert_eq!(got.len(), 256);
for i in 0..256 {
let d = (got[i] - expected[i]).abs();
assert!(
d < 1e-5,
"mismatch at {}: got {}, expected {}, diff {}",
i, got[i], expected[i], d
);
}
std::fs::remove_file(&tmp).ok();
}
#[test]
fn i16_dequant_via_mlx_native_matches_simple_cast() {
let values: [i16; 5] = [0, 1, -1, 32767, -32768];
let mut bytes = [0u8; 10];
for (i, v) in values.iter().enumerate() {
bytes[i * 2..i * 2 + 2].copy_from_slice(&v.to_le_bytes());
}
let mut buf = Vec::new();
buf.extend_from_slice(b"GGUF");
buf.extend_from_slice(&3u32.to_le_bytes());
buf.extend_from_slice(&1u64.to_le_bytes());
buf.extend_from_slice(&0u64.to_le_bytes());
let name = "i16_tensor";
buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(&1u32.to_le_bytes());
buf.extend_from_slice(&5u64.to_le_bytes());
buf.extend_from_slice(&17u32.to_le_bytes()); buf.extend_from_slice(&0u64.to_le_bytes());
while buf.len() % 32 != 0 {
buf.push(0);
}
buf.extend_from_slice(&bytes);
let tmp = std::env::temp_dir().join(format!(
"mlx_i16_test_{}.gguf",
std::process::id()
));
std::fs::write(&tmp, &buf).expect("write");
let device = MlxDevice::new().expect("device");
let gguf = GgufFile::open(&tmp).expect("open mini gguf i16");
let mbuf = gguf
.load_tensor_f32("i16_tensor", &device)
.expect("load_tensor_f32 i16");
let got: &[f32] = mbuf.as_slice().expect("slice");
assert_eq!(got.len(), 5);
for (i, v) in values.iter().enumerate() {
assert_eq!(got[i], *v as f32, "I16 cast at {}: got {}", i, got[i]);
}
std::fs::remove_file(&tmp).ok();
}