Skip to main content

axonml_quant/
types.rs

1//! Quantization Types
2//!
3//! Defines quantization formats and data structures.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use half::f16;
9use std::fmt;
10
11// =============================================================================
12// Quantization Type Enum
13// =============================================================================
14
15/// Quantization type enumeration.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum QuantType {
18    /// 8-bit quantization with per-block scale.
19    /// Format: scale (f16) + 32 x int8
20    Q8_0,
21
22    /// 4-bit quantization with per-block scale.
23    /// Format: scale (f16) + 16 x uint8 (two 4-bit values each)
24    Q4_0,
25
26    /// 4-bit quantization with per-block scale and min.
27    /// Format: scale (f16) + min (f16) + 16 x uint8
28    Q4_1,
29
30    /// 5-bit quantization with per-block scale.
31    Q5_0,
32
33    /// 5-bit quantization with per-block scale and min.
34    Q5_1,
35
36    /// Half-precision (16-bit float).
37    F16,
38
39    /// Full precision (32-bit float).
40    F32,
41}
42
43impl QuantType {
44    /// Returns the block size for this quantization type.
45    pub fn block_size(&self) -> usize {
46        match self {
47            QuantType::Q8_0
48            | QuantType::Q4_0
49            | QuantType::Q4_1
50            | QuantType::Q5_0
51            | QuantType::Q5_1 => 32,
52            QuantType::F16 | QuantType::F32 => 1,
53        }
54    }
55
56    /// Returns the number of bytes per block.
57    pub fn bytes_per_block(&self) -> usize {
58        match self {
59            QuantType::Q8_0 => 2 + 32, // f16 scale + 32 int8
60            QuantType::Q4_0 => 2 + 16, // f16 scale + 16 bytes (32 x 4-bit)
61            QuantType::Q4_1 => 4 + 16, // f16 scale + f16 min + 16 bytes
62            QuantType::Q5_0 => 2 + 20, // f16 scale + 20 bytes (32 x 5-bit)
63            QuantType::Q5_1 => 4 + 20, // f16 scale + f16 min + 20 bytes
64            QuantType::F16 => 2,
65            QuantType::F32 => 4,
66        }
67    }
68
69    /// Returns the bits per value.
70    pub fn bits_per_value(&self) -> usize {
71        match self {
72            QuantType::Q8_0 => 8,
73            QuantType::Q4_0 | QuantType::Q4_1 => 4,
74            QuantType::Q5_0 | QuantType::Q5_1 => 5,
75            QuantType::F16 => 16,
76            QuantType::F32 => 32,
77        }
78    }
79
80    /// Returns the compression ratio compared to F32.
81    pub fn compression_ratio(&self) -> f32 {
82        32.0 / self.bits_per_value() as f32
83    }
84
85    /// Returns true if this type uses block quantization.
86    pub fn is_block_quantized(&self) -> bool {
87        matches!(
88            self,
89            QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q5_0 | QuantType::Q5_1
90        )
91    }
92
93    /// Parses a quantization type from a string.
94    pub fn from_str(s: &str) -> Option<Self> {
95        match s.to_uppercase().as_str() {
96            "Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
97            "Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
98            "Q4_1" => Some(QuantType::Q4_1),
99            "Q5_0" | "Q5" => Some(QuantType::Q5_0),
100            "Q5_1" => Some(QuantType::Q5_1),
101            "F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
102            "F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
103            _ => None,
104        }
105    }
106}
107
108impl fmt::Display for QuantType {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        match self {
111            QuantType::Q8_0 => write!(f, "Q8_0"),
112            QuantType::Q4_0 => write!(f, "Q4_0"),
113            QuantType::Q4_1 => write!(f, "Q4_1"),
114            QuantType::Q5_0 => write!(f, "Q5_0"),
115            QuantType::Q5_1 => write!(f, "Q5_1"),
116            QuantType::F16 => write!(f, "F16"),
117            QuantType::F32 => write!(f, "F32"),
118        }
119    }
120}
121
122// =============================================================================
123// Quantized Block Structures
124// =============================================================================
125
126/// A block of Q8_0 quantized data.
127#[derive(Debug, Clone)]
128pub struct Q8Block {
129    /// Scale factor (stored as f16).
130    pub scale: f16,
131    /// Quantized values (32 x int8).
132    pub data: [i8; 32],
133}
134
135impl Q8Block {
136    /// Creates a new Q8 block.
137    pub fn new(scale: f16, data: [i8; 32]) -> Self {
138        Self { scale, data }
139    }
140
141    /// Returns the byte representation of this block.
142    pub fn to_bytes(&self) -> Vec<u8> {
143        let mut bytes = Vec::with_capacity(34);
144        bytes.extend_from_slice(&self.scale.to_le_bytes());
145        bytes.extend(self.data.iter().map(|&x| x as u8));
146        bytes
147    }
148
149    /// Creates a block from bytes.
150    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
151        if bytes.len() < 34 {
152            return None;
153        }
154        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
155        let mut data = [0i8; 32];
156        for (i, &b) in bytes[2..34].iter().enumerate() {
157            data[i] = b as i8;
158        }
159        Some(Self { scale, data })
160    }
161}
162
163/// A block of Q4_0 quantized data.
164#[derive(Debug, Clone)]
165pub struct Q4Block {
166    /// Scale factor (stored as f16).
167    pub scale: f16,
168    /// Packed quantized values (16 bytes = 32 x 4-bit).
169    pub data: [u8; 16],
170}
171
172impl Q4Block {
173    /// Creates a new Q4 block.
174    pub fn new(scale: f16, data: [u8; 16]) -> Self {
175        Self { scale, data }
176    }
177
178    /// Extracts the 4-bit values as i8 (range -8 to 7).
179    pub fn unpack(&self) -> [i8; 32] {
180        let mut result = [0i8; 32];
181        for i in 0..16 {
182            let byte = self.data[i];
183            result[i * 2] = ((byte & 0x0F) as i8) - 8;
184            result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
185        }
186        result
187    }
188
189    /// Packs 32 i8 values (-8 to 7 range) into 16 bytes.
190    pub fn pack(values: &[i8; 32]) -> [u8; 16] {
191        let mut data = [0u8; 16];
192        for i in 0..16 {
193            let low = ((values[i * 2] + 8) as u8) & 0x0F;
194            let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
195            data[i] = low | (high << 4);
196        }
197        data
198    }
199
200    /// Returns the byte representation of this block.
201    pub fn to_bytes(&self) -> Vec<u8> {
202        let mut bytes = Vec::with_capacity(18);
203        bytes.extend_from_slice(&self.scale.to_le_bytes());
204        bytes.extend_from_slice(&self.data);
205        bytes
206    }
207
208    /// Creates a block from bytes.
209    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
210        if bytes.len() < 18 {
211            return None;
212        }
213        let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
214        let mut data = [0u8; 16];
215        data.copy_from_slice(&bytes[2..18]);
216        Some(Self { scale, data })
217    }
218}
219
220/// A block of Q4_1 quantized data (with min value).
221#[derive(Debug, Clone)]
222pub struct Q4_1Block {
223    /// Scale factor (stored as f16).
224    pub scale: f16,
225    /// Minimum value (stored as f16).
226    pub min: f16,
227    /// Packed quantized values (16 bytes = 32 x 4-bit).
228    pub data: [u8; 16],
229}
230
231impl Q4_1Block {
232    /// Creates a new Q4_1 block.
233    pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
234        Self { scale, min, data }
235    }
236
237    /// Extracts the 4-bit values as u8 (range 0 to 15).
238    pub fn unpack(&self) -> [u8; 32] {
239        let mut result = [0u8; 32];
240        for i in 0..16 {
241            let byte = self.data[i];
242            result[i * 2] = byte & 0x0F;
243            result[i * 2 + 1] = byte >> 4;
244        }
245        result
246    }
247
248    /// Returns the byte representation of this block.
249    pub fn to_bytes(&self) -> Vec<u8> {
250        let mut bytes = Vec::with_capacity(20);
251        bytes.extend_from_slice(&self.scale.to_le_bytes());
252        bytes.extend_from_slice(&self.min.to_le_bytes());
253        bytes.extend_from_slice(&self.data);
254        bytes
255    }
256}
257
258// =============================================================================
259// Generic Quantized Block
260// =============================================================================
261
262/// Generic quantized block enum.
263#[derive(Debug, Clone)]
264pub enum QuantizedBlock {
265    /// Q8_0 block.
266    Q8(Q8Block),
267    /// Q4_0 block.
268    Q4(Q4Block),
269    /// Q4_1 block.
270    Q4_1(Q4_1Block),
271    /// F16 values (block size 1).
272    F16(Vec<f16>),
273    /// F32 values (original).
274    F32(Vec<f32>),
275}
276
277impl QuantizedBlock {
278    /// Returns the quantization type of this block.
279    pub fn quant_type(&self) -> QuantType {
280        match self {
281            QuantizedBlock::Q8(_) => QuantType::Q8_0,
282            QuantizedBlock::Q4(_) => QuantType::Q4_0,
283            QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
284            QuantizedBlock::F16(_) => QuantType::F16,
285            QuantizedBlock::F32(_) => QuantType::F32,
286        }
287    }
288}
289
290// =============================================================================
291// Quantized Tensor
292// =============================================================================
293
294/// A quantized tensor containing compressed weight data.
295#[derive(Debug, Clone)]
296pub struct QuantizedTensor {
297    /// Original tensor shape.
298    pub shape: Vec<usize>,
299    /// Quantization type.
300    pub quant_type: QuantType,
301    /// Quantized data blocks.
302    pub blocks: Vec<QuantizedBlock>,
303    /// Number of elements.
304    pub numel: usize,
305}
306
307impl QuantizedTensor {
308    /// Creates a new quantized tensor.
309    pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
310        let numel = shape.iter().product();
311        Self {
312            shape,
313            quant_type,
314            blocks,
315            numel,
316        }
317    }
318
319    /// Returns the memory size in bytes.
320    pub fn size_bytes(&self) -> usize {
321        self.blocks.len() * self.quant_type.bytes_per_block()
322    }
323
324    /// Returns the compression ratio compared to F32.
325    pub fn compression_ratio(&self) -> f32 {
326        let original_bytes = self.numel * 4;
327        original_bytes as f32 / self.size_bytes() as f32
328    }
329
330    /// Returns the number of blocks.
331    pub fn num_blocks(&self) -> usize {
332        self.blocks.len()
333    }
334}
335
336// =============================================================================
337// Tests
338// =============================================================================
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn test_quant_type_properties() {
346        assert_eq!(QuantType::Q8_0.block_size(), 32);
347        assert_eq!(QuantType::Q4_0.block_size(), 32);
348        assert_eq!(QuantType::F16.block_size(), 1);
349
350        assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
351        assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
352
353        assert!(QuantType::Q8_0.is_block_quantized());
354        assert!(!QuantType::F16.is_block_quantized());
355    }
356
357    #[test]
358    fn test_quant_type_from_str() {
359        assert_eq!(QuantType::from_str("Q8_0"), Some(QuantType::Q8_0));
360        assert_eq!(QuantType::from_str("INT8"), Some(QuantType::Q8_0));
361        assert_eq!(QuantType::from_str("Q4"), Some(QuantType::Q4_0));
362        assert_eq!(QuantType::from_str("F16"), Some(QuantType::F16));
363        assert_eq!(QuantType::from_str("invalid"), None);
364    }
365
366    #[test]
367    fn test_q4_pack_unpack() {
368        let values: [i8; 32] = [
369            -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1,
370            0, 1, 2, 3, 4, 5, 6, 7,
371        ];
372
373        let packed = Q4Block::pack(&values);
374        let block = Q4Block::new(f16::from_f32(1.0), packed);
375        let unpacked = block.unpack();
376
377        assert_eq!(values, unpacked);
378    }
379
380    #[test]
381    fn test_q8_block() {
382        let data = [0i8; 32];
383        let block = Q8Block::new(f16::from_f32(0.5), data);
384        let bytes = block.to_bytes();
385        let restored = Q8Block::from_bytes(&bytes).unwrap();
386
387        assert_eq!(block.scale, restored.scale);
388        assert_eq!(block.data, restored.data);
389    }
390}