use bytemuck::{Pod, Zeroable};
use half::f16;
use super::dtype::GgmlDtype;
use crate::error::{Result, RullamaError};
pub const QK_K: usize = 256;
pub const Q4_K_BLOCK_BYTES: usize = 144;
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct BlockQ4K {
d_bits: u16,
dmin_bits: u16,
scales: [u8; 12],
qs: [u8; 128],
}
#[inline]
fn get_scale_min_k4(j: usize, q: &[u8; 12]) -> (u8, u8) {
if j < 4 {
let d = q[j] & 63;
let m = q[j + 4] & 63;
(d, m)
} else {
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
(d, m)
}
}
pub fn dequant_q4_k(src: &[u8], out: &mut [f32]) -> Result<()> {
if !src.len().is_multiple_of(Q4_K_BLOCK_BYTES) {
return Err(RullamaError::Gguf(format!(
"Q4_K source not multiple of {Q4_K_BLOCK_BYTES} bytes (got {})",
src.len()
)));
}
let nb = src.len() / Q4_K_BLOCK_BYTES;
if out.len() != nb * QK_K {
return Err(RullamaError::Gguf(format!(
"Q4_K dest expected {} elements, got {}",
nb * QK_K,
out.len()
)));
}
let blocks: &[BlockQ4K] = bytemuck::cast_slice(src);
for (bi, blk) in blocks.iter().enumerate() {
let d = f16::from_bits(blk.d_bits).to_f32();
let dmin = f16::from_bits(blk.dmin_bits).to_f32();
let mut scales = [0u8; 8];
let mut mins = [0u8; 8];
for j in 0..8 {
let (s, m) = get_scale_min_k4(j, &blk.scales);
scales[j] = s;
mins[j] = m;
}
let dst = &mut out[bi * QK_K..(bi + 1) * QK_K];
let mut is = 0usize;
let mut j = 0usize;
while j < QK_K {
let q = &blk.qs[j / 2..j / 2 + 32];
let s_lo = scales[is] as f32;
let m_lo = mins[is] as f32;
let s_hi = scales[is + 1] as f32;
let m_hi = mins[is + 1] as f32;
for l in 0..32 {
dst[j + l] = d * s_lo * (q[l] & 0xF) as f32 - dmin * m_lo;
dst[j + l + 32] = d * s_hi * (q[l] >> 4) as f32 - dmin * m_hi;
}
is += 2;
j += 64;
}
}
Ok(())
}
pub const Q6_K_BLOCK_BYTES: usize = 210;
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct BlockQ6K {
ql: [u8; 128],
qh: [u8; 64],
scales: [i8; 16],
d_bits: u16,
}
pub fn dequant_q6_k(src: &[u8], out: &mut [f32]) -> Result<()> {
if !src.len().is_multiple_of(Q6_K_BLOCK_BYTES) {
return Err(RullamaError::Gguf(format!(
"Q6_K source not multiple of {Q6_K_BLOCK_BYTES} bytes (got {})",
src.len()
)));
}
let nb = src.len() / Q6_K_BLOCK_BYTES;
if out.len() != nb * QK_K {
return Err(RullamaError::Gguf(format!(
"Q6_K dest expected {} elements, got {}",
nb * QK_K,
out.len()
)));
}
let blocks: &[BlockQ6K] = bytemuck::cast_slice(src);
for (bi, blk) in blocks.iter().enumerate() {
let d = f16::from_bits(blk.d_bits).to_f32();
let dst = &mut out[bi * QK_K..(bi + 1) * QK_K];
for n_pass in 0..2 {
let ql = &blk.ql[n_pass * 64..(n_pass + 1) * 64];
let qh = &blk.qh[n_pass * 32..(n_pass + 1) * 32];
let sc = &blk.scales[n_pass * 8..(n_pass + 1) * 8];
let base = n_pass * 128;
for l in 0..32 {
let is = l / 16;
let q1 = ((ql[l] & 0xF) as i32 | ((qh[l] & 3) as i32) << 4) - 32;
let q2 = ((ql[l + 32] & 0xF) as i32 | (((qh[l] >> 2) & 3) as i32) << 4) - 32;
let q3 = ((ql[l] >> 4) as i32 | (((qh[l] >> 4) & 3) as i32) << 4) - 32;
let q4 = ((ql[l + 32] >> 4) as i32 | (((qh[l] >> 6) & 3) as i32) << 4) - 32;
dst[base + l] = d * sc[is] as f32 * q1 as f32;
dst[base + l + 32] = d * sc[is + 2] as f32 * q2 as f32;
dst[base + l + 64] = d * sc[is + 4] as f32 * q3 as f32;
dst[base + l + 96] = d * sc[is + 6] as f32 * q4 as f32;
}
}
}
Ok(())
}
pub const Q4_0_BLOCK_BYTES: usize = 18;
pub const QK4_0: usize = 32;
pub fn dequant_q4_0(src: &[u8], out: &mut [f32]) -> Result<()> {
if !src.len().is_multiple_of(Q4_0_BLOCK_BYTES) {
return Err(RullamaError::Gguf(format!(
"Q4_0 source not multiple of {Q4_0_BLOCK_BYTES} bytes (got {})",
src.len()
)));
}
let nb = src.len() / Q4_0_BLOCK_BYTES;
if out.len() != nb * QK4_0 {
return Err(RullamaError::Gguf(format!(
"Q4_0 dest expected {} elements, got {}",
nb * QK4_0,
out.len()
)));
}
for bi in 0..nb {
let off = bi * Q4_0_BLOCK_BYTES;
let d = f16::from_bits(u16::from_le_bytes([src[off], src[off + 1]])).to_f32();
let qs = &src[off + 2..off + Q4_0_BLOCK_BYTES];
let dst = &mut out[bi * QK4_0..(bi + 1) * QK4_0];
for l in 0..16 {
let q = qs[l];
dst[l] = ((q & 0x0F) as f32 - 8.0) * d;
dst[l + 16] = ((q >> 4) as f32 - 8.0) * d;
}
}
Ok(())
}
pub const Q5_0_BLOCK_BYTES: usize = 22;
pub const QK5_0: usize = 32;
pub fn dequant_q5_0(src: &[u8], out: &mut [f32]) -> Result<()> {
if !src.len().is_multiple_of(Q5_0_BLOCK_BYTES) {
return Err(RullamaError::Gguf(format!(
"Q5_0 source not multiple of {Q5_0_BLOCK_BYTES} bytes (got {})",
src.len()
)));
}
let nb = src.len() / Q5_0_BLOCK_BYTES;
if out.len() != nb * QK5_0 {
return Err(RullamaError::Gguf(format!(
"Q5_0 dest expected {} elements, got {}",
nb * QK5_0,
out.len()
)));
}
for bi in 0..nb {
let off = bi * Q5_0_BLOCK_BYTES;
let d = f16::from_bits(u16::from_le_bytes([src[off], src[off + 1]])).to_f32();
let qh = u32::from_le_bytes([src[off + 2], src[off + 3], src[off + 4], src[off + 5]]);
let qs = &src[off + 6..off + Q5_0_BLOCK_BYTES];
let dst = &mut out[bi * QK5_0..(bi + 1) * QK5_0];
for (j, q) in qs.iter().enumerate() {
let xh_0 = ((qh >> j) << 4) & 0x10;
let xh_1 = (qh >> (j + 12)) & 0x10;
let x0 = ((u32::from(q & 0x0F) | xh_0) as i32) - 16;
let x1 = ((u32::from(q >> 4) | xh_1) as i32) - 16;
dst[j] = x0 as f32 * d;
dst[j + 16] = x1 as f32 * d;
}
}
Ok(())
}
pub const Q8_0_BLOCK_BYTES: usize = 34;
pub const QK8_0: usize = 32;
pub fn dequant_q8_0(src: &[u8], out: &mut [f32]) -> Result<()> {
if !src.len().is_multiple_of(Q8_0_BLOCK_BYTES) {
return Err(RullamaError::Gguf(format!(
"Q8_0 source not multiple of {Q8_0_BLOCK_BYTES} bytes (got {})",
src.len()
)));
}
let nb = src.len() / Q8_0_BLOCK_BYTES;
if out.len() != nb * QK8_0 {
return Err(RullamaError::Gguf(format!(
"Q8_0 dest expected {} elements, got {}",
nb * QK8_0,
out.len()
)));
}
for bi in 0..nb {
let off = bi * Q8_0_BLOCK_BYTES;
let d = f16::from_bits(u16::from_le_bytes([src[off], src[off + 1]])).to_f32();
let qs = &src[off + 2..off + Q8_0_BLOCK_BYTES];
let dst = &mut out[bi * QK8_0..(bi + 1) * QK8_0];
for (l, q) in qs.iter().enumerate() {
dst[l] = (*q as i8) as f32 * d;
}
}
Ok(())
}
pub fn bf16_to_f32(src: &[u8], out: &mut [f32]) -> Result<()> {
if !src.len().is_multiple_of(2) {
return Err(RullamaError::Gguf(format!(
"BF16 source byte length {} is odd",
src.len()
)));
}
if out.len() * 2 != src.len() {
return Err(RullamaError::Gguf(format!(
"BF16 dest expected {} elements, got {}",
src.len() / 2,
out.len()
)));
}
for (i, chunk) in src.chunks_exact(2).enumerate() {
let bits = u32::from(u16::from_le_bytes([chunk[0], chunk[1]])) << 16;
out[i] = f32::from_bits(bits);
}
Ok(())
}
pub fn f16_to_f32(src: &[u8], out: &mut [f32]) -> Result<()> {
if !src.len().is_multiple_of(2) {
return Err(RullamaError::Gguf(format!(
"F16 source byte length {} is odd",
src.len()
)));
}
if out.len() * 2 != src.len() {
return Err(RullamaError::Gguf(format!(
"F16 dest expected {} elements, got {}",
src.len() / 2,
out.len()
)));
}
for (i, chunk) in src.chunks_exact(2).enumerate() {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
out[i] = f16::from_bits(bits).to_f32();
}
Ok(())
}
pub fn f32_to_f32(src: &[u8], out: &mut [f32]) -> Result<()> {
if !src.len().is_multiple_of(4) {
return Err(RullamaError::Gguf(format!(
"F32 source byte length {} not /4",
src.len()
)));
}
if out.len() * 4 != src.len() {
return Err(RullamaError::Gguf(format!(
"F32 dest expected {} elements, got {}",
src.len() / 4,
out.len()
)));
}
for (i, chunk) in src.chunks_exact(4).enumerate() {
out[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
Ok(())
}
pub fn dequant_into_f32(dtype: GgmlDtype, src: &[u8], out: &mut [f32]) -> Result<()> {
match dtype {
GgmlDtype::F32 => f32_to_f32(src, out),
GgmlDtype::F16 => f16_to_f32(src, out),
GgmlDtype::BF16 => bf16_to_f32(src, out),
GgmlDtype::Q4_K => dequant_q4_k(src, out),
GgmlDtype::Q6_K => dequant_q6_k(src, out),
GgmlDtype::Q4_0 => dequant_q4_0(src, out),
GgmlDtype::Q5_0 => dequant_q5_0(src, out),
GgmlDtype::Q8_0 => dequant_q8_0(src, out),
other => Err(RullamaError::Gguf(format!(
"dtype {other:?} is not in dequant scope (only F32, F16, BF16, Q4_0, Q5_0, Q8_0, Q4_K, Q6_K)"
))),
}
}
pub fn f16_to_f16_bits(src: &[u8], out: &mut [u16]) -> Result<()> {
if !src.len().is_multiple_of(2) {
return Err(RullamaError::Gguf(format!(
"F16 source byte length {} is odd",
src.len()
)));
}
if out.len() * 2 != src.len() {
return Err(RullamaError::Gguf(format!(
"F16 dest expected {} elements, got {}",
src.len() / 2,
out.len()
)));
}
for (i, chunk) in src.chunks_exact(2).enumerate() {
out[i] = u16::from_le_bytes([chunk[0], chunk[1]]);
}
Ok(())
}
pub fn dequant_into_f16(dtype: GgmlDtype, src: &[u8], out: &mut [u16]) -> Result<()> {
match dtype {
GgmlDtype::F16 => f16_to_f16_bits(src, out),
_ => {
let mut tmp = vec![0f32; out.len()];
dequant_into_f32(dtype, src, &mut tmp)?;
for (o, &v) in out.iter_mut().zip(tmp.iter()) {
*o = f16::from_f32(v).to_bits();
}
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn synth_q4_k_zero() -> Vec<u8> {
let mut buf = vec![0u8; Q4_K_BLOCK_BYTES];
buf[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
buf[2..4].copy_from_slice(&0x0000u16.to_le_bytes()); for j in 0..4 {
buf[4 + j] = 1;
}
for j in 4..8 {
buf[4 + j + 4] = 0x01;
}
buf
}
#[test]
fn q4_k_zero_block_dequants_to_zero() {
let src = synth_q4_k_zero();
let mut out = vec![999f32; QK_K];
dequant_q4_k(&src, &mut out).unwrap();
for &v in &out {
assert_eq!(v, 0.0, "dequant of all-zero quants must be zero");
}
}
#[test]
fn q4_0_dequant_matches_ggml_oracle() {
let mut buf = vec![0u8; Q4_0_BLOCK_BYTES];
buf[0..2].copy_from_slice(&0x4000u16.to_le_bytes()); buf[2] = 0x3A; buf[2 + 15] = 0x80; let mut out = vec![999f32; QK4_0];
dequant_q4_0(&buf, &mut out).unwrap();
assert_eq!(out[0], 4.0);
assert_eq!(out[16], -10.0);
assert_eq!(out[15], -16.0);
assert_eq!(out[31], 0.0);
for l in 1..15 {
assert_eq!(out[l], -16.0, "low nibble {l}");
assert_eq!(out[l + 16], -16.0, "high nibble {l}");
}
}
#[test]
fn q4_0_into_f32_dispatch_routes() {
let mut buf = vec![0u8; Q4_0_BLOCK_BYTES];
buf[0..2].copy_from_slice(&0x3C00u16.to_le_bytes()); for b in &mut buf[2..] {
*b = 0x88; }
let mut out = vec![999f32; QK4_0];
dequant_into_f32(GgmlDtype::Q4_0, &buf, &mut out).unwrap();
assert!(out.iter().all(|&v| v == 0.0), "nibble 8 with offset -8 → 0");
}
#[test]
fn q5_0_dequant_matches_ggml_oracle() {
let mut buf = vec![0u8; Q5_0_BLOCK_BYTES];
buf[0..2].copy_from_slice(&0x4000u16.to_le_bytes()); let qh: u32 = 1 | (1 << 16); buf[2..6].copy_from_slice(&qh.to_le_bytes());
buf[6] = 0x3F; let mut out = vec![999f32; QK5_0];
dequant_q5_0(&buf, &mut out).unwrap();
assert_eq!(out[0], 30.0);
assert_eq!(out[16], 6.0);
for j in 1..16 {
assert_eq!(out[j], -32.0, "elem {j}");
assert_eq!(out[j + 16], -32.0, "elem {}", j + 16);
}
}
#[test]
fn q5_0_into_f32_dispatch_routes() {
let mut buf = vec![0u8; Q5_0_BLOCK_BYTES];
buf[0..2].copy_from_slice(&0x3C00u16.to_le_bytes());
let mut out = vec![999f32; QK5_0];
dequant_into_f32(GgmlDtype::Q5_0, &buf, &mut out).unwrap();
assert!(out.iter().all(|&v| v == -16.0));
}
#[test]
fn q8_0_dequant_matches_ggml_oracle() {
let mut buf = vec![0u8; Q8_0_BLOCK_BYTES];
buf[0..2].copy_from_slice(&0x4000u16.to_le_bytes()); buf[2] = 10; buf[2 + 1] = 0x80; buf[2 + 15] = 0xFF; buf[2 + 31] = 127; let mut out = vec![999f32; QK8_0];
dequant_q8_0(&buf, &mut out).unwrap();
assert_eq!(out[0], 20.0);
assert_eq!(out[1], -256.0);
assert_eq!(out[15], -2.0);
assert_eq!(out[31], 254.0);
for l in 2..15 {
assert_eq!(out[l], 0.0, "untouched qs[{l}] must dequant to 0");
}
for l in 16..31 {
assert_eq!(out[l], 0.0, "untouched qs[{l}] must dequant to 0");
}
}
#[test]
fn q8_0_into_f32_dispatch_routes() {
let mut buf = vec![0u8; 2 * Q8_0_BLOCK_BYTES];
buf[0..2].copy_from_slice(&0x3C00u16.to_le_bytes()); buf[2] = 0x05; let b1 = Q8_0_BLOCK_BYTES;
buf[b1..b1 + 2].copy_from_slice(&0x3800u16.to_le_bytes()); buf[b1 + 2] = 0xFC; let mut out = vec![999f32; 2 * QK8_0];
dequant_into_f32(GgmlDtype::Q8_0, &buf, &mut out).unwrap();
assert_eq!(out[0], 5.0);
assert_eq!(out[32], -2.0);
assert!(out[1..32].iter().all(|&v| v == 0.0));
assert!(out[33..].iter().all(|&v| v == 0.0));
}
#[test]
fn dequant_into_f16_f16_passthrough_is_lossless() {
let vals = [1.0f32, 2.0, -0.5, 0.0, 65504.0]; let mut src = Vec::new();
for &v in &vals {
src.extend_from_slice(&f16::from_f32(v).to_bits().to_le_bytes());
}
let mut bits = vec![0u16; vals.len()];
dequant_into_f16(GgmlDtype::F16, &src, &mut bits).unwrap();
for (i, &v) in vals.iter().enumerate() {
assert_eq!(
bits[i],
f16::from_f32(v).to_bits(),
"f16 passthrough bit-exact at {i}"
);
assert_eq!(
f16::from_bits(bits[i]).to_f32(),
v,
"decoded value matches at {i}"
);
}
}
#[test]
fn dequant_into_f16_f32_downcast() {
let vals = [1.0f32, 2.0, -0.5, 0.1];
let mut src = Vec::new();
for &v in &vals {
src.extend_from_slice(&v.to_le_bytes());
}
let mut bits = vec![0u16; vals.len()];
dequant_into_f16(GgmlDtype::F32, &src, &mut bits).unwrap();
for (i, &v) in vals.iter().enumerate() {
assert_eq!(
bits[i],
f16::from_f32(v).to_bits(),
"f32→f16 downcast matches at {i}"
);
}
}
#[test]
fn q4_k_alternating_nibbles() {
let mut buf = synth_q4_k_zero();
for b in &mut buf[16..16 + 128] {
*b = 0xA5;
}
let mut out = vec![0f32; QK_K];
dequant_q4_k(&buf, &mut out).unwrap();
for chunk in 0..(QK_K / 64) {
for l in 0..32 {
assert_eq!(out[chunk * 64 + l], 5.0, "low nibble dequant");
assert_eq!(out[chunk * 64 + l + 32], 10.0, "high nibble dequant");
}
}
}
#[test]
fn q6_k_zero_block_dequants_to_zero() {
let mut buf = vec![0u8; Q6_K_BLOCK_BYTES];
buf[208..210].copy_from_slice(&0x3C00u16.to_le_bytes());
let mut out = vec![999f32; QK_K];
dequant_q6_k(&buf, &mut out).unwrap();
assert!(out.iter().all(|&v| v == 0.0));
}
#[test]
fn q6_k_unit_scale_constant_quants() {
let mut buf = vec![0u8; Q6_K_BLOCK_BYTES];
for i in 0..16 {
buf[192 + i] = 1;
}
buf[208..210].copy_from_slice(&0x3C00u16.to_le_bytes());
let mut out = vec![0f32; QK_K];
dequant_q6_k(&buf, &mut out).unwrap();
for &v in &out {
assert_eq!(v, -32.0);
}
}
#[test]
fn f16_round_trip() {
let values: [f32; 4] = [0.0, 1.0, -2.5, 3.5];
let mut bytes = Vec::with_capacity(values.len() * 2);
for v in values {
bytes.extend_from_slice(&f16::from_f32(v).to_bits().to_le_bytes());
}
let mut out = vec![0f32; values.len()];
f16_to_f32(&bytes, &mut out).unwrap();
for i in 0..values.len() {
assert!(
(out[i] - values[i]).abs() < 0.01,
"got {} want {}",
out[i],
values[i]
);
}
}
}