Skip to main content

axonml_quant/
types.rs

1//! Quantization Types
2//!
3//! # File
4//! `crates/axonml-quant/src/types.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use half::f16;
18use std::fmt;
19
20// =============================================================================
21// Quantization Type Enum
22// =============================================================================
23
24/// Quantization type enumeration.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum QuantType {
27    /// 8-bit quantization with per-block scale.
28    /// Format: scale (f16) + 32 x int8
29    Q8_0,
30
31    /// 4-bit quantization with per-block scale.
32    /// Format: scale (f16) + 16 x uint8 (two 4-bit values each)
33    Q4_0,
34
35    /// 4-bit quantization with per-block scale and min.
36    /// Format: scale (f16) + min (f16) + 16 x uint8
37    Q4_1,
38
39    /// 5-bit quantization with per-block scale.
40    Q5_0,
41
42    /// 5-bit quantization with per-block scale and min.
43    Q5_1,
44
45    /// Half-precision (16-bit float).
46    F16,
47
48    /// Full precision (32-bit float).
49    F32,
50}
51
52impl QuantType {
53    /// Returns the block size for this quantization type.
54    pub fn block_size(&self) -> usize {
55        match self {
56            QuantType::Q8_0
57            | QuantType::Q4_0
58            | QuantType::Q4_1
59            | QuantType::Q5_0
60            | QuantType::Q5_1 => 32,
61            QuantType::F16 | QuantType::F32 => 1,
62        }
63    }
64
65    /// Returns the number of bytes per block.
66    pub fn bytes_per_block(&self) -> usize {
67        match self {
68            QuantType::Q8_0 => 2 + 32, // f16 scale + 32 int8
69            QuantType::Q4_0 => 2 + 16, // f16 scale + 16 bytes (32 x 4-bit)
70            QuantType::Q4_1 => 4 + 16, // f16 scale + f16 min + 16 bytes
71            QuantType::Q5_0 => 2 + 20, // f16 scale + 20 bytes (32 x 5-bit)
72            QuantType::Q5_1 => 4 + 20, // f16 scale + f16 min + 20 bytes
73            QuantType::F16 => 2,
74            QuantType::F32 => 4,
75        }
76    }
77
78    /// Returns the bits per value.
79    pub fn bits_per_value(&self) -> usize {
80        match self {
81            QuantType::Q8_0 => 8,
82            QuantType::Q4_0 | QuantType::Q4_1 => 4,
83            QuantType::Q5_0 | QuantType::Q5_1 => 5,
84            QuantType::F16 => 16,
85            QuantType::F32 => 32,
86        }
87    }
88
89    /// Returns the compression ratio compared to F32.
90    pub fn compression_ratio(&self) -> f32 {
91        32.0 / self.bits_per_value() as f32
92    }
93
94    /// Returns true if this type uses block quantization.
95    pub fn is_block_quantized(&self) -> bool {
96        matches!(
97            self,
98            QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q5_0 | QuantType::Q5_1
99        )
100    }
101
102    /// Parses a quantization type from a string.
103    pub fn parse_type(s: &str) -> Option<Self> {
104        match s.to_uppercase().as_str() {
105            "Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
106            "Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
107            "Q4_1" => Some(QuantType::Q4_1),
108            "Q5_0" | "Q5" => Some(QuantType::Q5_0),
109            "Q5_1" => Some(QuantType::Q5_1),
110            "F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
111            "F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
112            _ => None,
113        }
114    }
115}
116
117impl std::str::FromStr for QuantType {
118    type Err = String;
119
120    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
121        Self::parse_type(s).ok_or_else(|| format!("Unknown quant type: '{s}'"))
122    }
123}
124
125impl fmt::Display for QuantType {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        match self {
128            QuantType::Q8_0 => write!(f, "Q8_0"),
129            QuantType::Q4_0 => write!(f, "Q4_0"),
130            QuantType::Q4_1 => write!(f, "Q4_1"),
131            QuantType::Q5_0 => write!(f, "Q5_0"),
132            QuantType::Q5_1 => write!(f, "Q5_1"),
133            QuantType::F16 => write!(f, "F16"),
134            QuantType::F32 => write!(f, "F32"),
135        }
136    }
137}
138
139// =============================================================================
140// Quantized Block Structures
141// =============================================================================
142
143/// A block of Q8_0 quantized data.
144#[derive(Debug, Clone)]
145pub struct Q8Block {
146    /// Scale factor (stored as f16).
147    pub scale: f16,
148    /// Quantized values (32 x int8).
149    pub data: [i8; 32],
150}
151
152impl Q8Block {
153    /// Creates a new Q8 block.
154    pub fn new(scale: f16, data: [i8; 32]) -> Self {
155        Self { scale, data }
156    }
157
158    /// Returns the byte representation of this block.
159    pub fn to_bytes(&self) -> Vec<u8> {
160        let mut bytes = Vec::with_capacity(34);
161        bytes.extend_from_slice(&self.scale.to_le_bytes());
162        bytes.extend(self.data.iter().map(|&x| x as u8));
163        bytes
164    }
165
166    /// Creates a block from bytes.
167    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
168        if bytes.len() < 34 {
169            return None;
170        }
171        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
172        let mut data = [0i8; 32];
173        for (i, &b) in bytes[2..34].iter().enumerate() {
174            data[i] = b as i8;
175        }
176        Some(Self { scale, data })
177    }
178}
179
180/// A block of Q4_0 quantized data.
181#[derive(Debug, Clone)]
182pub struct Q4Block {
183    /// Scale factor (stored as f16).
184    pub scale: f16,
185    /// Packed quantized values (16 bytes = 32 x 4-bit).
186    pub data: [u8; 16],
187}
188
189impl Q4Block {
190    /// Creates a new Q4 block.
191    pub fn new(scale: f16, data: [u8; 16]) -> Self {
192        Self { scale, data }
193    }
194
195    /// Extracts the 4-bit values as i8 (range -8 to 7).
196    pub fn unpack(&self) -> [i8; 32] {
197        let mut result = [0i8; 32];
198        for i in 0..16 {
199            let byte = self.data[i];
200            result[i * 2] = ((byte & 0x0F) as i8) - 8;
201            result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
202        }
203        result
204    }
205
206    /// Packs 32 i8 values (-8 to 7 range) into 16 bytes.
207    pub fn pack(values: &[i8; 32]) -> [u8; 16] {
208        let mut data = [0u8; 16];
209        for i in 0..16 {
210            let low = ((values[i * 2] + 8) as u8) & 0x0F;
211            let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
212            data[i] = low | (high << 4);
213        }
214        data
215    }
216
217    /// Returns the byte representation of this block.
218    pub fn to_bytes(&self) -> Vec<u8> {
219        let mut bytes = Vec::with_capacity(18);
220        bytes.extend_from_slice(&self.scale.to_le_bytes());
221        bytes.extend_from_slice(&self.data);
222        bytes
223    }
224
225    /// Creates a block from bytes.
226    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
227        if bytes.len() < 18 {
228            return None;
229        }
230        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
231        let mut data = [0u8; 16];
232        data.copy_from_slice(&bytes[2..18]);
233        Some(Self { scale, data })
234    }
235}
236
237/// A block of Q4_1 quantized data (with min value).
238#[derive(Debug, Clone)]
239pub struct Q4_1Block {
240    /// Scale factor (stored as f16).
241    pub scale: f16,
242    /// Minimum value (stored as f16).
243    pub min: f16,
244    /// Packed quantized values (16 bytes = 32 x 4-bit).
245    pub data: [u8; 16],
246}
247
248impl Q4_1Block {
249    /// Creates a new Q4_1 block.
250    pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
251        Self { scale, min, data }
252    }
253
254    /// Extracts the 4-bit values as u8 (range 0 to 15).
255    pub fn unpack(&self) -> [u8; 32] {
256        let mut result = [0u8; 32];
257        for i in 0..16 {
258            let byte = self.data[i];
259            result[i * 2] = byte & 0x0F;
260            result[i * 2 + 1] = byte >> 4;
261        }
262        result
263    }
264
265    /// Returns the byte representation of this block.
266    pub fn to_bytes(&self) -> Vec<u8> {
267        let mut bytes = Vec::with_capacity(20);
268        bytes.extend_from_slice(&self.scale.to_le_bytes());
269        bytes.extend_from_slice(&self.min.to_le_bytes());
270        bytes.extend_from_slice(&self.data);
271        bytes
272    }
273}
274
275// =============================================================================
276// Q5_0 Block (5-bit symmetric)
277// =============================================================================
278
279/// A block of Q5_0 quantized data (5-bit with per-block scale).
280///
281/// 32 values × 5 bits = 160 bits = 20 bytes of packed data + f16 scale.
282#[derive(Debug, Clone)]
283pub struct Q5Block {
284    /// Scale factor (stored as f16).
285    pub scale: f16,
286    /// Packed quantized values (20 bytes = 32 × 5-bit).
287    pub data: [u8; 20],
288}
289
290impl Q5Block {
291    /// Creates a new Q5_0 block.
292    pub fn new(scale: f16, data: [u8; 20]) -> Self {
293        Self { scale, data }
294    }
295
296    /// Packs 32 signed 5-bit values (range -16 to 15) into 20 bytes.
297    pub fn pack(values: &[i8; 32]) -> [u8; 20] {
298        let mut packed = [0u8; 20];
299        // Pack 32 × 5-bit values: 8 groups of 4 values → 20 bits each → 2.5 bytes
300        // Simpler: treat as 160-bit bitstream
301        #[allow(clippy::needless_range_loop)]
302        for i in 0..32 {
303            let v = (values[i] as u8) & 0x1F; // 5-bit unsigned representation
304            let bit_offset = i * 5;
305            let byte_offset = bit_offset / 8;
306            let bit_shift = bit_offset % 8;
307            packed[byte_offset] |= v << bit_shift;
308            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
309                packed[byte_offset + 1] |= v >> (8 - bit_shift);
310            }
311        }
312        packed
313    }
314
315    /// Unpacks 20 bytes into 32 signed 5-bit values.
316    pub fn unpack(&self) -> [i8; 32] {
317        let mut result = [0i8; 32];
318        #[allow(clippy::needless_range_loop)]
319        for i in 0..32 {
320            let bit_offset = i * 5;
321            let byte_offset = bit_offset / 8;
322            let bit_shift = bit_offset % 8;
323            let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
324            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
325                v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
326            }
327            // Sign extend: if bit 4 is set, value is negative
328            if v & 0x10 != 0 {
329                result[i] = (v | 0xE0) as i8; // sign extend to i8
330            } else {
331                result[i] = v as i8;
332            }
333        }
334        result
335    }
336
337    /// Returns byte representation.
338    pub fn to_bytes(&self) -> Vec<u8> {
339        let mut bytes = Vec::with_capacity(22);
340        bytes.extend_from_slice(&self.scale.to_le_bytes());
341        bytes.extend_from_slice(&self.data);
342        bytes
343    }
344}
345
346// =============================================================================
347// Q5_1 Block (5-bit asymmetric)
348// =============================================================================
349
350/// A block of Q5_1 quantized data (5-bit with per-block scale and min).
351#[derive(Debug, Clone)]
352pub struct Q5_1Block {
353    /// Scale factor (stored as f16).
354    pub scale: f16,
355    /// Minimum value (stored as f16).
356    pub min: f16,
357    /// Packed quantized values (20 bytes = 32 × 5-bit unsigned).
358    pub data: [u8; 20],
359}
360
361impl Q5_1Block {
362    /// Creates a new Q5_1 block.
363    pub fn new(scale: f16, min: f16, data: [u8; 20]) -> Self {
364        Self { scale, min, data }
365    }
366
367    /// Packs 32 unsigned 5-bit values (range 0 to 31) into 20 bytes.
368    pub fn pack(values: &[u8; 32]) -> [u8; 20] {
369        let mut packed = [0u8; 20];
370        #[allow(clippy::needless_range_loop)]
371        for i in 0..32 {
372            let v = values[i] & 0x1F;
373            let bit_offset = i * 5;
374            let byte_offset = bit_offset / 8;
375            let bit_shift = bit_offset % 8;
376            packed[byte_offset] |= v << bit_shift;
377            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
378                packed[byte_offset + 1] |= v >> (8 - bit_shift);
379            }
380        }
381        packed
382    }
383
384    /// Unpacks 20 bytes into 32 unsigned 5-bit values.
385    pub fn unpack(&self) -> [u8; 32] {
386        let mut result = [0u8; 32];
387        #[allow(clippy::needless_range_loop)]
388        for i in 0..32 {
389            let bit_offset = i * 5;
390            let byte_offset = bit_offset / 8;
391            let bit_shift = bit_offset % 8;
392            let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
393            if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
394                v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
395            }
396            result[i] = v;
397        }
398        result
399    }
400
401    /// Returns byte representation.
402    pub fn to_bytes(&self) -> Vec<u8> {
403        let mut bytes = Vec::with_capacity(24);
404        bytes.extend_from_slice(&self.scale.to_le_bytes());
405        bytes.extend_from_slice(&self.min.to_le_bytes());
406        bytes.extend_from_slice(&self.data);
407        bytes
408    }
409}
410
411// =============================================================================
412// Generic Quantized Block
413// =============================================================================
414
415/// Generic quantized block enum.
416#[derive(Debug, Clone)]
417pub enum QuantizedBlock {
418    /// Q8_0 block.
419    Q8(Q8Block),
420    /// Q4_0 block.
421    Q4(Q4Block),
422    /// Q4_1 block.
423    Q4_1(Q4_1Block),
424    /// Q5_0 block (5-bit symmetric).
425    Q5(Q5Block),
426    /// Q5_1 block (5-bit asymmetric).
427    Q5_1(Q5_1Block),
428    /// F16 values (block size 1).
429    F16(Vec<f16>),
430    /// F32 values (original).
431    F32(Vec<f32>),
432}
433
434impl QuantizedBlock {
435    /// Returns the quantization type of this block.
436    pub fn quant_type(&self) -> QuantType {
437        match self {
438            QuantizedBlock::Q8(_) => QuantType::Q8_0,
439            QuantizedBlock::Q4(_) => QuantType::Q4_0,
440            QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
441            QuantizedBlock::Q5(_) => QuantType::Q5_0,
442            QuantizedBlock::Q5_1(_) => QuantType::Q5_1,
443            QuantizedBlock::F16(_) => QuantType::F16,
444            QuantizedBlock::F32(_) => QuantType::F32,
445        }
446    }
447}
448
449// =============================================================================
450// Quantized Tensor
451// =============================================================================
452
453/// A quantized tensor containing compressed weight data.
454#[derive(Debug, Clone)]
455pub struct QuantizedTensor {
456    /// Original tensor shape.
457    pub shape: Vec<usize>,
458    /// Quantization type.
459    pub quant_type: QuantType,
460    /// Quantized data blocks.
461    pub blocks: Vec<QuantizedBlock>,
462    /// Number of elements.
463    pub numel: usize,
464}
465
466impl QuantizedTensor {
467    /// Creates a new quantized tensor.
468    pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
469        let numel = shape.iter().product();
470        Self {
471            shape,
472            quant_type,
473            blocks,
474            numel,
475        }
476    }
477
478    /// Returns the memory size in bytes.
479    pub fn size_bytes(&self) -> usize {
480        self.blocks.len() * self.quant_type.bytes_per_block()
481    }
482
483    /// Returns the compression ratio compared to F32.
484    pub fn compression_ratio(&self) -> f32 {
485        let original_bytes = self.numel * 4;
486        original_bytes as f32 / self.size_bytes() as f32
487    }
488
489    /// Returns the number of blocks.
490    pub fn num_blocks(&self) -> usize {
491        self.blocks.len()
492    }
493}
494
495// =============================================================================
496// Tests
497// =============================================================================
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn test_quant_type_properties() {
505        assert_eq!(QuantType::Q8_0.block_size(), 32);
506        assert_eq!(QuantType::Q4_0.block_size(), 32);
507        assert_eq!(QuantType::F16.block_size(), 1);
508
509        assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
510        assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
511
512        assert!(QuantType::Q8_0.is_block_quantized());
513        assert!(!QuantType::F16.is_block_quantized());
514    }
515
516    #[test]
517    fn test_quant_type_from_str() {
518        assert_eq!(QuantType::parse_type("Q8_0"), Some(QuantType::Q8_0));
519        assert_eq!(QuantType::parse_type("INT8"), Some(QuantType::Q8_0));
520        assert_eq!(QuantType::parse_type("Q4"), Some(QuantType::Q4_0));
521        assert_eq!(QuantType::parse_type("F16"), Some(QuantType::F16));
522        assert_eq!(QuantType::parse_type("invalid"), None);
523    }
524
525    #[test]
526    fn test_q4_pack_unpack() {
527        let values: [i8; 32] = [
528            -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1,
529            0, 1, 2, 3, 4, 5, 6, 7,
530        ];
531
532        let packed = Q4Block::pack(&values);
533        let block = Q4Block::new(f16::from_f32(1.0), packed);
534        let unpacked = block.unpack();
535
536        assert_eq!(values, unpacked);
537    }
538
539    #[test]
540    fn test_q8_block() {
541        let data = [0i8; 32];
542        let block = Q8Block::new(f16::from_f32(0.5), data);
543        let bytes = block.to_bytes();
544        let restored = Q8Block::from_bytes(&bytes).unwrap();
545
546        assert_eq!(block.scale, restored.scale);
547        assert_eq!(block.data, restored.data);
548    }
549}