Skip to main content

entrenar/quant/gguf_quant/
q4_0.rs

1//! Q4_0 quantization format
2
3use super::GGUF_BLOCK_SIZE;
4use serde::{Deserialize, Serialize};
5
6/// Q4_0 quantized tensor (GGUF format)
7///
8/// 4-bit quantization with per-block f16 scale factors.
9/// Each block: 32 values → 18 bytes (2 scale + 16 data)
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct Q4_0 {
12    /// Per-block scale factors (stored as f32, converted to f16 on export)
13    pub scales: Vec<f32>,
14    /// Packed 4-bit data (2 values per byte, 16 bytes per block)
15    pub data: Vec<u8>,
16    /// Original number of elements
17    pub len: usize,
18}
19
20impl Q4_0 {
21    /// Quantize f32 values to Q4_0 format
22    pub fn quantize(values: &[f32]) -> Self {
23        let len = values.len();
24        let num_blocks = len.div_ceil(GGUF_BLOCK_SIZE);
25
26        let mut scales = Vec::with_capacity(num_blocks);
27        let mut data = Vec::with_capacity(num_blocks * 16); // 16 bytes per block
28
29        for block_idx in 0..num_blocks {
30            let start = block_idx * GGUF_BLOCK_SIZE;
31            let end = (start + GGUF_BLOCK_SIZE).min(len);
32            let block = &values[start..end];
33
34            // Compute scale: max absolute value / 7 (4-bit signed: -8 to 7)
35            let max_abs = block
36                .iter()
37                .map(|v| v.abs())
38                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
39                .unwrap_or(0.0);
40
41            let scale = if max_abs < 1e-10 { 1e-10 } else { max_abs / 7.0 };
42            scales.push(scale);
43
44            // Quantize block (pad with zeros if incomplete)
45            let mut block_data = [0u8; 16];
46            for i in 0..GGUF_BLOCK_SIZE {
47                let val = if start + i < end { block[i] } else { 0.0 };
48
49                // Quantize to [-8, 7] range
50                let q = ((val / scale).round().clamp(-8.0, 7.0) as i8) & 0x0F;
51
52                // Pack 2 values per byte
53                if i % 2 == 0 {
54                    block_data[i / 2] = (q as u8) & 0x0F;
55                } else {
56                    block_data[i / 2] |= ((q as u8) & 0x0F) << 4;
57                }
58            }
59            data.extend_from_slice(&block_data);
60        }
61
62        Self { scales, data, len }
63    }
64
65    /// Dequantize Q4_0 back to f32
66    pub fn dequantize(&self) -> Vec<f32> {
67        let mut result = Vec::with_capacity(self.len);
68        let num_blocks = self.scales.len();
69
70        for block_idx in 0..num_blocks {
71            let scale = self.scales[block_idx];
72            let start = block_idx * GGUF_BLOCK_SIZE;
73            let block_len = (self.len - start).min(GGUF_BLOCK_SIZE);
74
75            for i in 0..block_len {
76                let byte_idx = block_idx * 16 + i / 2;
77                let byte = self.data[byte_idx];
78
79                // Extract 4-bit value
80                let nibble = if i % 2 == 0 { byte & 0x0F } else { (byte >> 4) & 0x0F };
81
82                // Sign extend from 4-bit
83                let q = if nibble & 0x08 != 0 { (nibble | 0xF0) as i8 } else { nibble as i8 };
84
85                result.push(f32::from(q) * scale);
86            }
87        }
88
89        result
90    }
91
92    /// Get memory usage in bytes
93    pub fn memory_bytes(&self) -> usize {
94        self.scales.len() * 4 + self.data.len() // scales as f32 for now
95    }
96
97    /// Get GGUF-format memory (with f16 scales)
98    pub fn gguf_bytes(&self) -> usize {
99        self.scales.len() * 2 + self.data.len() // 2 bytes per f16 scale
100    }
101
102    /// Get compression ratio vs f32
103    pub fn compression_ratio(&self) -> f32 {
104        let original = self.len * 4;
105        original as f32 / self.gguf_bytes() as f32
106    }
107
108    /// Number of blocks
109    pub fn num_blocks(&self) -> usize {
110        self.scales.len()
111    }
112}