use mlx_native::gguf::{
test_only_dequantize_iq4_nl, test_only_dequantize_q5_1, test_only_kvalues_iq4_nl,
};
const QK5_1: usize = 32;
const BLOCK_Q5_1_BYTES: usize = 24;
const QK4_NL: usize = 32;
const BLOCK_IQ4_NL_BYTES: usize = 18;
fn ref_quantize_q5_1(row: &[f32]) -> Vec<u8> {
assert!(row.len() % QK5_1 == 0);
let nb = row.len() / QK5_1;
let mut out = Vec::with_capacity(nb * BLOCK_Q5_1_BYTES);
for ib in 0..nb {
let x = &row[ib * QK5_1..(ib + 1) * QK5_1];
let mut min = f32::MAX;
let mut max = f32::MIN;
for &v in x {
if v < min {
min = v;
}
if v > max {
max = v;
}
}
let d = (max - min) / 31.0;
let id = if d == 0.0 { 0.0 } else { 1.0 / d };
let m = min;
let mut qs = [0u8; QK5_1 / 2];
let mut qh: u32 = 0;
for j in 0..(QK5_1 / 2) {
let q0 = ((x[j] - m) * id + 0.5).clamp(0.0, 31.0) as u32;
let q1 = ((x[j + QK5_1 / 2] - m) * id + 0.5).clamp(0.0, 31.0) as u32;
qs[j] = ((q0 & 0x0F) | ((q1 & 0x0F) << 4)) as u8;
qh |= ((q0 >> 4) & 1) << j;
qh |= ((q1 >> 4) & 1) << (j + 16);
}
let d_bits = half::f16::from_f32(d).to_bits();
let m_bits = half::f16::from_f32(m).to_bits();
out.extend_from_slice(&d_bits.to_le_bytes());
out.extend_from_slice(&m_bits.to_le_bytes());
out.extend_from_slice(&qh.to_le_bytes());
out.extend_from_slice(&qs);
}
out
}
fn ref_quantize_iq4_nl(row: &[f32]) -> Vec<u8> {
let kvalues = test_only_kvalues_iq4_nl();
assert!(row.len() % QK4_NL == 0);
let nb = row.len() / QK4_NL;
let mut out = Vec::with_capacity(nb * BLOCK_IQ4_NL_BYTES);
for ib in 0..nb {
let x = &row[ib * QK4_NL..(ib + 1) * QK4_NL];
let max_abs = x.iter().map(|v| v.abs()).fold(0.0_f32, f32::max);
let d = if max_abs == 0.0 { 0.0 } else { max_abs / 113.0 };
let inv_d = if d == 0.0 { 0.0 } else { 1.0 / d };
let nearest_idx = |target: f32| -> u8 {
let mut best_idx: u8 = 0;
let mut best_err = f32::MAX;
for (idx, &kv) in kvalues.iter().enumerate() {
let err = (target - kv as f32).abs();
if err < best_err {
best_err = err;
best_idx = idx as u8;
}
}
best_idx
};
let mut qs = [0u8; QK4_NL / 2];
for j in 0..(QK4_NL / 2) {
let lo = nearest_idx(x[j] * inv_d);
let hi = nearest_idx(x[j + QK4_NL / 2] * inv_d);
qs[j] = (lo & 0x0F) | ((hi & 0x0F) << 4);
}
let d_bits = half::f16::from_f32(d).to_bits();
out.extend_from_slice(&d_bits.to_le_bytes());
out.extend_from_slice(&qs);
}
out
}
fn deterministic_random_f32(seed: u64, n: usize) -> Vec<f32> {
let mut state = seed.wrapping_mul(0x2545F4914F6CDD1D);
let mut out = Vec::with_capacity(n);
for _ in 0..n {
state ^= state >> 12;
state ^= state << 25;
state ^= state >> 27;
let bits = state.wrapping_mul(0x2545F4914F6CDD1D);
let x = ((bits >> 11) as f32) / (1u64 << 53) as f32 * 2.0 - 1.0;
out.push(x);
}
out
}
#[test]
fn adr022_q5_1_dequant_parity_random_uniform() {
let row = deterministic_random_f32(0xAD220051, 10 * QK5_1);
let bytes = ref_quantize_q5_1(&row);
assert_eq!(bytes.len(), 10 * BLOCK_Q5_1_BYTES);
let mut decoded = vec![0.0_f32; 10 * QK5_1];
test_only_dequantize_q5_1(&bytes, &mut decoded).expect("q5_1 dequant");
let mut max_delta = 0.0_f32;
for (orig, dec) in row.iter().zip(decoded.iter()) {
let d = (orig - dec).abs();
if d > max_delta {
max_delta = d;
}
}
assert!(
max_delta <= 0.05_f32,
"Q5_1 max abs delta {max_delta} exceeds 0.05 bound"
);
}
#[test]
fn adr022_q5_1_dequant_zero_input_yields_zero_output() {
let row = vec![0.0_f32; QK5_1];
let bytes = ref_quantize_q5_1(&row);
let mut decoded = vec![0.0_f32; QK5_1];
test_only_dequantize_q5_1(&bytes, &mut decoded).expect("q5_1 dequant");
for &v in &decoded {
assert_eq!(v, 0.0_f32);
}
}
#[test]
fn adr022_q5_1_dequant_single_block_byte_layout() {
let mut bytes = Vec::with_capacity(BLOCK_Q5_1_BYTES);
let d_bits = half::f16::from_f32(1.0).to_bits();
let m_bits = half::f16::from_f32(-15.0).to_bits();
bytes.extend_from_slice(&d_bits.to_le_bytes());
bytes.extend_from_slice(&m_bits.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes());
bytes.extend_from_slice(&[0u8; 16]);
assert_eq!(bytes.len(), BLOCK_Q5_1_BYTES);
let mut decoded = vec![0.0_f32; QK5_1];
test_only_dequantize_q5_1(&bytes, &mut decoded).expect("q5_1 dequant");
for &v in &decoded {
assert_eq!(v, -15.0_f32);
}
}
#[test]
fn adr022_q5_1_dequant_qh_high_bit_round_trip() {
let mut bytes = Vec::with_capacity(BLOCK_Q5_1_BYTES);
let d = 0.5_f32;
let m = 1.0_f32;
bytes.extend_from_slice(&half::f16::from_f32(d).to_bits().to_le_bytes());
bytes.extend_from_slice(&half::f16::from_f32(m).to_bits().to_le_bytes());
let qh: u32 = (1 << 0) | (1 << 16);
bytes.extend_from_slice(&qh.to_le_bytes());
let mut qs = [0u8; 16];
qs[0] = 0x10; bytes.extend_from_slice(&qs);
let mut decoded = vec![0.0_f32; QK5_1];
test_only_dequantize_q5_1(&bytes, &mut decoded).expect("q5_1 dequant");
let expected_pos0 = 16.0_f32 * d + m;
let expected_pos16 = 17.0_f32 * d + m;
assert!((decoded[0] - expected_pos0).abs() < 1e-6);
assert!((decoded[16] - expected_pos16).abs() < 1e-6);
}
#[test]
fn adr022_iq4_nl_dequant_parity_random_uniform() {
let row = deterministic_random_f32(0xAD22004F, 10 * QK4_NL);
let bytes = ref_quantize_iq4_nl(&row);
assert_eq!(bytes.len(), 10 * BLOCK_IQ4_NL_BYTES);
let mut decoded = vec![0.0_f32; 10 * QK4_NL];
test_only_dequantize_iq4_nl(&bytes, &mut decoded).expect("iq4_nl dequant");
let mut max_delta = 0.0_f32;
for (orig, dec) in row.iter().zip(decoded.iter()) {
let d = (orig - dec).abs();
if d > max_delta {
max_delta = d;
}
}
assert!(
max_delta <= 0.15_f32,
"IQ4_NL max abs delta {max_delta} exceeds 0.15 bound (codebook half-gap)"
);
}
#[test]
fn adr022_iq4_nl_dequant_zero_input_yields_zero_output() {
let row = vec![0.0_f32; QK4_NL];
let bytes = ref_quantize_iq4_nl(&row);
let mut decoded = vec![0.0_f32; QK4_NL];
test_only_dequantize_iq4_nl(&bytes, &mut decoded).expect("iq4_nl dequant");
for &v in &decoded {
assert_eq!(v, 0.0_f32);
}
}
#[test]
fn adr022_iq4_nl_codebook_constant_byte_equal_to_llama_cpp() {
let kv = test_only_kvalues_iq4_nl();
assert_eq!(
kv,
[
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
]
);
}
#[test]
fn adr022_iq4_nl_dequant_single_block_byte_layout() {
let mut bytes = Vec::with_capacity(BLOCK_IQ4_NL_BYTES);
let d = 2.0_f32;
bytes.extend_from_slice(&half::f16::from_f32(d).to_bits().to_le_bytes());
let mut qs = [0u8; 16];
qs[0] = 0x21;
bytes.extend_from_slice(&qs);
let mut decoded = vec![0.0_f32; QK4_NL];
test_only_dequantize_iq4_nl(&bytes, &mut decoded).expect("iq4_nl dequant");
assert!((decoded[0] - (-208.0_f32)).abs() < 1e-3);
assert!((decoded[16] - (-166.0_f32)).abs() < 1e-3);
for &v in &decoded[1..16] {
assert!((v - (-254.0_f32)).abs() < 1e-3);
}
for &v in &decoded[17..32] {
assert!((v - (-254.0_f32)).abs() < 1e-3);
}
}