1use half::f16;
18use std::fmt;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum QuantType {
27 Q8_0,
30
31 Q4_0,
34
35 Q4_1,
38
39 Q5_0,
41
42 Q5_1,
44
45 F16,
47
48 F32,
50}
51
52impl QuantType {
53 pub fn block_size(&self) -> usize {
55 match self {
56 QuantType::Q8_0
57 | QuantType::Q4_0
58 | QuantType::Q4_1
59 | QuantType::Q5_0
60 | QuantType::Q5_1 => 32,
61 QuantType::F16 | QuantType::F32 => 1,
62 }
63 }
64
65 pub fn bytes_per_block(&self) -> usize {
67 match self {
68 QuantType::Q8_0 => 2 + 32, QuantType::Q4_0 => 2 + 16, QuantType::Q4_1 => 4 + 16, QuantType::Q5_0 => 2 + 20, QuantType::Q5_1 => 4 + 20, QuantType::F16 => 2,
74 QuantType::F32 => 4,
75 }
76 }
77
78 pub fn bits_per_value(&self) -> usize {
80 match self {
81 QuantType::Q8_0 => 8,
82 QuantType::Q4_0 | QuantType::Q4_1 => 4,
83 QuantType::Q5_0 | QuantType::Q5_1 => 5,
84 QuantType::F16 => 16,
85 QuantType::F32 => 32,
86 }
87 }
88
89 pub fn compression_ratio(&self) -> f32 {
91 32.0 / self.bits_per_value() as f32
92 }
93
94 pub fn is_block_quantized(&self) -> bool {
96 matches!(
97 self,
98 QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q5_0 | QuantType::Q5_1
99 )
100 }
101
102 pub fn parse_type(s: &str) -> Option<Self> {
104 match s.to_uppercase().as_str() {
105 "Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
106 "Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
107 "Q4_1" => Some(QuantType::Q4_1),
108 "Q5_0" | "Q5" => Some(QuantType::Q5_0),
109 "Q5_1" => Some(QuantType::Q5_1),
110 "F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
111 "F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
112 _ => None,
113 }
114 }
115}
116
117impl std::str::FromStr for QuantType {
118 type Err = String;
119
120 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
121 Self::parse_type(s).ok_or_else(|| format!("Unknown quant type: '{s}'"))
122 }
123}
124
125impl fmt::Display for QuantType {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 match self {
128 QuantType::Q8_0 => write!(f, "Q8_0"),
129 QuantType::Q4_0 => write!(f, "Q4_0"),
130 QuantType::Q4_1 => write!(f, "Q4_1"),
131 QuantType::Q5_0 => write!(f, "Q5_0"),
132 QuantType::Q5_1 => write!(f, "Q5_1"),
133 QuantType::F16 => write!(f, "F16"),
134 QuantType::F32 => write!(f, "F32"),
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
145pub struct Q8Block {
146 pub scale: f16,
148 pub data: [i8; 32],
150}
151
152impl Q8Block {
153 pub fn new(scale: f16, data: [i8; 32]) -> Self {
155 Self { scale, data }
156 }
157
158 pub fn to_bytes(&self) -> Vec<u8> {
160 let mut bytes = Vec::with_capacity(34);
161 bytes.extend_from_slice(&self.scale.to_le_bytes());
162 bytes.extend(self.data.iter().map(|&x| x as u8));
163 bytes
164 }
165
166 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
168 if bytes.len() < 34 {
169 return None;
170 }
171 let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
172 let mut data = [0i8; 32];
173 for (i, &b) in bytes[2..34].iter().enumerate() {
174 data[i] = b as i8;
175 }
176 Some(Self { scale, data })
177 }
178}
179
180#[derive(Debug, Clone)]
182pub struct Q4Block {
183 pub scale: f16,
185 pub data: [u8; 16],
187}
188
189impl Q4Block {
190 pub fn new(scale: f16, data: [u8; 16]) -> Self {
192 Self { scale, data }
193 }
194
195 pub fn unpack(&self) -> [i8; 32] {
197 let mut result = [0i8; 32];
198 for i in 0..16 {
199 let byte = self.data[i];
200 result[i * 2] = ((byte & 0x0F) as i8) - 8;
201 result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
202 }
203 result
204 }
205
206 pub fn pack(values: &[i8; 32]) -> [u8; 16] {
208 let mut data = [0u8; 16];
209 for i in 0..16 {
210 let low = ((values[i * 2] + 8) as u8) & 0x0F;
211 let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
212 data[i] = low | (high << 4);
213 }
214 data
215 }
216
217 pub fn to_bytes(&self) -> Vec<u8> {
219 let mut bytes = Vec::with_capacity(18);
220 bytes.extend_from_slice(&self.scale.to_le_bytes());
221 bytes.extend_from_slice(&self.data);
222 bytes
223 }
224
225 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
227 if bytes.len() < 18 {
228 return None;
229 }
230 let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
231 let mut data = [0u8; 16];
232 data.copy_from_slice(&bytes[2..18]);
233 Some(Self { scale, data })
234 }
235}
236
237#[derive(Debug, Clone)]
239pub struct Q4_1Block {
240 pub scale: f16,
242 pub min: f16,
244 pub data: [u8; 16],
246}
247
248impl Q4_1Block {
249 pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
251 Self { scale, min, data }
252 }
253
254 pub fn unpack(&self) -> [u8; 32] {
256 let mut result = [0u8; 32];
257 for i in 0..16 {
258 let byte = self.data[i];
259 result[i * 2] = byte & 0x0F;
260 result[i * 2 + 1] = byte >> 4;
261 }
262 result
263 }
264
265 pub fn to_bytes(&self) -> Vec<u8> {
267 let mut bytes = Vec::with_capacity(20);
268 bytes.extend_from_slice(&self.scale.to_le_bytes());
269 bytes.extend_from_slice(&self.min.to_le_bytes());
270 bytes.extend_from_slice(&self.data);
271 bytes
272 }
273}
274
275#[derive(Debug, Clone)]
283pub struct Q5Block {
284 pub scale: f16,
286 pub data: [u8; 20],
288}
289
290impl Q5Block {
291 pub fn new(scale: f16, data: [u8; 20]) -> Self {
293 Self { scale, data }
294 }
295
296 pub fn pack(values: &[i8; 32]) -> [u8; 20] {
298 let mut packed = [0u8; 20];
299 for i in 0..32 {
302 let v = (values[i] as u8) & 0x1F; let bit_offset = i * 5;
304 let byte_offset = bit_offset / 8;
305 let bit_shift = bit_offset % 8;
306 packed[byte_offset] |= v << bit_shift;
307 if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
308 packed[byte_offset + 1] |= v >> (8 - bit_shift);
309 }
310 }
311 packed
312 }
313
314 pub fn unpack(&self) -> [i8; 32] {
316 let mut result = [0i8; 32];
317 for i in 0..32 {
318 let bit_offset = i * 5;
319 let byte_offset = bit_offset / 8;
320 let bit_shift = bit_offset % 8;
321 let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
322 if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
323 v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
324 }
325 if v & 0x10 != 0 {
327 result[i] = (v | 0xE0) as i8; } else {
329 result[i] = v as i8;
330 }
331 }
332 result
333 }
334
335 pub fn to_bytes(&self) -> Vec<u8> {
337 let mut bytes = Vec::with_capacity(22);
338 bytes.extend_from_slice(&self.scale.to_le_bytes());
339 bytes.extend_from_slice(&self.data);
340 bytes
341 }
342}
343
344#[derive(Debug, Clone)]
350pub struct Q5_1Block {
351 pub scale: f16,
353 pub min: f16,
355 pub data: [u8; 20],
357}
358
359impl Q5_1Block {
360 pub fn new(scale: f16, min: f16, data: [u8; 20]) -> Self {
362 Self { scale, min, data }
363 }
364
365 pub fn pack(values: &[u8; 32]) -> [u8; 20] {
367 let mut packed = [0u8; 20];
368 for i in 0..32 {
369 let v = values[i] & 0x1F;
370 let bit_offset = i * 5;
371 let byte_offset = bit_offset / 8;
372 let bit_shift = bit_offset % 8;
373 packed[byte_offset] |= v << bit_shift;
374 if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
375 packed[byte_offset + 1] |= v >> (8 - bit_shift);
376 }
377 }
378 packed
379 }
380
381 pub fn unpack(&self) -> [u8; 32] {
383 let mut result = [0u8; 32];
384 for i in 0..32 {
385 let bit_offset = i * 5;
386 let byte_offset = bit_offset / 8;
387 let bit_shift = bit_offset % 8;
388 let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
389 if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
390 v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
391 }
392 result[i] = v;
393 }
394 result
395 }
396
397 pub fn to_bytes(&self) -> Vec<u8> {
399 let mut bytes = Vec::with_capacity(24);
400 bytes.extend_from_slice(&self.scale.to_le_bytes());
401 bytes.extend_from_slice(&self.min.to_le_bytes());
402 bytes.extend_from_slice(&self.data);
403 bytes
404 }
405}
406
407#[derive(Debug, Clone)]
413pub enum QuantizedBlock {
414 Q8(Q8Block),
416 Q4(Q4Block),
418 Q4_1(Q4_1Block),
420 Q5(Q5Block),
422 Q5_1(Q5_1Block),
424 F16(Vec<f16>),
426 F32(Vec<f32>),
428}
429
430impl QuantizedBlock {
431 pub fn quant_type(&self) -> QuantType {
433 match self {
434 QuantizedBlock::Q8(_) => QuantType::Q8_0,
435 QuantizedBlock::Q4(_) => QuantType::Q4_0,
436 QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
437 QuantizedBlock::Q5(_) => QuantType::Q5_0,
438 QuantizedBlock::Q5_1(_) => QuantType::Q5_1,
439 QuantizedBlock::F16(_) => QuantType::F16,
440 QuantizedBlock::F32(_) => QuantType::F32,
441 }
442 }
443}
444
445#[derive(Debug, Clone)]
451pub struct QuantizedTensor {
452 pub shape: Vec<usize>,
454 pub quant_type: QuantType,
456 pub blocks: Vec<QuantizedBlock>,
458 pub numel: usize,
460}
461
462impl QuantizedTensor {
463 pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
465 let numel = shape.iter().product();
466 Self {
467 shape,
468 quant_type,
469 blocks,
470 numel,
471 }
472 }
473
474 pub fn size_bytes(&self) -> usize {
476 self.blocks.len() * self.quant_type.bytes_per_block()
477 }
478
479 pub fn compression_ratio(&self) -> f32 {
481 let original_bytes = self.numel * 4;
482 original_bytes as f32 / self.size_bytes() as f32
483 }
484
485 pub fn num_blocks(&self) -> usize {
487 self.blocks.len()
488 }
489}
490
491#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_quant_type_properties() {
501 assert_eq!(QuantType::Q8_0.block_size(), 32);
502 assert_eq!(QuantType::Q4_0.block_size(), 32);
503 assert_eq!(QuantType::F16.block_size(), 1);
504
505 assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
506 assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
507
508 assert!(QuantType::Q8_0.is_block_quantized());
509 assert!(!QuantType::F16.is_block_quantized());
510 }
511
512 #[test]
513 fn test_quant_type_from_str() {
514 assert_eq!(QuantType::parse_type("Q8_0"), Some(QuantType::Q8_0));
515 assert_eq!(QuantType::parse_type("INT8"), Some(QuantType::Q8_0));
516 assert_eq!(QuantType::parse_type("Q4"), Some(QuantType::Q4_0));
517 assert_eq!(QuantType::parse_type("F16"), Some(QuantType::F16));
518 assert_eq!(QuantType::parse_type("invalid"), None);
519 }
520
521 #[test]
522 fn test_q4_pack_unpack() {
523 let values: [i8; 32] = [
524 -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1,
525 0, 1, 2, 3, 4, 5, 6, 7,
526 ];
527
528 let packed = Q4Block::pack(&values);
529 let block = Q4Block::new(f16::from_f32(1.0), packed);
530 let unpacked = block.unpack();
531
532 assert_eq!(values, unpacked);
533 }
534
535 #[test]
536 fn test_q8_block() {
537 let data = [0i8; 32];
538 let block = Q8Block::new(f16::from_f32(0.5), data);
539 let bytes = block.to_bytes();
540 let restored = Q8Block::from_bytes(&bytes).unwrap();
541
542 assert_eq!(block.scale, restored.scale);
543 assert_eq!(block.data, restored.data);
544 }
545}