Skip to main content

axonml_quant/
types.rs

1//! Quantization Types
2//!
3//! Defines quantization formats and data structures.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::fmt;
9use half::f16;
10
11// =============================================================================
12// Quantization Type Enum
13// =============================================================================
14
15/// Quantization type enumeration.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum QuantType {
18    /// 8-bit quantization with per-block scale.
19    /// Format: scale (f16) + 32 x int8
20    Q8_0,
21
22    /// 4-bit quantization with per-block scale.
23    /// Format: scale (f16) + 16 x uint8 (two 4-bit values each)
24    Q4_0,
25
26    /// 4-bit quantization with per-block scale and min.
27    /// Format: scale (f16) + min (f16) + 16 x uint8
28    Q4_1,
29
30    /// 5-bit quantization with per-block scale.
31    Q5_0,
32
33    /// 5-bit quantization with per-block scale and min.
34    Q5_1,
35
36    /// Half-precision (16-bit float).
37    F16,
38
39    /// Full precision (32-bit float).
40    F32,
41}
42
43impl QuantType {
44    /// Returns the block size for this quantization type.
45    pub fn block_size(&self) -> usize {
46        match self {
47            QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 |
48            QuantType::Q5_0 | QuantType::Q5_1 => 32,
49            QuantType::F16 | QuantType::F32 => 1,
50        }
51    }
52
53    /// Returns the number of bytes per block.
54    pub fn bytes_per_block(&self) -> usize {
55        match self {
56            QuantType::Q8_0 => 2 + 32,      // f16 scale + 32 int8
57            QuantType::Q4_0 => 2 + 16,      // f16 scale + 16 bytes (32 x 4-bit)
58            QuantType::Q4_1 => 4 + 16,      // f16 scale + f16 min + 16 bytes
59            QuantType::Q5_0 => 2 + 20,      // f16 scale + 20 bytes (32 x 5-bit)
60            QuantType::Q5_1 => 4 + 20,      // f16 scale + f16 min + 20 bytes
61            QuantType::F16 => 2,
62            QuantType::F32 => 4,
63        }
64    }
65
66    /// Returns the bits per value.
67    pub fn bits_per_value(&self) -> usize {
68        match self {
69            QuantType::Q8_0 => 8,
70            QuantType::Q4_0 | QuantType::Q4_1 => 4,
71            QuantType::Q5_0 | QuantType::Q5_1 => 5,
72            QuantType::F16 => 16,
73            QuantType::F32 => 32,
74        }
75    }
76
77    /// Returns the compression ratio compared to F32.
78    pub fn compression_ratio(&self) -> f32 {
79        32.0 / self.bits_per_value() as f32
80    }
81
82    /// Returns true if this type uses block quantization.
83    pub fn is_block_quantized(&self) -> bool {
84        matches!(self, QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 |
85                       QuantType::Q5_0 | QuantType::Q5_1)
86    }
87
88    /// Parses a quantization type from a string.
89    pub fn from_str(s: &str) -> Option<Self> {
90        match s.to_uppercase().as_str() {
91            "Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
92            "Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
93            "Q4_1" => Some(QuantType::Q4_1),
94            "Q5_0" | "Q5" => Some(QuantType::Q5_0),
95            "Q5_1" => Some(QuantType::Q5_1),
96            "F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
97            "F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
98            _ => None,
99        }
100    }
101}
102
103impl fmt::Display for QuantType {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        match self {
106            QuantType::Q8_0 => write!(f, "Q8_0"),
107            QuantType::Q4_0 => write!(f, "Q4_0"),
108            QuantType::Q4_1 => write!(f, "Q4_1"),
109            QuantType::Q5_0 => write!(f, "Q5_0"),
110            QuantType::Q5_1 => write!(f, "Q5_1"),
111            QuantType::F16 => write!(f, "F16"),
112            QuantType::F32 => write!(f, "F32"),
113        }
114    }
115}
116
117// =============================================================================
118// Quantized Block Structures
119// =============================================================================
120
121/// A block of Q8_0 quantized data.
122#[derive(Debug, Clone)]
123pub struct Q8Block {
124    /// Scale factor (stored as f16).
125    pub scale: f16,
126    /// Quantized values (32 x int8).
127    pub data: [i8; 32],
128}
129
130impl Q8Block {
131    /// Creates a new Q8 block.
132    pub fn new(scale: f16, data: [i8; 32]) -> Self {
133        Self { scale, data }
134    }
135
136    /// Returns the byte representation of this block.
137    pub fn to_bytes(&self) -> Vec<u8> {
138        let mut bytes = Vec::with_capacity(34);
139        bytes.extend_from_slice(&self.scale.to_le_bytes());
140        bytes.extend(self.data.iter().map(|&x| x as u8));
141        bytes
142    }
143
144    /// Creates a block from bytes.
145    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
146        if bytes.len() < 34 {
147            return None;
148        }
149        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
150        let mut data = [0i8; 32];
151        for (i, &b) in bytes[2..34].iter().enumerate() {
152            data[i] = b as i8;
153        }
154        Some(Self { scale, data })
155    }
156}
157
158/// A block of Q4_0 quantized data.
159#[derive(Debug, Clone)]
160pub struct Q4Block {
161    /// Scale factor (stored as f16).
162    pub scale: f16,
163    /// Packed quantized values (16 bytes = 32 x 4-bit).
164    pub data: [u8; 16],
165}
166
167impl Q4Block {
168    /// Creates a new Q4 block.
169    pub fn new(scale: f16, data: [u8; 16]) -> Self {
170        Self { scale, data }
171    }
172
173    /// Extracts the 4-bit values as i8 (range -8 to 7).
174    pub fn unpack(&self) -> [i8; 32] {
175        let mut result = [0i8; 32];
176        for i in 0..16 {
177            let byte = self.data[i];
178            result[i * 2] = ((byte & 0x0F) as i8) - 8;
179            result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
180        }
181        result
182    }
183
184    /// Packs 32 i8 values (-8 to 7 range) into 16 bytes.
185    pub fn pack(values: &[i8; 32]) -> [u8; 16] {
186        let mut data = [0u8; 16];
187        for i in 0..16 {
188            let low = ((values[i * 2] + 8) as u8) & 0x0F;
189            let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
190            data[i] = low | (high << 4);
191        }
192        data
193    }
194
195    /// Returns the byte representation of this block.
196    pub fn to_bytes(&self) -> Vec<u8> {
197        let mut bytes = Vec::with_capacity(18);
198        bytes.extend_from_slice(&self.scale.to_le_bytes());
199        bytes.extend_from_slice(&self.data);
200        bytes
201    }
202
203    /// Creates a block from bytes.
204    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
205        if bytes.len() < 18 {
206            return None;
207        }
208        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
209        let mut data = [0u8; 16];
210        data.copy_from_slice(&bytes[2..18]);
211        Some(Self { scale, data })
212    }
213}
214
215/// A block of Q4_1 quantized data (with min value).
216#[derive(Debug, Clone)]
217pub struct Q4_1Block {
218    /// Scale factor (stored as f16).
219    pub scale: f16,
220    /// Minimum value (stored as f16).
221    pub min: f16,
222    /// Packed quantized values (16 bytes = 32 x 4-bit).
223    pub data: [u8; 16],
224}
225
226impl Q4_1Block {
227    /// Creates a new Q4_1 block.
228    pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
229        Self { scale, min, data }
230    }
231
232    /// Extracts the 4-bit values as u8 (range 0 to 15).
233    pub fn unpack(&self) -> [u8; 32] {
234        let mut result = [0u8; 32];
235        for i in 0..16 {
236            let byte = self.data[i];
237            result[i * 2] = byte & 0x0F;
238            result[i * 2 + 1] = byte >> 4;
239        }
240        result
241    }
242
243    /// Returns the byte representation of this block.
244    pub fn to_bytes(&self) -> Vec<u8> {
245        let mut bytes = Vec::with_capacity(20);
246        bytes.extend_from_slice(&self.scale.to_le_bytes());
247        bytes.extend_from_slice(&self.min.to_le_bytes());
248        bytes.extend_from_slice(&self.data);
249        bytes
250    }
251}
252
253// =============================================================================
254// Generic Quantized Block
255// =============================================================================
256
257/// Generic quantized block enum.
258#[derive(Debug, Clone)]
259pub enum QuantizedBlock {
260    /// Q8_0 block.
261    Q8(Q8Block),
262    /// Q4_0 block.
263    Q4(Q4Block),
264    /// Q4_1 block.
265    Q4_1(Q4_1Block),
266    /// F16 values (block size 1).
267    F16(Vec<f16>),
268    /// F32 values (original).
269    F32(Vec<f32>),
270}
271
272impl QuantizedBlock {
273    /// Returns the quantization type of this block.
274    pub fn quant_type(&self) -> QuantType {
275        match self {
276            QuantizedBlock::Q8(_) => QuantType::Q8_0,
277            QuantizedBlock::Q4(_) => QuantType::Q4_0,
278            QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
279            QuantizedBlock::F16(_) => QuantType::F16,
280            QuantizedBlock::F32(_) => QuantType::F32,
281        }
282    }
283}
284
285// =============================================================================
286// Quantized Tensor
287// =============================================================================
288
289/// A quantized tensor containing compressed weight data.
290#[derive(Debug, Clone)]
291pub struct QuantizedTensor {
292    /// Original tensor shape.
293    pub shape: Vec<usize>,
294    /// Quantization type.
295    pub quant_type: QuantType,
296    /// Quantized data blocks.
297    pub blocks: Vec<QuantizedBlock>,
298    /// Number of elements.
299    pub numel: usize,
300}
301
302impl QuantizedTensor {
303    /// Creates a new quantized tensor.
304    pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
305        let numel = shape.iter().product();
306        Self {
307            shape,
308            quant_type,
309            blocks,
310            numel,
311        }
312    }
313
314    /// Returns the memory size in bytes.
315    pub fn size_bytes(&self) -> usize {
316        self.blocks.len() * self.quant_type.bytes_per_block()
317    }
318
319    /// Returns the compression ratio compared to F32.
320    pub fn compression_ratio(&self) -> f32 {
321        let original_bytes = self.numel * 4;
322        original_bytes as f32 / self.size_bytes() as f32
323    }
324
325    /// Returns the number of blocks.
326    pub fn num_blocks(&self) -> usize {
327        self.blocks.len()
328    }
329}
330
331// =============================================================================
332// Tests
333// =============================================================================
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_quant_type_properties() {
341        assert_eq!(QuantType::Q8_0.block_size(), 32);
342        assert_eq!(QuantType::Q4_0.block_size(), 32);
343        assert_eq!(QuantType::F16.block_size(), 1);
344
345        assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
346        assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
347
348        assert!(QuantType::Q8_0.is_block_quantized());
349        assert!(!QuantType::F16.is_block_quantized());
350    }
351
352    #[test]
353    fn test_quant_type_from_str() {
354        assert_eq!(QuantType::from_str("Q8_0"), Some(QuantType::Q8_0));
355        assert_eq!(QuantType::from_str("INT8"), Some(QuantType::Q8_0));
356        assert_eq!(QuantType::from_str("Q4"), Some(QuantType::Q4_0));
357        assert_eq!(QuantType::from_str("F16"), Some(QuantType::F16));
358        assert_eq!(QuantType::from_str("invalid"), None);
359    }
360
361    #[test]
362    fn test_q4_pack_unpack() {
363        let values: [i8; 32] = [
364            -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7,
365            -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7,
366        ];
367
368        let packed = Q4Block::pack(&values);
369        let block = Q4Block::new(f16::from_f32(1.0), packed);
370        let unpacked = block.unpack();
371
372        assert_eq!(values, unpacked);
373    }
374
375    #[test]
376    fn test_q8_block() {
377        let data = [0i8; 32];
378        let block = Q8Block::new(f16::from_f32(0.5), data);
379        let bytes = block.to_bytes();
380        let restored = Q8Block::from_bytes(&bytes).unwrap();
381
382        assert_eq!(block.scale, restored.scale);
383        assert_eq!(block.data, restored.data);
384    }
385}