1use half::f16;
18use std::fmt;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum QuantType {
27 Q8_0,
30
31 Q4_0,
34
35 Q4_1,
38
39 Q5_0,
41
42 Q5_1,
44
45 F16,
47
48 F32,
50}
51
52impl QuantType {
53 pub fn block_size(&self) -> usize {
55 match self {
56 QuantType::Q8_0
57 | QuantType::Q4_0
58 | QuantType::Q4_1
59 | QuantType::Q5_0
60 | QuantType::Q5_1 => 32,
61 QuantType::F16 | QuantType::F32 => 1,
62 }
63 }
64
65 pub fn bytes_per_block(&self) -> usize {
67 match self {
68 QuantType::Q8_0 => 2 + 32, QuantType::Q4_0 => 2 + 16, QuantType::Q4_1 => 4 + 16, QuantType::Q5_0 => 2 + 20, QuantType::Q5_1 => 4 + 20, QuantType::F16 => 2,
74 QuantType::F32 => 4,
75 }
76 }
77
78 pub fn bits_per_value(&self) -> usize {
80 match self {
81 QuantType::Q8_0 => 8,
82 QuantType::Q4_0 | QuantType::Q4_1 => 4,
83 QuantType::Q5_0 | QuantType::Q5_1 => 5,
84 QuantType::F16 => 16,
85 QuantType::F32 => 32,
86 }
87 }
88
89 pub fn compression_ratio(&self) -> f32 {
91 32.0 / self.bits_per_value() as f32
92 }
93
94 pub fn is_block_quantized(&self) -> bool {
96 matches!(
97 self,
98 QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q5_0 | QuantType::Q5_1
99 )
100 }
101
102 #[allow(clippy::should_implement_trait)]
104 pub fn from_str(s: &str) -> Option<Self> {
105 match s.to_uppercase().as_str() {
106 "Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
107 "Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
108 "Q4_1" => Some(QuantType::Q4_1),
109 "Q5_0" | "Q5" => Some(QuantType::Q5_0),
110 "Q5_1" => Some(QuantType::Q5_1),
111 "F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
112 "F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
113 _ => None,
114 }
115 }
116}
117
118impl fmt::Display for QuantType {
119 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120 match self {
121 QuantType::Q8_0 => write!(f, "Q8_0"),
122 QuantType::Q4_0 => write!(f, "Q4_0"),
123 QuantType::Q4_1 => write!(f, "Q4_1"),
124 QuantType::Q5_0 => write!(f, "Q5_0"),
125 QuantType::Q5_1 => write!(f, "Q5_1"),
126 QuantType::F16 => write!(f, "F16"),
127 QuantType::F32 => write!(f, "F32"),
128 }
129 }
130}
131
132#[derive(Debug, Clone)]
138pub struct Q8Block {
139 pub scale: f16,
141 pub data: [i8; 32],
143}
144
145impl Q8Block {
146 pub fn new(scale: f16, data: [i8; 32]) -> Self {
148 Self { scale, data }
149 }
150
151 pub fn to_bytes(&self) -> Vec<u8> {
153 let mut bytes = Vec::with_capacity(34);
154 bytes.extend_from_slice(&self.scale.to_le_bytes());
155 bytes.extend(self.data.iter().map(|&x| x as u8));
156 bytes
157 }
158
159 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
161 if bytes.len() < 34 {
162 return None;
163 }
164 let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
165 let mut data = [0i8; 32];
166 for (i, &b) in bytes[2..34].iter().enumerate() {
167 data[i] = b as i8;
168 }
169 Some(Self { scale, data })
170 }
171}
172
173#[derive(Debug, Clone)]
175pub struct Q4Block {
176 pub scale: f16,
178 pub data: [u8; 16],
180}
181
182impl Q4Block {
183 pub fn new(scale: f16, data: [u8; 16]) -> Self {
185 Self { scale, data }
186 }
187
188 pub fn unpack(&self) -> [i8; 32] {
190 let mut result = [0i8; 32];
191 for i in 0..16 {
192 let byte = self.data[i];
193 result[i * 2] = ((byte & 0x0F) as i8) - 8;
194 result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
195 }
196 result
197 }
198
199 pub fn pack(values: &[i8; 32]) -> [u8; 16] {
201 let mut data = [0u8; 16];
202 for i in 0..16 {
203 let low = ((values[i * 2] + 8) as u8) & 0x0F;
204 let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
205 data[i] = low | (high << 4);
206 }
207 data
208 }
209
210 pub fn to_bytes(&self) -> Vec<u8> {
212 let mut bytes = Vec::with_capacity(18);
213 bytes.extend_from_slice(&self.scale.to_le_bytes());
214 bytes.extend_from_slice(&self.data);
215 bytes
216 }
217
218 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
220 if bytes.len() < 18 {
221 return None;
222 }
223 let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
224 let mut data = [0u8; 16];
225 data.copy_from_slice(&bytes[2..18]);
226 Some(Self { scale, data })
227 }
228}
229
230#[derive(Debug, Clone)]
232pub struct Q4_1Block {
233 pub scale: f16,
235 pub min: f16,
237 pub data: [u8; 16],
239}
240
241impl Q4_1Block {
242 pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
244 Self { scale, min, data }
245 }
246
247 pub fn unpack(&self) -> [u8; 32] {
249 let mut result = [0u8; 32];
250 for i in 0..16 {
251 let byte = self.data[i];
252 result[i * 2] = byte & 0x0F;
253 result[i * 2 + 1] = byte >> 4;
254 }
255 result
256 }
257
258 pub fn to_bytes(&self) -> Vec<u8> {
260 let mut bytes = Vec::with_capacity(20);
261 bytes.extend_from_slice(&self.scale.to_le_bytes());
262 bytes.extend_from_slice(&self.min.to_le_bytes());
263 bytes.extend_from_slice(&self.data);
264 bytes
265 }
266}
267
268#[derive(Debug, Clone)]
274pub enum QuantizedBlock {
275 Q8(Q8Block),
277 Q4(Q4Block),
279 Q4_1(Q4_1Block),
281 F16(Vec<f16>),
283 F32(Vec<f32>),
285}
286
287impl QuantizedBlock {
288 pub fn quant_type(&self) -> QuantType {
290 match self {
291 QuantizedBlock::Q8(_) => QuantType::Q8_0,
292 QuantizedBlock::Q4(_) => QuantType::Q4_0,
293 QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
294 QuantizedBlock::F16(_) => QuantType::F16,
295 QuantizedBlock::F32(_) => QuantType::F32,
296 }
297 }
298}
299
300#[derive(Debug, Clone)]
306pub struct QuantizedTensor {
307 pub shape: Vec<usize>,
309 pub quant_type: QuantType,
311 pub blocks: Vec<QuantizedBlock>,
313 pub numel: usize,
315}
316
317impl QuantizedTensor {
318 pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
320 let numel = shape.iter().product();
321 Self {
322 shape,
323 quant_type,
324 blocks,
325 numel,
326 }
327 }
328
329 pub fn size_bytes(&self) -> usize {
331 self.blocks.len() * self.quant_type.bytes_per_block()
332 }
333
334 pub fn compression_ratio(&self) -> f32 {
336 let original_bytes = self.numel * 4;
337 original_bytes as f32 / self.size_bytes() as f32
338 }
339
340 pub fn num_blocks(&self) -> usize {
342 self.blocks.len()
343 }
344}
345
346#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_quant_type_properties() {
356 assert_eq!(QuantType::Q8_0.block_size(), 32);
357 assert_eq!(QuantType::Q4_0.block_size(), 32);
358 assert_eq!(QuantType::F16.block_size(), 1);
359
360 assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
361 assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
362
363 assert!(QuantType::Q8_0.is_block_quantized());
364 assert!(!QuantType::F16.is_block_quantized());
365 }
366
367 #[test]
368 fn test_quant_type_from_str() {
369 assert_eq!(QuantType::from_str("Q8_0"), Some(QuantType::Q8_0));
370 assert_eq!(QuantType::from_str("INT8"), Some(QuantType::Q8_0));
371 assert_eq!(QuantType::from_str("Q4"), Some(QuantType::Q4_0));
372 assert_eq!(QuantType::from_str("F16"), Some(QuantType::F16));
373 assert_eq!(QuantType::from_str("invalid"), None);
374 }
375
376 #[test]
377 fn test_q4_pack_unpack() {
378 let values: [i8; 32] = [
379 -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1,
380 0, 1, 2, 3, 4, 5, 6, 7,
381 ];
382
383 let packed = Q4Block::pack(&values);
384 let block = Q4Block::new(f16::from_f32(1.0), packed);
385 let unpacked = block.unpack();
386
387 assert_eq!(values, unpacked);
388 }
389
390 #[test]
391 fn test_q8_block() {
392 let data = [0i8; 32];
393 let block = Q8Block::new(f16::from_f32(0.5), data);
394 let bytes = block.to_bytes();
395 let restored = Q8Block::from_bytes(&bytes).unwrap();
396
397 assert_eq!(block.scale, restored.scale);
398 assert_eq!(block.data, restored.data);
399 }
400}