use bytemuck::{Pod, Zeroable};
use half::f16;
use crate::error::{Result, RullamaError};
use super::dtype::GgmlDtype;
pub const QK_K: usize = 256;
pub const Q4_K_BLOCK_BYTES: usize = 144;
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct BlockQ4K {
d_bits: u16,
dmin_bits: u16,
scales: [u8; 12],
qs: [u8; 128],
}
#[inline]
fn get_scale_min_k4(j: usize, q: &[u8; 12]) -> (u8, u8) {
if j < 4 {
let d = q[j] & 63;
let m = q[j + 4] & 63;
(d, m)
} else {
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
let m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
(d, m)
}
}
pub fn dequant_q4_k(src: &[u8], out: &mut [f32]) -> Result<()> {
if src.len() % Q4_K_BLOCK_BYTES != 0 {
return Err(RullamaError::Gguf(format!(
"Q4_K source not multiple of {Q4_K_BLOCK_BYTES} bytes (got {})", src.len()
)));
}
let nb = src.len() / Q4_K_BLOCK_BYTES;
if out.len() != nb * QK_K {
return Err(RullamaError::Gguf(format!(
"Q4_K dest expected {} elements, got {}", nb * QK_K, out.len()
)));
}
let blocks: &[BlockQ4K] = bytemuck::cast_slice(src);
for (bi, blk) in blocks.iter().enumerate() {
let d = f16::from_bits(blk.d_bits).to_f32();
let dmin = f16::from_bits(blk.dmin_bits).to_f32();
let mut scales = [0u8; 8];
let mut mins = [0u8; 8];
for j in 0..8 {
let (s, m) = get_scale_min_k4(j, &blk.scales);
scales[j] = s;
mins[j] = m;
}
let dst = &mut out[bi * QK_K..(bi + 1) * QK_K];
let mut is = 0usize;
let mut j = 0usize;
while j < QK_K {
let q = &blk.qs[j / 2..j / 2 + 32];
let s_lo = scales[is] as f32;
let m_lo = mins[is] as f32;
let s_hi = scales[is + 1] as f32;
let m_hi = mins[is + 1] as f32;
for l in 0..32 {
dst[j + l] = d * s_lo * (q[l] & 0xF) as f32 - dmin * m_lo;
dst[j + l + 32] = d * s_hi * (q[l] >> 4) as f32 - dmin * m_hi;
}
is += 2;
j += 64;
}
}
Ok(())
}
pub const Q6_K_BLOCK_BYTES: usize = 210;
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct BlockQ6K {
ql: [u8; 128],
qh: [u8; 64],
scales: [i8; 16],
d_bits: u16,
}
pub fn dequant_q6_k(src: &[u8], out: &mut [f32]) -> Result<()> {
if src.len() % Q6_K_BLOCK_BYTES != 0 {
return Err(RullamaError::Gguf(format!(
"Q6_K source not multiple of {Q6_K_BLOCK_BYTES} bytes (got {})", src.len()
)));
}
let nb = src.len() / Q6_K_BLOCK_BYTES;
if out.len() != nb * QK_K {
return Err(RullamaError::Gguf(format!(
"Q6_K dest expected {} elements, got {}", nb * QK_K, out.len()
)));
}
let blocks: &[BlockQ6K] = bytemuck::cast_slice(src);
for (bi, blk) in blocks.iter().enumerate() {
let d = f16::from_bits(blk.d_bits).to_f32();
let dst = &mut out[bi * QK_K..(bi + 1) * QK_K];
for n_pass in 0..2 {
let ql = &blk.ql[n_pass * 64..(n_pass + 1) * 64];
let qh = &blk.qh[n_pass * 32..(n_pass + 1) * 32];
let sc = &blk.scales[n_pass * 8..(n_pass + 1) * 8];
let base = n_pass * 128;
for l in 0..32 {
let is = l / 16;
let q1 = ((ql[l] & 0xF) as i32 | (((qh[l] >> 0) & 3) as i32) << 4) - 32;
let q2 = ((ql[l + 32] & 0xF) as i32 | (((qh[l] >> 2) & 3) as i32) << 4) - 32;
let q3 = ((ql[l] >> 4) as i32 | (((qh[l] >> 4) & 3) as i32) << 4) - 32;
let q4 = ((ql[l + 32] >> 4) as i32 | (((qh[l] >> 6) & 3) as i32) << 4) - 32;
dst[base + l + 0] = d * sc[is + 0] as f32 * q1 as f32;
dst[base + l + 32] = d * sc[is + 2] as f32 * q2 as f32;
dst[base + l + 64] = d * sc[is + 4] as f32 * q3 as f32;
dst[base + l + 96] = d * sc[is + 6] as f32 * q4 as f32;
}
}
}
Ok(())
}
pub fn bf16_to_f32(src: &[u8], out: &mut [f32]) -> Result<()> {
if src.len() % 2 != 0 {
return Err(RullamaError::Gguf(format!("BF16 source byte length {} is odd", src.len())));
}
if out.len() * 2 != src.len() {
return Err(RullamaError::Gguf(format!(
"BF16 dest expected {} elements, got {}", src.len() / 2, out.len()
)));
}
for (i, chunk) in src.chunks_exact(2).enumerate() {
let bits = u32::from(u16::from_le_bytes([chunk[0], chunk[1]])) << 16;
out[i] = f32::from_bits(bits);
}
Ok(())
}
pub fn f16_to_f32(src: &[u8], out: &mut [f32]) -> Result<()> {
if src.len() % 2 != 0 {
return Err(RullamaError::Gguf(format!("F16 source byte length {} is odd", src.len())));
}
if out.len() * 2 != src.len() {
return Err(RullamaError::Gguf(format!(
"F16 dest expected {} elements, got {}", src.len() / 2, out.len()
)));
}
for (i, chunk) in src.chunks_exact(2).enumerate() {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
out[i] = f16::from_bits(bits).to_f32();
}
Ok(())
}
pub fn f32_to_f32(src: &[u8], out: &mut [f32]) -> Result<()> {
if src.len() % 4 != 0 {
return Err(RullamaError::Gguf(format!("F32 source byte length {} not /4", src.len())));
}
if out.len() * 4 != src.len() {
return Err(RullamaError::Gguf(format!(
"F32 dest expected {} elements, got {}", src.len() / 4, out.len()
)));
}
for (i, chunk) in src.chunks_exact(4).enumerate() {
out[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
Ok(())
}
pub fn dequant_into_f32(dtype: GgmlDtype, src: &[u8], out: &mut [f32]) -> Result<()> {
match dtype {
GgmlDtype::F32 => f32_to_f32(src, out),
GgmlDtype::F16 => f16_to_f32(src, out),
GgmlDtype::BF16 => bf16_to_f32(src, out),
GgmlDtype::Q4_K => dequant_q4_k(src, out),
GgmlDtype::Q6_K => dequant_q6_k(src, out),
other => Err(RullamaError::Gguf(format!(
"dtype {other:?} is not in v1 dequant scope (only F32, F16, BF16, Q4_K, Q6_K)"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn synth_q4_k_zero() -> Vec<u8> {
let mut buf = vec![0u8; Q4_K_BLOCK_BYTES];
buf[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
buf[2..4].copy_from_slice(&0x0000u16.to_le_bytes()); for j in 0..4 { buf[4 + j] = 1; }
for j in 4..8 { buf[4 + j + 4] = 0x01; }
buf
}
#[test]
fn q4_k_zero_block_dequants_to_zero() {
let src = synth_q4_k_zero();
let mut out = vec![999f32; QK_K];
dequant_q4_k(&src, &mut out).unwrap();
for &v in &out {
assert_eq!(v, 0.0, "dequant of all-zero quants must be zero");
}
}
#[test]
fn q4_k_alternating_nibbles() {
let mut buf = synth_q4_k_zero();
for b in &mut buf[16..16 + 128] { *b = 0xA5; }
let mut out = vec![0f32; QK_K];
dequant_q4_k(&buf, &mut out).unwrap();
for chunk in 0..(QK_K / 64) {
for l in 0..32 {
assert_eq!(out[chunk * 64 + l], 5.0, "low nibble dequant");
assert_eq!(out[chunk * 64 + l + 32], 10.0, "high nibble dequant");
}
}
}
#[test]
fn q6_k_zero_block_dequants_to_zero() {
let mut buf = vec![0u8; Q6_K_BLOCK_BYTES];
buf[208..210].copy_from_slice(&0x3C00u16.to_le_bytes());
let mut out = vec![999f32; QK_K];
dequant_q6_k(&buf, &mut out).unwrap();
assert!(out.iter().all(|&v| v == 0.0));
}
#[test]
fn q6_k_unit_scale_constant_quants() {
let mut buf = vec![0u8; Q6_K_BLOCK_BYTES];
for i in 0..16 { buf[192 + i] = 1; }
buf[208..210].copy_from_slice(&0x3C00u16.to_le_bytes());
let mut out = vec![0f32; QK_K];
dequant_q6_k(&buf, &mut out).unwrap();
for &v in &out {
assert_eq!(v, -32.0);
}
}
#[test]
fn f16_round_trip() {
let values: [f32; 4] = [0.0, 1.0, -2.5, 3.14];
let mut bytes = Vec::with_capacity(values.len() * 2);
for v in values {
bytes.extend_from_slice(&f16::from_f32(v).to_bits().to_le_bytes());
}
let mut out = vec![0f32; values.len()];
f16_to_f32(&bytes, &mut out).unwrap();
for i in 0..values.len() {
assert!((out[i] - values[i]).abs() < 0.01, "got {} want {}", out[i], values[i]);
}
}
}