1use half::f16;
18use std::fmt;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum QuantType {
27 Q8_0,
30
31 Q4_0,
34
35 Q4_1,
38
39 Q5_0,
41
42 Q5_1,
44
45 F16,
47
48 F32,
50}
51
52impl QuantType {
53 pub fn block_size(&self) -> usize {
55 match self {
56 QuantType::Q8_0
57 | QuantType::Q4_0
58 | QuantType::Q4_1
59 | QuantType::Q5_0
60 | QuantType::Q5_1 => 32,
61 QuantType::F16 | QuantType::F32 => 1,
62 }
63 }
64
65 pub fn bytes_per_block(&self) -> usize {
67 match self {
68 QuantType::Q8_0 => 2 + 32, QuantType::Q4_0 => 2 + 16, QuantType::Q4_1 => 4 + 16, QuantType::Q5_0 => 2 + 20, QuantType::Q5_1 => 4 + 20, QuantType::F16 => 2,
74 QuantType::F32 => 4,
75 }
76 }
77
78 pub fn bits_per_value(&self) -> usize {
80 match self {
81 QuantType::Q8_0 => 8,
82 QuantType::Q4_0 | QuantType::Q4_1 => 4,
83 QuantType::Q5_0 | QuantType::Q5_1 => 5,
84 QuantType::F16 => 16,
85 QuantType::F32 => 32,
86 }
87 }
88
89 pub fn compression_ratio(&self) -> f32 {
91 32.0 / self.bits_per_value() as f32
92 }
93
94 pub fn is_block_quantized(&self) -> bool {
96 matches!(
97 self,
98 QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q5_0 | QuantType::Q5_1
99 )
100 }
101
102 pub fn parse_type(s: &str) -> Option<Self> {
104 match s.to_uppercase().as_str() {
105 "Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
106 "Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
107 "Q4_1" => Some(QuantType::Q4_1),
108 "Q5_0" | "Q5" => Some(QuantType::Q5_0),
109 "Q5_1" => Some(QuantType::Q5_1),
110 "F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
111 "F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
112 _ => None,
113 }
114 }
115}
116
117impl std::str::FromStr for QuantType {
118 type Err = String;
119
120 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
121 Self::parse_type(s).ok_or_else(|| format!("Unknown quant type: '{s}'"))
122 }
123}
124
125impl fmt::Display for QuantType {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 match self {
128 QuantType::Q8_0 => write!(f, "Q8_0"),
129 QuantType::Q4_0 => write!(f, "Q4_0"),
130 QuantType::Q4_1 => write!(f, "Q4_1"),
131 QuantType::Q5_0 => write!(f, "Q5_0"),
132 QuantType::Q5_1 => write!(f, "Q5_1"),
133 QuantType::F16 => write!(f, "F16"),
134 QuantType::F32 => write!(f, "F32"),
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
145pub struct Q8Block {
146 pub scale: f16,
148 pub data: [i8; 32],
150}
151
152impl Q8Block {
153 pub fn new(scale: f16, data: [i8; 32]) -> Self {
155 Self { scale, data }
156 }
157
158 pub fn to_bytes(&self) -> Vec<u8> {
160 let mut bytes = Vec::with_capacity(34);
161 bytes.extend_from_slice(&self.scale.to_le_bytes());
162 bytes.extend(self.data.iter().map(|&x| x as u8));
163 bytes
164 }
165
166 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
168 if bytes.len() < 34 {
169 return None;
170 }
171 let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
172 let mut data = [0i8; 32];
173 for (i, &b) in bytes[2..34].iter().enumerate() {
174 data[i] = b as i8;
175 }
176 Some(Self { scale, data })
177 }
178}
179
180#[derive(Debug, Clone)]
182pub struct Q4Block {
183 pub scale: f16,
185 pub data: [u8; 16],
187}
188
189impl Q4Block {
190 pub fn new(scale: f16, data: [u8; 16]) -> Self {
192 Self { scale, data }
193 }
194
195 pub fn unpack(&self) -> [i8; 32] {
197 let mut result = [0i8; 32];
198 for i in 0..16 {
199 let byte = self.data[i];
200 result[i * 2] = ((byte & 0x0F) as i8) - 8;
201 result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
202 }
203 result
204 }
205
206 pub fn pack(values: &[i8; 32]) -> [u8; 16] {
208 let mut data = [0u8; 16];
209 for i in 0..16 {
210 let low = ((values[i * 2] + 8) as u8) & 0x0F;
211 let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
212 data[i] = low | (high << 4);
213 }
214 data
215 }
216
217 pub fn to_bytes(&self) -> Vec<u8> {
219 let mut bytes = Vec::with_capacity(18);
220 bytes.extend_from_slice(&self.scale.to_le_bytes());
221 bytes.extend_from_slice(&self.data);
222 bytes
223 }
224
225 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
227 if bytes.len() < 18 {
228 return None;
229 }
230 let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
231 let mut data = [0u8; 16];
232 data.copy_from_slice(&bytes[2..18]);
233 Some(Self { scale, data })
234 }
235}
236
237#[derive(Debug, Clone)]
239pub struct Q4_1Block {
240 pub scale: f16,
242 pub min: f16,
244 pub data: [u8; 16],
246}
247
248impl Q4_1Block {
249 pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
251 Self { scale, min, data }
252 }
253
254 pub fn unpack(&self) -> [u8; 32] {
256 let mut result = [0u8; 32];
257 for i in 0..16 {
258 let byte = self.data[i];
259 result[i * 2] = byte & 0x0F;
260 result[i * 2 + 1] = byte >> 4;
261 }
262 result
263 }
264
265 pub fn to_bytes(&self) -> Vec<u8> {
267 let mut bytes = Vec::with_capacity(20);
268 bytes.extend_from_slice(&self.scale.to_le_bytes());
269 bytes.extend_from_slice(&self.min.to_le_bytes());
270 bytes.extend_from_slice(&self.data);
271 bytes
272 }
273}
274
275#[derive(Debug, Clone)]
283pub struct Q5Block {
284 pub scale: f16,
286 pub data: [u8; 20],
288}
289
290impl Q5Block {
291 pub fn new(scale: f16, data: [u8; 20]) -> Self {
293 Self { scale, data }
294 }
295
296 pub fn pack(values: &[i8; 32]) -> [u8; 20] {
298 let mut packed = [0u8; 20];
299 #[allow(clippy::needless_range_loop)]
302 for i in 0..32 {
303 let v = (values[i] as u8) & 0x1F; let bit_offset = i * 5;
305 let byte_offset = bit_offset / 8;
306 let bit_shift = bit_offset % 8;
307 packed[byte_offset] |= v << bit_shift;
308 if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
309 packed[byte_offset + 1] |= v >> (8 - bit_shift);
310 }
311 }
312 packed
313 }
314
315 pub fn unpack(&self) -> [i8; 32] {
317 let mut result = [0i8; 32];
318 #[allow(clippy::needless_range_loop)]
319 for i in 0..32 {
320 let bit_offset = i * 5;
321 let byte_offset = bit_offset / 8;
322 let bit_shift = bit_offset % 8;
323 let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
324 if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
325 v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
326 }
327 if v & 0x10 != 0 {
329 result[i] = (v | 0xE0) as i8; } else {
331 result[i] = v as i8;
332 }
333 }
334 result
335 }
336
337 pub fn to_bytes(&self) -> Vec<u8> {
339 let mut bytes = Vec::with_capacity(22);
340 bytes.extend_from_slice(&self.scale.to_le_bytes());
341 bytes.extend_from_slice(&self.data);
342 bytes
343 }
344}
345
346#[derive(Debug, Clone)]
352pub struct Q5_1Block {
353 pub scale: f16,
355 pub min: f16,
357 pub data: [u8; 20],
359}
360
361impl Q5_1Block {
362 pub fn new(scale: f16, min: f16, data: [u8; 20]) -> Self {
364 Self { scale, min, data }
365 }
366
367 pub fn pack(values: &[u8; 32]) -> [u8; 20] {
369 let mut packed = [0u8; 20];
370 #[allow(clippy::needless_range_loop)]
371 for i in 0..32 {
372 let v = values[i] & 0x1F;
373 let bit_offset = i * 5;
374 let byte_offset = bit_offset / 8;
375 let bit_shift = bit_offset % 8;
376 packed[byte_offset] |= v << bit_shift;
377 if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
378 packed[byte_offset + 1] |= v >> (8 - bit_shift);
379 }
380 }
381 packed
382 }
383
384 pub fn unpack(&self) -> [u8; 32] {
386 let mut result = [0u8; 32];
387 #[allow(clippy::needless_range_loop)]
388 for i in 0..32 {
389 let bit_offset = i * 5;
390 let byte_offset = bit_offset / 8;
391 let bit_shift = bit_offset % 8;
392 let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
393 if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
394 v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
395 }
396 result[i] = v;
397 }
398 result
399 }
400
401 pub fn to_bytes(&self) -> Vec<u8> {
403 let mut bytes = Vec::with_capacity(24);
404 bytes.extend_from_slice(&self.scale.to_le_bytes());
405 bytes.extend_from_slice(&self.min.to_le_bytes());
406 bytes.extend_from_slice(&self.data);
407 bytes
408 }
409}
410
411#[derive(Debug, Clone)]
417pub enum QuantizedBlock {
418 Q8(Q8Block),
420 Q4(Q4Block),
422 Q4_1(Q4_1Block),
424 Q5(Q5Block),
426 Q5_1(Q5_1Block),
428 F16(Vec<f16>),
430 F32(Vec<f32>),
432}
433
434impl QuantizedBlock {
435 pub fn quant_type(&self) -> QuantType {
437 match self {
438 QuantizedBlock::Q8(_) => QuantType::Q8_0,
439 QuantizedBlock::Q4(_) => QuantType::Q4_0,
440 QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
441 QuantizedBlock::Q5(_) => QuantType::Q5_0,
442 QuantizedBlock::Q5_1(_) => QuantType::Q5_1,
443 QuantizedBlock::F16(_) => QuantType::F16,
444 QuantizedBlock::F32(_) => QuantType::F32,
445 }
446 }
447}
448
449#[derive(Debug, Clone)]
455pub struct QuantizedTensor {
456 pub shape: Vec<usize>,
458 pub quant_type: QuantType,
460 pub blocks: Vec<QuantizedBlock>,
462 pub numel: usize,
464}
465
466impl QuantizedTensor {
467 pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
469 let numel = shape.iter().product();
470 Self {
471 shape,
472 quant_type,
473 blocks,
474 numel,
475 }
476 }
477
478 pub fn size_bytes(&self) -> usize {
480 self.blocks.len() * self.quant_type.bytes_per_block()
481 }
482
483 pub fn compression_ratio(&self) -> f32 {
485 let original_bytes = self.numel * 4;
486 original_bytes as f32 / self.size_bytes() as f32
487 }
488
489 pub fn num_blocks(&self) -> usize {
491 self.blocks.len()
492 }
493}
494
495#[cfg(test)]
500mod tests {
501 use super::*;
502
503 #[test]
504 fn test_quant_type_properties() {
505 assert_eq!(QuantType::Q8_0.block_size(), 32);
506 assert_eq!(QuantType::Q4_0.block_size(), 32);
507 assert_eq!(QuantType::F16.block_size(), 1);
508
509 assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
510 assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
511
512 assert!(QuantType::Q8_0.is_block_quantized());
513 assert!(!QuantType::F16.is_block_quantized());
514 }
515
516 #[test]
517 fn test_quant_type_from_str() {
518 assert_eq!(QuantType::parse_type("Q8_0"), Some(QuantType::Q8_0));
519 assert_eq!(QuantType::parse_type("INT8"), Some(QuantType::Q8_0));
520 assert_eq!(QuantType::parse_type("Q4"), Some(QuantType::Q4_0));
521 assert_eq!(QuantType::parse_type("F16"), Some(QuantType::F16));
522 assert_eq!(QuantType::parse_type("invalid"), None);
523 }
524
525 #[test]
526 fn test_q4_pack_unpack() {
527 let values: [i8; 32] = [
528 -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1,
529 0, 1, 2, 3, 4, 5, 6, 7,
530 ];
531
532 let packed = Q4Block::pack(&values);
533 let block = Q4Block::new(f16::from_f32(1.0), packed);
534 let unpacked = block.unpack();
535
536 assert_eq!(values, unpacked);
537 }
538
539 #[test]
540 fn test_q8_block() {
541 let data = [0i8; 32];
542 let block = Q8Block::new(f16::from_f32(0.5), data);
543 let bytes = block.to_bytes();
544 let restored = Q8Block::from_bytes(&bytes).unwrap();
545
546 assert_eq!(block.scale, restored.scale);
547 assert_eq!(block.data, restored.data);
548 }
549}