Skip to main content

axonml_quant/
types.rs

1//! Quantization Types
2//!
3//! # File
4//! `crates/axonml-quant/src/types.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use half::f16;
18use std::fmt;
19
20// =============================================================================
21// Quantization Type Enum
22// =============================================================================
23
24/// Quantization type enumeration.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum QuantType {
27    /// 8-bit quantization with per-block scale.
28    /// Format: scale (f16) + 32 x int8
29    Q8_0,
30
31    /// 4-bit quantization with per-block scale.
32    /// Format: scale (f16) + 16 x uint8 (two 4-bit values each)
33    Q4_0,
34
35    /// 4-bit quantization with per-block scale and min.
36    /// Format: scale (f16) + min (f16) + 16 x uint8
37    Q4_1,
38
39    /// 5-bit quantization with per-block scale.
40    Q5_0,
41
42    /// 5-bit quantization with per-block scale and min.
43    Q5_1,
44
45    /// Half-precision (16-bit float).
46    F16,
47
48    /// Full precision (32-bit float).
49    F32,
50}
51
52impl QuantType {
53    /// Returns the block size for this quantization type.
54    pub fn block_size(&self) -> usize {
55        match self {
56            QuantType::Q8_0
57            | QuantType::Q4_0
58            | QuantType::Q4_1
59            | QuantType::Q5_0
60            | QuantType::Q5_1 => 32,
61            QuantType::F16 | QuantType::F32 => 1,
62        }
63    }
64
65    /// Returns the number of bytes per block.
66    pub fn bytes_per_block(&self) -> usize {
67        match self {
68            QuantType::Q8_0 => 2 + 32, // f16 scale + 32 int8
69            QuantType::Q4_0 => 2 + 16, // f16 scale + 16 bytes (32 x 4-bit)
70            QuantType::Q4_1 => 4 + 16, // f16 scale + f16 min + 16 bytes
71            QuantType::Q5_0 => 2 + 20, // f16 scale + 20 bytes (32 x 5-bit)
72            QuantType::Q5_1 => 4 + 20, // f16 scale + f16 min + 20 bytes
73            QuantType::F16 => 2,
74            QuantType::F32 => 4,
75        }
76    }
77
78    /// Returns the bits per value.
79    pub fn bits_per_value(&self) -> usize {
80        match self {
81            QuantType::Q8_0 => 8,
82            QuantType::Q4_0 | QuantType::Q4_1 => 4,
83            QuantType::Q5_0 | QuantType::Q5_1 => 5,
84            QuantType::F16 => 16,
85            QuantType::F32 => 32,
86        }
87    }
88
89    /// Returns the compression ratio compared to F32.
90    pub fn compression_ratio(&self) -> f32 {
91        32.0 / self.bits_per_value() as f32
92    }
93
94    /// Returns true if this type uses block quantization.
95    pub fn is_block_quantized(&self) -> bool {
96        matches!(
97            self,
98            QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q5_0 | QuantType::Q5_1
99        )
100    }
101
102    /// Parses a quantization type from a string.
103    pub fn parse_type(s: &str) -> Option<Self> {
104        match s.to_uppercase().as_str() {
105            "Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
106            "Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
107            "Q4_1" => Some(QuantType::Q4_1),
108            "Q5_0" | "Q5" => Some(QuantType::Q5_0),
109            "Q5_1" => Some(QuantType::Q5_1),
110            "F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
111            "F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
112            _ => None,
113        }
114    }
115}
116
117impl std::str::FromStr for QuantType {
118    type Err = String;
119
120    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
121        Self::parse_type(s).ok_or_else(|| format!("Unknown quant type: '{s}'"))
122    }
123}
124
125impl fmt::Display for QuantType {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        match self {
128            QuantType::Q8_0 => write!(f, "Q8_0"),
129            QuantType::Q4_0 => write!(f, "Q4_0"),
130            QuantType::Q4_1 => write!(f, "Q4_1"),
131            QuantType::Q5_0 => write!(f, "Q5_0"),
132            QuantType::Q5_1 => write!(f, "Q5_1"),
133            QuantType::F16 => write!(f, "F16"),
134            QuantType::F32 => write!(f, "F32"),
135        }
136    }
137}
138
139// =============================================================================
140// Quantized Block Structures
141// =============================================================================
142
143/// A block of Q8_0 quantized data.
144#[derive(Debug, Clone)]
145pub struct Q8Block {
146    /// Scale factor (stored as f16).
147    pub scale: f16,
148    /// Quantized values (32 x int8).
149    pub data: [i8; 32],
150}
151
152impl Q8Block {
153    /// Creates a new Q8 block.
154    pub fn new(scale: f16, data: [i8; 32]) -> Self {
155        Self { scale, data }
156    }
157
158    /// Returns the byte representation of this block.
159    pub fn to_bytes(&self) -> Vec<u8> {
160        let mut bytes = Vec::with_capacity(34);
161        bytes.extend_from_slice(&self.scale.to_le_bytes());
162        bytes.extend(self.data.iter().map(|&x| x as u8));
163        bytes
164    }
165
166    /// Creates a block from bytes.
167    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
168        if bytes.len() < 34 {
169            return None;
170        }
171        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
172        let mut data = [0i8; 32];
173        for (i, &b) in bytes[2..34].iter().enumerate() {
174            data[i] = b as i8;
175        }
176        Some(Self { scale, data })
177    }
178}
179
180/// A block of Q4_0 quantized data.
181#[derive(Debug, Clone)]
182pub struct Q4Block {
183    /// Scale factor (stored as f16).
184    pub scale: f16,
185    /// Packed quantized values (16 bytes = 32 x 4-bit).
186    pub data: [u8; 16],
187}
188
189impl Q4Block {
190    /// Creates a new Q4 block.
191    pub fn new(scale: f16, data: [u8; 16]) -> Self {
192        Self { scale, data }
193    }
194
195    /// Extracts the 4-bit values as i8 (range -8 to 7).
196    pub fn unpack(&self) -> [i8; 32] {
197        let mut result = [0i8; 32];
198        for i in 0..16 {
199            let byte = self.data[i];
200            result[i * 2] = ((byte & 0x0F) as i8) - 8;
201            result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
202        }
203        result
204    }
205
206    /// Packs 32 i8 values (-8 to 7 range) into 16 bytes.
207    pub fn pack(values: &[i8; 32]) -> [u8; 16] {
208        let mut data = [0u8; 16];
209        for i in 0..16 {
210            let low = ((values[i * 2] + 8) as u8) & 0x0F;
211            let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
212            data[i] = low | (high << 4);
213        }
214        data
215    }
216
217    /// Returns the byte representation of this block.
218    pub fn to_bytes(&self) -> Vec<u8> {
219        let mut bytes = Vec::with_capacity(18);
220        bytes.extend_from_slice(&self.scale.to_le_bytes());
221        bytes.extend_from_slice(&self.data);
222        bytes
223    }
224
225    /// Creates a block from bytes.
226    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
227        if bytes.len() < 18 {
228            return None;
229        }
230        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
231        let mut data = [0u8; 16];
232        data.copy_from_slice(&bytes[2..18]);
233        Some(Self { scale, data })
234    }
235}
236
237/// A block of Q4_1 quantized data (with min value).
238#[derive(Debug, Clone)]
239pub struct Q4_1Block {
240    /// Scale factor (stored as f16).
241    pub scale: f16,
242    /// Minimum value (stored as f16).
243    pub min: f16,
244    /// Packed quantized values (16 bytes = 32 x 4-bit).
245    pub data: [u8; 16],
246}
247
248impl Q4_1Block {
249    /// Creates a new Q4_1 block.
250    pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
251        Self { scale, min, data }
252    }
253
254    /// Extracts the 4-bit values as u8 (range 0 to 15).
255    pub fn unpack(&self) -> [u8; 32] {
256        let mut result = [0u8; 32];
257        for i in 0..16 {
258            let byte = self.data[i];
259            result[i * 2] = byte & 0x0F;
260            result[i * 2 + 1] = byte >> 4;
261        }
262        result
263    }
264
265    /// Returns the byte representation of this block.
266    pub fn to_bytes(&self) -> Vec<u8> {
267        let mut bytes = Vec::with_capacity(20);
268        bytes.extend_from_slice(&self.scale.to_le_bytes());
269        bytes.extend_from_slice(&self.min.to_le_bytes());
270        bytes.extend_from_slice(&self.data);
271        bytes
272    }
273}
274
275// =============================================================================
276// Q5_0 Block (5-bit symmetric)
277// =============================================================================
278
279/// A block of Q5_0 quantized data (5-bit with per-block scale).
280///
281/// 32 values × 5 bits = 160 bits = 20 bytes of packed data + f16 scale.
282#[derive(Debug, Clone)]
283pub struct Q5Block {
284    /// Scale factor (stored as f16).
285    pub scale: f16,
286    /// Packed quantized values (20 bytes = 32 × 5-bit).
287    pub data: [u8; 20],
288}
289
290impl Q5Block {
291    /// Creates a new Q5_0 block.
292    pub fn new(scale: f16, data: [u8; 20]) -> Self {
293        Self { scale, data }
294    }
295
296    /// Packs 32 signed 5-bit values (range -16 to 15) into 20 bytes.
297    pub fn pack(values: &[i8; 32]) -> [u8; 20] {
298        let mut packed = [0u8; 20];
299        // Pack 32 × 5-bit values: 8 groups of 4 values → 20 bits each → 2.5 bytes
300        // Simpler: treat as 160-bit bitstream
301        for i in 0..32 {
302            let v = (values[i] as u8) & 0x1F; // 5-bit unsigned representation
303            let bit_offset = i * 5;
304            let byte_offset = bit_offset / 8;
305            let bit_shift = bit_offset % 8;
306            packed[byte_offset] |= v << bit_shift;
307            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
308                packed[byte_offset + 1] |= v >> (8 - bit_shift);
309            }
310        }
311        packed
312    }
313
314    /// Unpacks 20 bytes into 32 signed 5-bit values.
315    pub fn unpack(&self) -> [i8; 32] {
316        let mut result = [0i8; 32];
317        for i in 0..32 {
318            let bit_offset = i * 5;
319            let byte_offset = bit_offset / 8;
320            let bit_shift = bit_offset % 8;
321            let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
322            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
323                v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
324            }
325            // Sign extend: if bit 4 is set, value is negative
326            if v & 0x10 != 0 {
327                result[i] = (v | 0xE0) as i8; // sign extend to i8
328            } else {
329                result[i] = v as i8;
330            }
331        }
332        result
333    }
334
335    /// Returns byte representation.
336    pub fn to_bytes(&self) -> Vec<u8> {
337        let mut bytes = Vec::with_capacity(22);
338        bytes.extend_from_slice(&self.scale.to_le_bytes());
339        bytes.extend_from_slice(&self.data);
340        bytes
341    }
342}
343
344// =============================================================================
345// Q5_1 Block (5-bit asymmetric)
346// =============================================================================
347
348/// A block of Q5_1 quantized data (5-bit with per-block scale and min).
349#[derive(Debug, Clone)]
350pub struct Q5_1Block {
351    /// Scale factor (stored as f16).
352    pub scale: f16,
353    /// Minimum value (stored as f16).
354    pub min: f16,
355    /// Packed quantized values (20 bytes = 32 × 5-bit unsigned).
356    pub data: [u8; 20],
357}
358
359impl Q5_1Block {
360    /// Creates a new Q5_1 block.
361    pub fn new(scale: f16, min: f16, data: [u8; 20]) -> Self {
362        Self { scale, min, data }
363    }
364
365    /// Packs 32 unsigned 5-bit values (range 0 to 31) into 20 bytes.
366    pub fn pack(values: &[u8; 32]) -> [u8; 20] {
367        let mut packed = [0u8; 20];
368        for i in 0..32 {
369            let v = values[i] & 0x1F;
370            let bit_offset = i * 5;
371            let byte_offset = bit_offset / 8;
372            let bit_shift = bit_offset % 8;
373            packed[byte_offset] |= v << bit_shift;
374            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
375                packed[byte_offset + 1] |= v >> (8 - bit_shift);
376            }
377        }
378        packed
379    }
380
381    /// Unpacks 20 bytes into 32 unsigned 5-bit values.
382    pub fn unpack(&self) -> [u8; 32] {
383        let mut result = [0u8; 32];
384        for i in 0..32 {
385            let bit_offset = i * 5;
386            let byte_offset = bit_offset / 8;
387            let bit_shift = bit_offset % 8;
388            let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
389            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
390                v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
391            }
392            result[i] = v;
393        }
394        result
395    }
396
397    /// Returns byte representation.
398    pub fn to_bytes(&self) -> Vec<u8> {
399        let mut bytes = Vec::with_capacity(24);
400        bytes.extend_from_slice(&self.scale.to_le_bytes());
401        bytes.extend_from_slice(&self.min.to_le_bytes());
402        bytes.extend_from_slice(&self.data);
403        bytes
404    }
405}
406
407// =============================================================================
408// Generic Quantized Block
409// =============================================================================
410
411/// Generic quantized block enum.
412#[derive(Debug, Clone)]
413pub enum QuantizedBlock {
414    /// Q8_0 block.
415    Q8(Q8Block),
416    /// Q4_0 block.
417    Q4(Q4Block),
418    /// Q4_1 block.
419    Q4_1(Q4_1Block),
420    /// Q5_0 block (5-bit symmetric).
421    Q5(Q5Block),
422    /// Q5_1 block (5-bit asymmetric).
423    Q5_1(Q5_1Block),
424    /// F16 values (block size 1).
425    F16(Vec<f16>),
426    /// F32 values (original).
427    F32(Vec<f32>),
428}
429
430impl QuantizedBlock {
431    /// Returns the quantization type of this block.
432    pub fn quant_type(&self) -> QuantType {
433        match self {
434            QuantizedBlock::Q8(_) => QuantType::Q8_0,
435            QuantizedBlock::Q4(_) => QuantType::Q4_0,
436            QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
437            QuantizedBlock::Q5(_) => QuantType::Q5_0,
438            QuantizedBlock::Q5_1(_) => QuantType::Q5_1,
439            QuantizedBlock::F16(_) => QuantType::F16,
440            QuantizedBlock::F32(_) => QuantType::F32,
441        }
442    }
443}
444
445// =============================================================================
446// Quantized Tensor
447// =============================================================================
448
449/// A quantized tensor containing compressed weight data.
450#[derive(Debug, Clone)]
451pub struct QuantizedTensor {
452    /// Original tensor shape.
453    pub shape: Vec<usize>,
454    /// Quantization type.
455    pub quant_type: QuantType,
456    /// Quantized data blocks.
457    pub blocks: Vec<QuantizedBlock>,
458    /// Number of elements.
459    pub numel: usize,
460}
461
462impl QuantizedTensor {
463    /// Creates a new quantized tensor.
464    pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
465        let numel = shape.iter().product();
466        Self {
467            shape,
468            quant_type,
469            blocks,
470            numel,
471        }
472    }
473
474    /// Returns the memory size in bytes.
475    pub fn size_bytes(&self) -> usize {
476        self.blocks.len() * self.quant_type.bytes_per_block()
477    }
478
479    /// Returns the compression ratio compared to F32.
480    pub fn compression_ratio(&self) -> f32 {
481        let original_bytes = self.numel * 4;
482        original_bytes as f32 / self.size_bytes() as f32
483    }
484
485    /// Returns the number of blocks.
486    pub fn num_blocks(&self) -> usize {
487        self.blocks.len()
488    }
489}
490
491// =============================================================================
492// Tests
493// =============================================================================
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn test_quant_type_properties() {
501        assert_eq!(QuantType::Q8_0.block_size(), 32);
502        assert_eq!(QuantType::Q4_0.block_size(), 32);
503        assert_eq!(QuantType::F16.block_size(), 1);
504
505        assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
506        assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
507
508        assert!(QuantType::Q8_0.is_block_quantized());
509        assert!(!QuantType::F16.is_block_quantized());
510    }
511
512    #[test]
513    fn test_quant_type_from_str() {
514        assert_eq!(QuantType::parse_type("Q8_0"), Some(QuantType::Q8_0));
515        assert_eq!(QuantType::parse_type("INT8"), Some(QuantType::Q8_0));
516        assert_eq!(QuantType::parse_type("Q4"), Some(QuantType::Q4_0));
517        assert_eq!(QuantType::parse_type("F16"), Some(QuantType::F16));
518        assert_eq!(QuantType::parse_type("invalid"), None);
519    }
520
521    #[test]
522    fn test_q4_pack_unpack() {
523        let values: [i8; 32] = [
524            -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1,
525            0, 1, 2, 3, 4, 5, 6, 7,
526        ];
527
528        let packed = Q4Block::pack(&values);
529        let block = Q4Block::new(f16::from_f32(1.0), packed);
530        let unpacked = block.unpack();
531
532        assert_eq!(values, unpacked);
533    }
534
535    #[test]
536    fn test_q8_block() {
537        let data = [0i8; 32];
538        let block = Q8Block::new(f16::from_f32(0.5), data);
539        let bytes = block.to_bytes();
540        let restored = Q8Block::from_bytes(&bytes).unwrap();
541
542        assert_eq!(block.scale, restored.scale);
543        assert_eq!(block.data, restored.data);
544    }
545}