Skip to main content

axonml_quant/
types.rs

1//! Quantization Types
2//!
3//! # File
4//! `crates/axonml-quant/src/types.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use half::f16;
18use std::fmt;
19
20// =============================================================================
21// Quantization Type Enum
22// =============================================================================
23
24/// Quantization type enumeration.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum QuantType {
27    /// 8-bit quantization with per-block scale.
28    /// Format: scale (f16) + 32 x int8
29    Q8_0,
30
31    /// 4-bit quantization with per-block scale.
32    /// Format: scale (f16) + 16 x uint8 (two 4-bit values each)
33    Q4_0,
34
35    /// 4-bit quantization with per-block scale and min.
36    /// Format: scale (f16) + min (f16) + 16 x uint8
37    Q4_1,
38
39    /// 5-bit quantization with per-block scale.
40    Q5_0,
41
42    /// 5-bit quantization with per-block scale and min.
43    Q5_1,
44
45    /// Half-precision (16-bit float).
46    F16,
47
48    /// Full precision (32-bit float).
49    F32,
50}
51
52impl QuantType {
53    /// Returns the block size for this quantization type.
54    pub fn block_size(&self) -> usize {
55        match self {
56            QuantType::Q8_0
57            | QuantType::Q4_0
58            | QuantType::Q4_1
59            | QuantType::Q5_0
60            | QuantType::Q5_1 => 32,
61            QuantType::F16 | QuantType::F32 => 1,
62        }
63    }
64
65    /// Returns the number of bytes per block.
66    pub fn bytes_per_block(&self) -> usize {
67        match self {
68            QuantType::Q8_0 => 2 + 32, // f16 scale + 32 int8
69            QuantType::Q4_0 => 2 + 16, // f16 scale + 16 bytes (32 x 4-bit)
70            QuantType::Q4_1 => 4 + 16, // f16 scale + f16 min + 16 bytes
71            QuantType::Q5_0 => 2 + 20, // f16 scale + 20 bytes (32 x 5-bit)
72            QuantType::Q5_1 => 4 + 20, // f16 scale + f16 min + 20 bytes
73            QuantType::F16 => 2,
74            QuantType::F32 => 4,
75        }
76    }
77
78    /// Returns the bits per value.
79    pub fn bits_per_value(&self) -> usize {
80        match self {
81            QuantType::Q8_0 => 8,
82            QuantType::Q4_0 | QuantType::Q4_1 => 4,
83            QuantType::Q5_0 | QuantType::Q5_1 => 5,
84            QuantType::F16 => 16,
85            QuantType::F32 => 32,
86        }
87    }
88
89    /// Returns the compression ratio compared to F32.
90    pub fn compression_ratio(&self) -> f32 {
91        32.0 / self.bits_per_value() as f32
92    }
93
94    /// Returns true if this type uses block quantization.
95    pub fn is_block_quantized(&self) -> bool {
96        matches!(
97            self,
98            QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q5_0 | QuantType::Q5_1
99        )
100    }
101
102    /// Parses a quantization type from a string.
103    #[allow(clippy::should_implement_trait)]
104    pub fn from_str(s: &str) -> Option<Self> {
105        match s.to_uppercase().as_str() {
106            "Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
107            "Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
108            "Q4_1" => Some(QuantType::Q4_1),
109            "Q5_0" | "Q5" => Some(QuantType::Q5_0),
110            "Q5_1" => Some(QuantType::Q5_1),
111            "F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
112            "F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
113            _ => None,
114        }
115    }
116}
117
118impl fmt::Display for QuantType {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        match self {
121            QuantType::Q8_0 => write!(f, "Q8_0"),
122            QuantType::Q4_0 => write!(f, "Q4_0"),
123            QuantType::Q4_1 => write!(f, "Q4_1"),
124            QuantType::Q5_0 => write!(f, "Q5_0"),
125            QuantType::Q5_1 => write!(f, "Q5_1"),
126            QuantType::F16 => write!(f, "F16"),
127            QuantType::F32 => write!(f, "F32"),
128        }
129    }
130}
131
132// =============================================================================
133// Quantized Block Structures
134// =============================================================================
135
136/// A block of Q8_0 quantized data.
137#[derive(Debug, Clone)]
138pub struct Q8Block {
139    /// Scale factor (stored as f16).
140    pub scale: f16,
141    /// Quantized values (32 x int8).
142    pub data: [i8; 32],
143}
144
145impl Q8Block {
146    /// Creates a new Q8 block.
147    pub fn new(scale: f16, data: [i8; 32]) -> Self {
148        Self { scale, data }
149    }
150
151    /// Returns the byte representation of this block.
152    pub fn to_bytes(&self) -> Vec<u8> {
153        let mut bytes = Vec::with_capacity(34);
154        bytes.extend_from_slice(&self.scale.to_le_bytes());
155        bytes.extend(self.data.iter().map(|&x| x as u8));
156        bytes
157    }
158
159    /// Creates a block from bytes.
160    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
161        if bytes.len() < 34 {
162            return None;
163        }
164        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
165        let mut data = [0i8; 32];
166        for (i, &b) in bytes[2..34].iter().enumerate() {
167            data[i] = b as i8;
168        }
169        Some(Self { scale, data })
170    }
171}
172
173/// A block of Q4_0 quantized data.
174#[derive(Debug, Clone)]
175pub struct Q4Block {
176    /// Scale factor (stored as f16).
177    pub scale: f16,
178    /// Packed quantized values (16 bytes = 32 x 4-bit).
179    pub data: [u8; 16],
180}
181
182impl Q4Block {
183    /// Creates a new Q4 block.
184    pub fn new(scale: f16, data: [u8; 16]) -> Self {
185        Self { scale, data }
186    }
187
188    /// Extracts the 4-bit values as i8 (range -8 to 7).
189    pub fn unpack(&self) -> [i8; 32] {
190        let mut result = [0i8; 32];
191        for i in 0..16 {
192            let byte = self.data[i];
193            result[i * 2] = ((byte & 0x0F) as i8) - 8;
194            result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
195        }
196        result
197    }
198
199    /// Packs 32 i8 values (-8 to 7 range) into 16 bytes.
200    pub fn pack(values: &[i8; 32]) -> [u8; 16] {
201        let mut data = [0u8; 16];
202        for i in 0..16 {
203            let low = ((values[i * 2] + 8) as u8) & 0x0F;
204            let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
205            data[i] = low | (high << 4);
206        }
207        data
208    }
209
210    /// Returns the byte representation of this block.
211    pub fn to_bytes(&self) -> Vec<u8> {
212        let mut bytes = Vec::with_capacity(18);
213        bytes.extend_from_slice(&self.scale.to_le_bytes());
214        bytes.extend_from_slice(&self.data);
215        bytes
216    }
217
218    /// Creates a block from bytes.
219    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
220        if bytes.len() < 18 {
221            return None;
222        }
223        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
224        let mut data = [0u8; 16];
225        data.copy_from_slice(&bytes[2..18]);
226        Some(Self { scale, data })
227    }
228}
229
230/// A block of Q4_1 quantized data (with min value).
231#[derive(Debug, Clone)]
232pub struct Q4_1Block {
233    /// Scale factor (stored as f16).
234    pub scale: f16,
235    /// Minimum value (stored as f16).
236    pub min: f16,
237    /// Packed quantized values (16 bytes = 32 x 4-bit).
238    pub data: [u8; 16],
239}
240
241impl Q4_1Block {
242    /// Creates a new Q4_1 block.
243    pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
244        Self { scale, min, data }
245    }
246
247    /// Extracts the 4-bit values as u8 (range 0 to 15).
248    pub fn unpack(&self) -> [u8; 32] {
249        let mut result = [0u8; 32];
250        for i in 0..16 {
251            let byte = self.data[i];
252            result[i * 2] = byte & 0x0F;
253            result[i * 2 + 1] = byte >> 4;
254        }
255        result
256    }
257
258    /// Returns the byte representation of this block.
259    pub fn to_bytes(&self) -> Vec<u8> {
260        let mut bytes = Vec::with_capacity(20);
261        bytes.extend_from_slice(&self.scale.to_le_bytes());
262        bytes.extend_from_slice(&self.min.to_le_bytes());
263        bytes.extend_from_slice(&self.data);
264        bytes
265    }
266}
267
268// =============================================================================
269// Generic Quantized Block
270// =============================================================================
271
272/// Generic quantized block enum.
273#[derive(Debug, Clone)]
274pub enum QuantizedBlock {
275    /// Q8_0 block.
276    Q8(Q8Block),
277    /// Q4_0 block.
278    Q4(Q4Block),
279    /// Q4_1 block.
280    Q4_1(Q4_1Block),
281    /// F16 values (block size 1).
282    F16(Vec<f16>),
283    /// F32 values (original).
284    F32(Vec<f32>),
285}
286
287impl QuantizedBlock {
288    /// Returns the quantization type of this block.
289    pub fn quant_type(&self) -> QuantType {
290        match self {
291            QuantizedBlock::Q8(_) => QuantType::Q8_0,
292            QuantizedBlock::Q4(_) => QuantType::Q4_0,
293            QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
294            QuantizedBlock::F16(_) => QuantType::F16,
295            QuantizedBlock::F32(_) => QuantType::F32,
296        }
297    }
298}
299
300// =============================================================================
301// Quantized Tensor
302// =============================================================================
303
304/// A quantized tensor containing compressed weight data.
305#[derive(Debug, Clone)]
306pub struct QuantizedTensor {
307    /// Original tensor shape.
308    pub shape: Vec<usize>,
309    /// Quantization type.
310    pub quant_type: QuantType,
311    /// Quantized data blocks.
312    pub blocks: Vec<QuantizedBlock>,
313    /// Number of elements.
314    pub numel: usize,
315}
316
317impl QuantizedTensor {
318    /// Creates a new quantized tensor.
319    pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
320        let numel = shape.iter().product();
321        Self {
322            shape,
323            quant_type,
324            blocks,
325            numel,
326        }
327    }
328
329    /// Returns the memory size in bytes.
330    pub fn size_bytes(&self) -> usize {
331        self.blocks.len() * self.quant_type.bytes_per_block()
332    }
333
334    /// Returns the compression ratio compared to F32.
335    pub fn compression_ratio(&self) -> f32 {
336        let original_bytes = self.numel * 4;
337        original_bytes as f32 / self.size_bytes() as f32
338    }
339
340    /// Returns the number of blocks.
341    pub fn num_blocks(&self) -> usize {
342        self.blocks.len()
343    }
344}
345
346// =============================================================================
347// Tests
348// =============================================================================
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_quant_type_properties() {
356        assert_eq!(QuantType::Q8_0.block_size(), 32);
357        assert_eq!(QuantType::Q4_0.block_size(), 32);
358        assert_eq!(QuantType::F16.block_size(), 1);
359
360        assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
361        assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
362
363        assert!(QuantType::Q8_0.is_block_quantized());
364        assert!(!QuantType::F16.is_block_quantized());
365    }
366
367    #[test]
368    fn test_quant_type_from_str() {
369        assert_eq!(QuantType::from_str("Q8_0"), Some(QuantType::Q8_0));
370        assert_eq!(QuantType::from_str("INT8"), Some(QuantType::Q8_0));
371        assert_eq!(QuantType::from_str("Q4"), Some(QuantType::Q4_0));
372        assert_eq!(QuantType::from_str("F16"), Some(QuantType::F16));
373        assert_eq!(QuantType::from_str("invalid"), None);
374    }
375
376    #[test]
377    fn test_q4_pack_unpack() {
378        let values: [i8; 32] = [
379            -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1,
380            0, 1, 2, 3, 4, 5, 6, 7,
381        ];
382
383        let packed = Q4Block::pack(&values);
384        let block = Q4Block::new(f16::from_f32(1.0), packed);
385        let unpacked = block.unpack();
386
387        assert_eq!(values, unpacked);
388    }
389
390    #[test]
391    fn test_q8_block() {
392        let data = [0i8; 32];
393        let block = Q8Block::new(f16::from_f32(0.5), data);
394        let bytes = block.to_bytes();
395        let restored = Q8Block::from_bytes(&bytes).unwrap();
396
397        assert_eq!(block.scale, restored.scale);
398        assert_eq!(block.data, restored.data);
399    }
400}