Skip to main content

axonml_quant/
types.rs

1//! Quantization Types
2//!
3//! # File
4//! `crates/axonml-quant/src/types.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use half::f16;
19use std::fmt;
20
21// =============================================================================
22// Quantization Type Enum
23// =============================================================================
24
25/// Quantization type enumeration.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum QuantType {
28    /// 8-bit quantization with per-block scale.
29    /// Format: scale (f16) + 32 x int8
30    Q8_0,
31
32    /// 4-bit quantization with per-block scale.
33    /// Format: scale (f16) + 16 x uint8 (two 4-bit values each)
34    Q4_0,
35
36    /// 4-bit quantization with per-block scale and min.
37    /// Format: scale (f16) + min (f16) + 16 x uint8
38    Q4_1,
39
40    /// 5-bit quantization with per-block scale.
41    Q5_0,
42
43    /// 5-bit quantization with per-block scale and min.
44    Q5_1,
45
46    /// Half-precision (16-bit float).
47    F16,
48
49    /// Full precision (32-bit float).
50    F32,
51}
52
53impl QuantType {
54    /// Returns the block size for this quantization type.
55    pub fn block_size(&self) -> usize {
56        match self {
57            QuantType::Q8_0
58            | QuantType::Q4_0
59            | QuantType::Q4_1
60            | QuantType::Q5_0
61            | QuantType::Q5_1 => 32,
62            QuantType::F16 | QuantType::F32 => 1,
63        }
64    }
65
66    /// Returns the number of bytes per block.
67    pub fn bytes_per_block(&self) -> usize {
68        match self {
69            QuantType::Q8_0 => 2 + 32, // f16 scale + 32 int8
70            QuantType::Q4_0 => 2 + 16, // f16 scale + 16 bytes (32 x 4-bit)
71            QuantType::Q4_1 => 4 + 16, // f16 scale + f16 min + 16 bytes
72            QuantType::Q5_0 => 2 + 20, // f16 scale + 20 bytes (32 x 5-bit)
73            QuantType::Q5_1 => 4 + 20, // f16 scale + f16 min + 20 bytes
74            QuantType::F16 => 2,
75            QuantType::F32 => 4,
76        }
77    }
78
79    /// Returns the bits per value.
80    pub fn bits_per_value(&self) -> usize {
81        match self {
82            QuantType::Q8_0 => 8,
83            QuantType::Q4_0 | QuantType::Q4_1 => 4,
84            QuantType::Q5_0 | QuantType::Q5_1 => 5,
85            QuantType::F16 => 16,
86            QuantType::F32 => 32,
87        }
88    }
89
90    /// Returns the compression ratio compared to F32.
91    pub fn compression_ratio(&self) -> f32 {
92        32.0 / self.bits_per_value() as f32
93    }
94
95    /// Returns true if this type uses block quantization.
96    pub fn is_block_quantized(&self) -> bool {
97        matches!(
98            self,
99            QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q5_0 | QuantType::Q5_1
100        )
101    }
102
103    /// Parses a quantization type from a string.
104    pub fn parse_type(s: &str) -> Option<Self> {
105        match s.to_uppercase().as_str() {
106            "Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
107            "Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
108            "Q4_1" => Some(QuantType::Q4_1),
109            "Q5_0" | "Q5" => Some(QuantType::Q5_0),
110            "Q5_1" => Some(QuantType::Q5_1),
111            "F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
112            "F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
113            _ => None,
114        }
115    }
116}
117
118impl std::str::FromStr for QuantType {
119    type Err = String;
120
121    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
122        Self::parse_type(s).ok_or_else(|| format!("Unknown quant type: '{s}'"))
123    }
124}
125
126impl fmt::Display for QuantType {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        match self {
129            QuantType::Q8_0 => write!(f, "Q8_0"),
130            QuantType::Q4_0 => write!(f, "Q4_0"),
131            QuantType::Q4_1 => write!(f, "Q4_1"),
132            QuantType::Q5_0 => write!(f, "Q5_0"),
133            QuantType::Q5_1 => write!(f, "Q5_1"),
134            QuantType::F16 => write!(f, "F16"),
135            QuantType::F32 => write!(f, "F32"),
136        }
137    }
138}
139
140// =============================================================================
141// Quantized Block Structures
142// =============================================================================
143
144/// A block of Q8_0 quantized data.
145#[derive(Debug, Clone)]
146pub struct Q8Block {
147    /// Scale factor (stored as f16).
148    pub scale: f16,
149    /// Quantized values (32 x int8).
150    pub data: [i8; 32],
151}
152
153impl Q8Block {
154    /// Creates a new Q8 block.
155    pub fn new(scale: f16, data: [i8; 32]) -> Self {
156        Self { scale, data }
157    }
158
159    /// Returns the byte representation of this block.
160    pub fn to_bytes(&self) -> Vec<u8> {
161        let mut bytes = Vec::with_capacity(34);
162        bytes.extend_from_slice(&self.scale.to_le_bytes());
163        bytes.extend(self.data.iter().map(|&x| x as u8));
164        bytes
165    }
166
167    /// Creates a block from bytes.
168    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
169        if bytes.len() < 34 {
170            return None;
171        }
172        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
173        let mut data = [0i8; 32];
174        for (i, &b) in bytes[2..34].iter().enumerate() {
175            data[i] = b as i8;
176        }
177        Some(Self { scale, data })
178    }
179}
180
181/// A block of Q4_0 quantized data.
182#[derive(Debug, Clone)]
183pub struct Q4Block {
184    /// Scale factor (stored as f16).
185    pub scale: f16,
186    /// Packed quantized values (16 bytes = 32 x 4-bit).
187    pub data: [u8; 16],
188}
189
190impl Q4Block {
191    /// Creates a new Q4 block.
192    pub fn new(scale: f16, data: [u8; 16]) -> Self {
193        Self { scale, data }
194    }
195
196    /// Extracts the 4-bit values as i8 (range -8 to 7).
197    pub fn unpack(&self) -> [i8; 32] {
198        let mut result = [0i8; 32];
199        for i in 0..16 {
200            let byte = self.data[i];
201            result[i * 2] = ((byte & 0x0F) as i8) - 8;
202            result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
203        }
204        result
205    }
206
207    /// Packs 32 i8 values (-8 to 7 range) into 16 bytes.
208    pub fn pack(values: &[i8; 32]) -> [u8; 16] {
209        let mut data = [0u8; 16];
210        for i in 0..16 {
211            let low = ((values[i * 2] + 8) as u8) & 0x0F;
212            let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
213            data[i] = low | (high << 4);
214        }
215        data
216    }
217
218    /// Returns the byte representation of this block.
219    pub fn to_bytes(&self) -> Vec<u8> {
220        let mut bytes = Vec::with_capacity(18);
221        bytes.extend_from_slice(&self.scale.to_le_bytes());
222        bytes.extend_from_slice(&self.data);
223        bytes
224    }
225
226    /// Creates a block from bytes.
227    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
228        if bytes.len() < 18 {
229            return None;
230        }
231        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
232        let mut data = [0u8; 16];
233        data.copy_from_slice(&bytes[2..18]);
234        Some(Self { scale, data })
235    }
236}
237
238/// A block of Q4_1 quantized data (with min value).
239#[derive(Debug, Clone)]
240pub struct Q4_1Block {
241    /// Scale factor (stored as f16).
242    pub scale: f16,
243    /// Minimum value (stored as f16).
244    pub min: f16,
245    /// Packed quantized values (16 bytes = 32 x 4-bit).
246    pub data: [u8; 16],
247}
248
249impl Q4_1Block {
250    /// Creates a new Q4_1 block.
251    pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
252        Self { scale, min, data }
253    }
254
255    /// Extracts the 4-bit values as u8 (range 0 to 15).
256    pub fn unpack(&self) -> [u8; 32] {
257        let mut result = [0u8; 32];
258        for i in 0..16 {
259            let byte = self.data[i];
260            result[i * 2] = byte & 0x0F;
261            result[i * 2 + 1] = byte >> 4;
262        }
263        result
264    }
265
266    /// Returns the byte representation of this block.
267    pub fn to_bytes(&self) -> Vec<u8> {
268        let mut bytes = Vec::with_capacity(20);
269        bytes.extend_from_slice(&self.scale.to_le_bytes());
270        bytes.extend_from_slice(&self.min.to_le_bytes());
271        bytes.extend_from_slice(&self.data);
272        bytes
273    }
274}
275
276// =============================================================================
277// Q5_0 Block (5-bit symmetric)
278// =============================================================================
279
280/// A block of Q5_0 quantized data (5-bit with per-block scale).
281///
282/// 32 values × 5 bits = 160 bits = 20 bytes of packed data + f16 scale.
283#[derive(Debug, Clone)]
284pub struct Q5Block {
285    /// Scale factor (stored as f16).
286    pub scale: f16,
287    /// Packed quantized values (20 bytes = 32 × 5-bit).
288    pub data: [u8; 20],
289}
290
291impl Q5Block {
292    /// Creates a new Q5_0 block.
293    pub fn new(scale: f16, data: [u8; 20]) -> Self {
294        Self { scale, data }
295    }
296
297    /// Packs 32 signed 5-bit values (range -16 to 15) into 20 bytes.
298    pub fn pack(values: &[i8; 32]) -> [u8; 20] {
299        let mut packed = [0u8; 20];
300        // Pack 32 × 5-bit values: 8 groups of 4 values → 20 bits each → 2.5 bytes
301        // Simpler: treat as 160-bit bitstream
302        #[allow(clippy::needless_range_loop)]
303        for i in 0..32 {
304            let v = (values[i] as u8) & 0x1F; // 5-bit unsigned representation
305            let bit_offset = i * 5;
306            let byte_offset = bit_offset / 8;
307            let bit_shift = bit_offset % 8;
308            packed[byte_offset] |= v << bit_shift;
309            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
310                packed[byte_offset + 1] |= v >> (8 - bit_shift);
311            }
312        }
313        packed
314    }
315
316    /// Unpacks 20 bytes into 32 signed 5-bit values.
317    pub fn unpack(&self) -> [i8; 32] {
318        let mut result = [0i8; 32];
319        #[allow(clippy::needless_range_loop)]
320        for i in 0..32 {
321            let bit_offset = i * 5;
322            let byte_offset = bit_offset / 8;
323            let bit_shift = bit_offset % 8;
324            let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
325            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
326                v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
327            }
328            // Sign extend: if bit 4 is set, value is negative
329            if v & 0x10 != 0 {
330                result[i] = (v | 0xE0) as i8; // sign extend to i8
331            } else {
332                result[i] = v as i8;
333            }
334        }
335        result
336    }
337
338    /// Returns byte representation.
339    pub fn to_bytes(&self) -> Vec<u8> {
340        let mut bytes = Vec::with_capacity(22);
341        bytes.extend_from_slice(&self.scale.to_le_bytes());
342        bytes.extend_from_slice(&self.data);
343        bytes
344    }
345}
346
347// =============================================================================
348// Q5_1 Block (5-bit asymmetric)
349// =============================================================================
350
351/// A block of Q5_1 quantized data (5-bit with per-block scale and min).
352#[derive(Debug, Clone)]
353pub struct Q5_1Block {
354    /// Scale factor (stored as f16).
355    pub scale: f16,
356    /// Minimum value (stored as f16).
357    pub min: f16,
358    /// Packed quantized values (20 bytes = 32 × 5-bit unsigned).
359    pub data: [u8; 20],
360}
361
362impl Q5_1Block {
363    /// Creates a new Q5_1 block.
364    pub fn new(scale: f16, min: f16, data: [u8; 20]) -> Self {
365        Self { scale, min, data }
366    }
367
368    /// Packs 32 unsigned 5-bit values (range 0 to 31) into 20 bytes.
369    pub fn pack(values: &[u8; 32]) -> [u8; 20] {
370        let mut packed = [0u8; 20];
371        #[allow(clippy::needless_range_loop)]
372        for i in 0..32 {
373            let v = values[i] & 0x1F;
374            let bit_offset = i * 5;
375            let byte_offset = bit_offset / 8;
376            let bit_shift = bit_offset % 8;
377            packed[byte_offset] |= v << bit_shift;
378            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
379                packed[byte_offset + 1] |= v >> (8 - bit_shift);
380            }
381        }
382        packed
383    }
384
385    /// Unpacks 20 bytes into 32 unsigned 5-bit values.
386    pub fn unpack(&self) -> [u8; 32] {
387        let mut result = [0u8; 32];
388        #[allow(clippy::needless_range_loop)]
389        for i in 0..32 {
390            let bit_offset = i * 5;
391            let byte_offset = bit_offset / 8;
392            let bit_shift = bit_offset % 8;
393            let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
394            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
395                v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
396            }
397            result[i] = v;
398        }
399        result
400    }
401
402    /// Returns byte representation.
403    pub fn to_bytes(&self) -> Vec<u8> {
404        let mut bytes = Vec::with_capacity(24);
405        bytes.extend_from_slice(&self.scale.to_le_bytes());
406        bytes.extend_from_slice(&self.min.to_le_bytes());
407        bytes.extend_from_slice(&self.data);
408        bytes
409    }
410}
411
412// =============================================================================
413// Generic Quantized Block
414// =============================================================================
415
416/// Generic quantized block enum.
417#[derive(Debug, Clone)]
418pub enum QuantizedBlock {
419    /// Q8_0 block.
420    Q8(Q8Block),
421    /// Q4_0 block.
422    Q4(Q4Block),
423    /// Q4_1 block.
424    Q4_1(Q4_1Block),
425    /// Q5_0 block (5-bit symmetric).
426    Q5(Q5Block),
427    /// Q5_1 block (5-bit asymmetric).
428    Q5_1(Q5_1Block),
429    /// F16 values (block size 1).
430    F16(Vec<f16>),
431    /// F32 values (original).
432    F32(Vec<f32>),
433}
434
435impl QuantizedBlock {
436    /// Returns the quantization type of this block.
437    pub fn quant_type(&self) -> QuantType {
438        match self {
439            QuantizedBlock::Q8(_) => QuantType::Q8_0,
440            QuantizedBlock::Q4(_) => QuantType::Q4_0,
441            QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
442            QuantizedBlock::Q5(_) => QuantType::Q5_0,
443            QuantizedBlock::Q5_1(_) => QuantType::Q5_1,
444            QuantizedBlock::F16(_) => QuantType::F16,
445            QuantizedBlock::F32(_) => QuantType::F32,
446        }
447    }
448}
449
450// =============================================================================
451// Quantized Tensor
452// =============================================================================
453
454/// A quantized tensor containing compressed weight data.
455#[derive(Debug, Clone)]
456pub struct QuantizedTensor {
457    /// Original tensor shape.
458    pub shape: Vec<usize>,
459    /// Quantization type.
460    pub quant_type: QuantType,
461    /// Quantized data blocks.
462    pub blocks: Vec<QuantizedBlock>,
463    /// Number of elements.
464    pub numel: usize,
465}
466
467impl QuantizedTensor {
468    /// Creates a new quantized tensor.
469    pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
470        let numel = shape.iter().product();
471        Self {
472            shape,
473            quant_type,
474            blocks,
475            numel,
476        }
477    }
478
479    /// Returns the memory size in bytes.
480    pub fn size_bytes(&self) -> usize {
481        self.blocks.len() * self.quant_type.bytes_per_block()
482    }
483
484    /// Returns the compression ratio compared to F32.
485    pub fn compression_ratio(&self) -> f32 {
486        let original_bytes = self.numel * 4;
487        original_bytes as f32 / self.size_bytes() as f32
488    }
489
490    /// Returns the number of blocks.
491    pub fn num_blocks(&self) -> usize {
492        self.blocks.len()
493    }
494}
495
496// =============================================================================
497// Tests
498// =============================================================================
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503
504    #[test]
505    fn test_quant_type_properties() {
506        assert_eq!(QuantType::Q8_0.block_size(), 32);
507        assert_eq!(QuantType::Q4_0.block_size(), 32);
508        assert_eq!(QuantType::F16.block_size(), 1);
509
510        assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
511        assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
512
513        assert!(QuantType::Q8_0.is_block_quantized());
514        assert!(!QuantType::F16.is_block_quantized());
515    }
516
517    #[test]
518    fn test_quant_type_from_str() {
519        assert_eq!(QuantType::parse_type("Q8_0"), Some(QuantType::Q8_0));
520        assert_eq!(QuantType::parse_type("INT8"), Some(QuantType::Q8_0));
521        assert_eq!(QuantType::parse_type("Q4"), Some(QuantType::Q4_0));
522        assert_eq!(QuantType::parse_type("F16"), Some(QuantType::F16));
523        assert_eq!(QuantType::parse_type("invalid"), None);
524    }
525
526    #[test]
527    fn test_q4_pack_unpack() {
528        let values: [i8; 32] = [
529            -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1,
530            0, 1, 2, 3, 4, 5, 6, 7,
531        ];
532
533        let packed = Q4Block::pack(&values);
534        let block = Q4Block::new(f16::from_f32(1.0), packed);
535        let unpacked = block.unpack();
536
537        assert_eq!(values, unpacked);
538    }
539
540    #[test]
541    fn test_q8_block() {
542        let data = [0i8; 32];
543        let block = Q8Block::new(f16::from_f32(0.5), data);
544        let bytes = block.to_bytes();
545        let restored = Q8Block::from_bytes(&bytes).unwrap();
546
547        assert_eq!(block.scale, restored.scale);
548        assert_eq!(block.data, restored.data);
549    }
550}