use super::GGUF_BLOCK_SIZE;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Q4_0 {
pub scales: Vec<f32>,
pub data: Vec<u8>,
pub len: usize,
}
impl Q4_0 {
pub fn quantize(values: &[f32]) -> Self {
let len = values.len();
let num_blocks = len.div_ceil(GGUF_BLOCK_SIZE);
let mut scales = Vec::with_capacity(num_blocks);
let mut data = Vec::with_capacity(num_blocks * 16);
for block_idx in 0..num_blocks {
let start = block_idx * GGUF_BLOCK_SIZE;
let end = (start + GGUF_BLOCK_SIZE).min(len);
let block = &values[start..end];
let max_abs = block
.iter()
.map(|v| v.abs())
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0);
let scale = if max_abs < 1e-10 { 1e-10 } else { max_abs / 7.0 };
scales.push(scale);
let mut block_data = [0u8; 16];
for i in 0..GGUF_BLOCK_SIZE {
let val = if start + i < end { block[i] } else { 0.0 };
let q = ((val / scale).round().clamp(-8.0, 7.0) as i8) & 0x0F;
if i % 2 == 0 {
block_data[i / 2] = (q as u8) & 0x0F;
} else {
block_data[i / 2] |= ((q as u8) & 0x0F) << 4;
}
}
data.extend_from_slice(&block_data);
}
Self { scales, data, len }
}
pub fn dequantize(&self) -> Vec<f32> {
let mut result = Vec::with_capacity(self.len);
let num_blocks = self.scales.len();
for block_idx in 0..num_blocks {
let scale = self.scales[block_idx];
let start = block_idx * GGUF_BLOCK_SIZE;
let block_len = (self.len - start).min(GGUF_BLOCK_SIZE);
for i in 0..block_len {
let byte_idx = block_idx * 16 + i / 2;
let byte = self.data[byte_idx];
let nibble = if i % 2 == 0 { byte & 0x0F } else { (byte >> 4) & 0x0F };
let q = if nibble & 0x08 != 0 { (nibble | 0xF0) as i8 } else { nibble as i8 };
result.push(f32::from(q) * scale);
}
}
result
}
pub fn memory_bytes(&self) -> usize {
self.scales.len() * 4 + self.data.len() }
pub fn gguf_bytes(&self) -> usize {
self.scales.len() * 2 + self.data.len() }
pub fn compression_ratio(&self) -> f32 {
let original = self.len * 4;
original as f32 / self.gguf_bytes() as f32
}
pub fn num_blocks(&self) -> usize {
self.scales.len()
}
}