use half::f16;
pub fn dequant_q4k(blocks: &[u8], output: &mut [f32]) {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 144;
let num_blocks = blocks.len() / BLOCK_BYTES;
debug_assert_eq!(output.len(), num_blocks * BLOCK_SIZE);
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
let sc = &block[4..16]; let qs = &block[16..144]; let out = &mut output[b * BLOCK_SIZE..][..BLOCK_SIZE];
let (scales, mins) = unpack_q4k_q5k_scales(sc);
for j in 0..8 {
let dl = d * scales[j] as f32;
let ml = dmin * mins[j] as f32;
let chunk = j / 2;
let is_high = j % 2 == 1;
let qs_base = chunk * 32;
for l in 0..32 {
let q = if is_high {
((qs[qs_base + l] >> 4) & 0x0F) as f32
} else {
(qs[qs_base + l] & 0x0F) as f32
};
out[j * 32 + l] = dl * q - ml;
}
}
}
}
pub fn dequant_q5k(blocks: &[u8], output: &mut [f32]) {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 176;
let num_blocks = blocks.len() / BLOCK_BYTES;
debug_assert_eq!(output.len(), num_blocks * BLOCK_SIZE);
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
let sc = &block[4..16]; let qh = &block[16..48]; let qs = &block[48..176]; let out = &mut output[b * BLOCK_SIZE..][..BLOCK_SIZE];
let (scales, mins) = unpack_q4k_q5k_scales(sc);
for j in 0..8 {
let dl = d * scales[j] as f32;
let ml = dmin * mins[j] as f32;
for l in 0..32 {
let idx = j * 32 + l;
let qs_idx = j * 16 + l / 2;
let low4 = if l % 2 == 0 {
qs[qs_idx] & 0x0F
} else {
(qs[qs_idx] >> 4) & 0x0F
};
let qh_byte = idx / 8;
let qh_bit = idx % 8;
let high1 = (qh[qh_byte] >> qh_bit) & 0x01;
let q = (low4 | (high1 << 4)) as f32;
out[idx] = dl * q - ml;
}
}
}
}
pub fn unpack_q4k_q5k_scales(sc: &[u8]) -> ([u8; 8], [u8; 8]) {
let mut scales = [0u8; 8];
let mut mins = [0u8; 8];
for i in 0..4 {
scales[i] = sc[i] & 0x3F;
mins[i] = sc[i + 4] & 0x3F;
}
for i in 4..8 {
scales[i] = (sc[i + 4] & 0x0F) | ((sc[i - 4] >> 6) << 4);
mins[i] = (sc[i + 4] >> 4) | ((sc[i] >> 6) << 4);
}
(scales, mins)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dequant_q4k_zero_scales() {
let block = [0u8; 144];
let mut output = [0.0f32; 256];
dequant_q4k(&block, &mut output);
for &v in &output {
assert!(v.abs() < 1e-5, "expected ~0, got {}", v);
}
}
#[test]
fn test_dequant_q5k_zero() {
let block = [0u8; 176];
let mut output = [0.0f32; 256];
dequant_q5k(&block, &mut output);
for &v in &output {
assert!(v.abs() < 1e-5, "expected ~0, got {}", v);
}
}
#[test]
fn test_dequant_q5k_known_values() {
let mut block = [0u8; 176];
block[0..2].copy_from_slice(&f16::from_f32(1.0).to_le_bytes()); block[2..4].copy_from_slice(&f16::from_f32(0.0).to_le_bytes()); block[4..8].fill(0x01); block[8..12].fill(0x00); block[12..16].fill(0x00); block[48..176].fill(0x55);
let mut output = [0.0f32; 256];
dequant_q5k(&block, &mut output);
for (i, &v) in output[..128].iter().enumerate() {
assert!((v - 5.0).abs() < 0.1, "elem {i}: expected 5.0, got {v}");
}
for (i, &v) in output[128..].iter().enumerate() {
assert!(v.abs() < 1e-5, "elem {}: expected ~0, got {}", i + 128, v);
}
}
#[test]
fn test_unpack_q4k_q5k_scales_basic() {
let mut sc = [0u8; 12];
sc[0] = 10; sc[1] = 20; let (scales, _mins) = unpack_q4k_q5k_scales(&sc);
assert_eq!(scales[0], 10);
assert_eq!(scales[1], 20);
}
}