use half::f16;
#[allow(clippy::needless_range_loop)]
pub fn dequant_q2k(blocks: &[u8], output: &mut [f32]) {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 84;
let num_blocks = blocks.len() / BLOCK_BYTES;
debug_assert_eq!(output.len(), num_blocks * BLOCK_SIZE);
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let sc = &block[0..16];
let qs = &block[16..80];
let d = f16::from_le_bytes([block[80], block[81]]).to_f32();
let dmin = f16::from_le_bytes([block[82], block[83]]).to_f32();
let out = &mut output[b * BLOCK_SIZE..][..BLOCK_SIZE];
let mut y = 0;
let mut is = 0;
let mut q_offset = 0;
for _n in 0..2 {
let q = &qs[q_offset..];
for shift in (0..8).step_by(2) {
let dl = d * (sc[is] & 0x0F) as f32;
let ml = dmin * (sc[is] >> 4) as f32;
is += 1;
for l in 0..16 {
out[y] = dl * ((q[l] >> shift) & 3) as f32 - ml;
y += 1;
}
let dl = d * (sc[is] & 0x0F) as f32;
let ml = dmin * (sc[is] >> 4) as f32;
is += 1;
for l in 0..16 {
out[y] = dl * ((q[16 + l] >> shift) & 3) as f32 - ml;
y += 1;
}
}
q_offset += 32;
}
}
}
pub fn dequant_q3k(blocks: &[u8], output: &mut [f32]) {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 110;
let num_blocks = blocks.len() / BLOCK_BYTES;
debug_assert_eq!(output.len(), num_blocks * BLOCK_SIZE);
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let hm = &block[0..32];
let qs = &block[32..96];
let sc = &block[96..108];
let d_all = f16::from_le_bytes([block[108], block[109]]).to_f32();
let out = &mut output[b * BLOCK_SIZE..][..BLOCK_SIZE];
let scales = unpack_q3k_scales(sc);
let mut y = 0;
let mut is = 0;
let mut m: u8 = 1; let mut q_offset = 0;
for _n in 0..2 {
let q = &qs[q_offset..];
for shift in (0..8).step_by(2) {
let dl = d_all * scales[is] as f32;
is += 1;
for l in 0..16 {
let low2 = (q[l] >> shift) & 3;
let high_sub = if hm[l] & m != 0 { 0i8 } else { 4i8 };
out[y] = dl * (low2 as i8 - high_sub) as f32;
y += 1;
}
let dl = d_all * scales[is] as f32;
is += 1;
for l in 0..16 {
let low2 = (q[16 + l] >> shift) & 3;
let high_sub = if hm[16 + l] & m != 0 { 0i8 } else { 4i8 };
out[y] = dl * (low2 as i8 - high_sub) as f32;
y += 1;
}
m = m.wrapping_shl(1);
}
q_offset += 32;
}
}
}
pub fn unpack_q3k_scales(sc: &[u8]) -> [i8; 16] {
let mut aux = [0u32; 4];
aux[0] = u32::from_le_bytes([sc[0], sc[1], sc[2], sc[3]]);
aux[1] = u32::from_le_bytes([sc[4], sc[5], sc[6], sc[7]]);
aux[2] = u32::from_le_bytes([sc[8], sc[9], sc[10], sc[11]]);
let tmp = aux[2];
const KMASK1: u32 = 0x03030303;
const KMASK2: u32 = 0x0f0f0f0f;
let a0 = aux[0];
let a1 = aux[1];
aux[0] = (a0 & KMASK2) | ((tmp & KMASK1) << 4);
aux[1] = (a1 & KMASK2) | (((tmp >> 2) & KMASK1) << 4);
aux[2] = ((a0 >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);
aux[3] = ((a1 >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4);
let mut bytes = [0u8; 16];
bytes[0..4].copy_from_slice(&aux[0].to_le_bytes());
bytes[4..8].copy_from_slice(&aux[1].to_le_bytes());
bytes[8..12].copy_from_slice(&aux[2].to_le_bytes());
bytes[12..16].copy_from_slice(&aux[3].to_le_bytes());
let mut scales = [0i8; 16];
for (scale, &byte) in scales.iter_mut().zip(bytes.iter()) {
*scale = byte as i8 - 32;
}
scales
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dequant_q2k_zero() {
let block = [0u8; 84];
let mut output = [0.0f32; 256];
dequant_q2k(&block, &mut output);
for &v in &output {
assert!(v.abs() < 1e-5, "expected ~0, got {}", v);
}
}
#[test]
fn test_dequant_q2k_known_values() {
let mut block = [0u8; 84];
block[0..16].fill(0x01); block[16..80].fill(0x55); block[80..82].copy_from_slice(&f16::from_f32(1.0).to_le_bytes());
block[82..84].copy_from_slice(&f16::from_f32(0.0).to_le_bytes());
let mut output = [0.0f32; 256];
dequant_q2k(&block, &mut output);
for (i, &v) in output.iter().enumerate() {
assert!((v - 1.0).abs() < 0.01, "elem {i}: expected 1.0, got {v}");
}
}
#[test]
fn test_dequant_q2k_with_min() {
let mut block = [0u8; 84];
block[0..16].fill(0x23); block[16..80].fill(0xAA); block[80..82].copy_from_slice(&f16::from_f32(2.0).to_le_bytes());
block[82..84].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
let mut output = [0.0f32; 256];
dequant_q2k(&block, &mut output);
for (i, &v) in output.iter().enumerate() {
assert!((v - 11.0).abs() < 0.1, "elem {i}: expected 11.0, got {v}");
}
}
#[test]
fn test_dequant_q3k_zero() {
let block = [0u8; 110];
let mut output = [0.0f32; 256];
dequant_q3k(&block, &mut output);
for &v in &output {
assert!(v.abs() < 1e-5, "expected ~0, got {}", v);
}
}
#[test]
fn test_dequant_q3k_known_values() {
let mut block = [0u8; 110];
block[0..32].fill(0xFF); block[32..96].fill(0xFF); block[108..110].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
let mut output = [0.0f32; 256];
dequant_q3k(&block, &mut output);
for (i, &v) in output.iter().enumerate() {
assert!(
(v - (-48.0)).abs() < 0.1,
"elem {i}: expected -48.0, got {v}"
);
}
}
#[test]
fn test_unpack_q3k_scales_zeros() {
let sc = [0u8; 12];
let scales = unpack_q3k_scales(&sc);
for &s in &scales {
assert_eq!(s, -32);
}
}
#[test]
fn test_unpack_q3k_scales_all_ones() {
let sc = [0xFF; 12];
let scales = unpack_q3k_scales(&sc);
for &s in &scales {
assert_eq!(s, 31);
}
}
#[test]
#[ignore] fn test_dequant_q3k_real_block() {
let block: [u8; 110] = [
249, 99, 245, 234, 226, 236, 45, 116, 159, 178, 189, 173, 255, 243, 26, 125, 222, 253,
238, 81, 247, 255, 191, 230, 74, 99, 179, 247, 70, 110, 203, 143, 238, 36, 144, 36,
114, 66, 196, 206, 0, 61, 12, 12, 18, 224, 193, 9, 19, 0, 92, 104, 251, 17, 16, 138,
255, 193, 98, 128, 3, 103, 20, 0, 0, 143, 16, 6, 3, 79, 4, 200, 50, 244, 48, 99, 20,
24, 248, 204, 86, 1, 2, 12, 0, 0, 68, 74, 131, 147, 55, 33, 158, 192, 79, 63, 81, 211,
209, 159, 0, 207, 132, 227, 55, 195, 39, 36, 44, 6,
];
let expected_first8: [f32; 8] = [0.0032, 0.0, 0.0, -0.0064, -0.0032, -0.0032, 0.0, -0.0032];
let mut output = [0.0f32; 256];
dequant_q3k(&block, &mut output);
for (i, &exp) in expected_first8.iter().enumerate() {
assert!(
(output[i] - exp).abs() < 0.001,
"elem {i}: expected {exp}, got {}",
output[i]
);
}
}
}