use half::f16;
pub fn dequant_q6k(blocks: &[u8], output: &mut [f32]) {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 210;
let num_blocks = blocks.len() / BLOCK_BYTES;
debug_assert_eq!(output.len(), num_blocks * BLOCK_SIZE);
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let ql = &block[0..128];
let qh = &block[128..192];
let sc: &[i8] = unsafe { std::slice::from_raw_parts(block[192..208].as_ptr().cast(), 16) };
let d = f16::from_le_bytes([block[208], block[209]]).to_f32();
let out = &mut output[b * BLOCK_SIZE..][..BLOCK_SIZE];
for n in 0..2 {
let y_base = n * 128;
let ql_base = n * 64;
let qh_base = n * 32;
let sc_base = n * 8;
for l in 0..32 {
let is = l / 16;
let q1 = ((ql[ql_base + l] & 0x0F) | ((qh[qh_base + l] & 0x03) << 4)) as i8 - 32;
let q2 = ((ql[ql_base + l + 32] & 0x0F) | (((qh[qh_base + l] >> 2) & 0x03) << 4))
as i8
- 32;
let q3 =
((ql[ql_base + l] >> 4) | (((qh[qh_base + l] >> 4) & 0x03) << 4)) as i8 - 32;
let q4 = ((ql[ql_base + l + 32] >> 4) | (((qh[qh_base + l] >> 6) & 0x03) << 4))
as i8
- 32;
out[y_base + l] = d * sc[sc_base + is] as f32 * q1 as f32;
out[y_base + l + 32] = d * sc[sc_base + is + 2] as f32 * q2 as f32;
out[y_base + l + 64] = d * sc[sc_base + is + 4] as f32 * q3 as f32;
out[y_base + l + 96] = d * sc[sc_base + is + 6] as f32 * q4 as f32;
}
}
}
}
pub fn dequant_q8k(blocks: &[u8], output: &mut [f32]) {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 292;
let num_blocks = blocks.len() / BLOCK_BYTES;
debug_assert_eq!(output.len(), num_blocks * BLOCK_SIZE);
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let d = f32::from_le_bytes([block[0], block[1], block[2], block[3]]);
let qs = &block[4..260];
let out = &mut output[b * BLOCK_SIZE..][..BLOCK_SIZE];
for (out_val, &qs_val) in out.iter_mut().zip(qs.iter()) {
*out_val = qs_val as i8 as f32 * d;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dequant_q6k_zero_scales() {
let block = [0u8; 210];
let mut output = [0.0f32; 256];
dequant_q6k(&block, &mut output);
for &v in &output {
assert!(v.abs() < 1e-5, "expected ~0, got {}", v);
}
}
#[test]
fn test_dequant_q8k_known_values() {
let mut block = [0u8; 292];
block[0..4].copy_from_slice(&0.5f32.to_le_bytes());
block[4..260].fill(10);
let mut output = [0.0f32; 256];
dequant_q8k(&block, &mut output);
for &v in &output {
assert!((v - 5.0).abs() < 0.01, "expected 5.0, got {}", v);
}
}
}