1use std::fmt;
9use half::f16;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum QuantType {
18 Q8_0,
21
22 Q4_0,
25
26 Q4_1,
29
30 Q5_0,
32
33 Q5_1,
35
36 F16,
38
39 F32,
41}
42
43impl QuantType {
44 pub fn block_size(&self) -> usize {
46 match self {
47 QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 |
48 QuantType::Q5_0 | QuantType::Q5_1 => 32,
49 QuantType::F16 | QuantType::F32 => 1,
50 }
51 }
52
53 pub fn bytes_per_block(&self) -> usize {
55 match self {
56 QuantType::Q8_0 => 2 + 32, QuantType::Q4_0 => 2 + 16, QuantType::Q4_1 => 4 + 16, QuantType::Q5_0 => 2 + 20, QuantType::Q5_1 => 4 + 20, QuantType::F16 => 2,
62 QuantType::F32 => 4,
63 }
64 }
65
66 pub fn bits_per_value(&self) -> usize {
68 match self {
69 QuantType::Q8_0 => 8,
70 QuantType::Q4_0 | QuantType::Q4_1 => 4,
71 QuantType::Q5_0 | QuantType::Q5_1 => 5,
72 QuantType::F16 => 16,
73 QuantType::F32 => 32,
74 }
75 }
76
77 pub fn compression_ratio(&self) -> f32 {
79 32.0 / self.bits_per_value() as f32
80 }
81
82 pub fn is_block_quantized(&self) -> bool {
84 matches!(self, QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 |
85 QuantType::Q5_0 | QuantType::Q5_1)
86 }
87
88 pub fn from_str(s: &str) -> Option<Self> {
90 match s.to_uppercase().as_str() {
91 "Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
92 "Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
93 "Q4_1" => Some(QuantType::Q4_1),
94 "Q5_0" | "Q5" => Some(QuantType::Q5_0),
95 "Q5_1" => Some(QuantType::Q5_1),
96 "F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
97 "F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
98 _ => None,
99 }
100 }
101}
102
103impl fmt::Display for QuantType {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 match self {
106 QuantType::Q8_0 => write!(f, "Q8_0"),
107 QuantType::Q4_0 => write!(f, "Q4_0"),
108 QuantType::Q4_1 => write!(f, "Q4_1"),
109 QuantType::Q5_0 => write!(f, "Q5_0"),
110 QuantType::Q5_1 => write!(f, "Q5_1"),
111 QuantType::F16 => write!(f, "F16"),
112 QuantType::F32 => write!(f, "F32"),
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
123pub struct Q8Block {
124 pub scale: f16,
126 pub data: [i8; 32],
128}
129
130impl Q8Block {
131 pub fn new(scale: f16, data: [i8; 32]) -> Self {
133 Self { scale, data }
134 }
135
136 pub fn to_bytes(&self) -> Vec<u8> {
138 let mut bytes = Vec::with_capacity(34);
139 bytes.extend_from_slice(&self.scale.to_le_bytes());
140 bytes.extend(self.data.iter().map(|&x| x as u8));
141 bytes
142 }
143
144 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
146 if bytes.len() < 34 {
147 return None;
148 }
149 let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
150 let mut data = [0i8; 32];
151 for (i, &b) in bytes[2..34].iter().enumerate() {
152 data[i] = b as i8;
153 }
154 Some(Self { scale, data })
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct Q4Block {
161 pub scale: f16,
163 pub data: [u8; 16],
165}
166
167impl Q4Block {
168 pub fn new(scale: f16, data: [u8; 16]) -> Self {
170 Self { scale, data }
171 }
172
173 pub fn unpack(&self) -> [i8; 32] {
175 let mut result = [0i8; 32];
176 for i in 0..16 {
177 let byte = self.data[i];
178 result[i * 2] = ((byte & 0x0F) as i8) - 8;
179 result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
180 }
181 result
182 }
183
184 pub fn pack(values: &[i8; 32]) -> [u8; 16] {
186 let mut data = [0u8; 16];
187 for i in 0..16 {
188 let low = ((values[i * 2] + 8) as u8) & 0x0F;
189 let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
190 data[i] = low | (high << 4);
191 }
192 data
193 }
194
195 pub fn to_bytes(&self) -> Vec<u8> {
197 let mut bytes = Vec::with_capacity(18);
198 bytes.extend_from_slice(&self.scale.to_le_bytes());
199 bytes.extend_from_slice(&self.data);
200 bytes
201 }
202
203 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
205 if bytes.len() < 18 {
206 return None;
207 }
208 let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
209 let mut data = [0u8; 16];
210 data.copy_from_slice(&bytes[2..18]);
211 Some(Self { scale, data })
212 }
213}
214
215#[derive(Debug, Clone)]
217pub struct Q4_1Block {
218 pub scale: f16,
220 pub min: f16,
222 pub data: [u8; 16],
224}
225
226impl Q4_1Block {
227 pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
229 Self { scale, min, data }
230 }
231
232 pub fn unpack(&self) -> [u8; 32] {
234 let mut result = [0u8; 32];
235 for i in 0..16 {
236 let byte = self.data[i];
237 result[i * 2] = byte & 0x0F;
238 result[i * 2 + 1] = byte >> 4;
239 }
240 result
241 }
242
243 pub fn to_bytes(&self) -> Vec<u8> {
245 let mut bytes = Vec::with_capacity(20);
246 bytes.extend_from_slice(&self.scale.to_le_bytes());
247 bytes.extend_from_slice(&self.min.to_le_bytes());
248 bytes.extend_from_slice(&self.data);
249 bytes
250 }
251}
252
253#[derive(Debug, Clone)]
259pub enum QuantizedBlock {
260 Q8(Q8Block),
262 Q4(Q4Block),
264 Q4_1(Q4_1Block),
266 F16(Vec<f16>),
268 F32(Vec<f32>),
270}
271
272impl QuantizedBlock {
273 pub fn quant_type(&self) -> QuantType {
275 match self {
276 QuantizedBlock::Q8(_) => QuantType::Q8_0,
277 QuantizedBlock::Q4(_) => QuantType::Q4_0,
278 QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
279 QuantizedBlock::F16(_) => QuantType::F16,
280 QuantizedBlock::F32(_) => QuantType::F32,
281 }
282 }
283}
284
285#[derive(Debug, Clone)]
291pub struct QuantizedTensor {
292 pub shape: Vec<usize>,
294 pub quant_type: QuantType,
296 pub blocks: Vec<QuantizedBlock>,
298 pub numel: usize,
300}
301
302impl QuantizedTensor {
303 pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
305 let numel = shape.iter().product();
306 Self {
307 shape,
308 quant_type,
309 blocks,
310 numel,
311 }
312 }
313
314 pub fn size_bytes(&self) -> usize {
316 self.blocks.len() * self.quant_type.bytes_per_block()
317 }
318
319 pub fn compression_ratio(&self) -> f32 {
321 let original_bytes = self.numel * 4;
322 original_bytes as f32 / self.size_bytes() as f32
323 }
324
325 pub fn num_blocks(&self) -> usize {
327 self.blocks.len()
328 }
329}
330
331#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_quant_type_properties() {
341 assert_eq!(QuantType::Q8_0.block_size(), 32);
342 assert_eq!(QuantType::Q4_0.block_size(), 32);
343 assert_eq!(QuantType::F16.block_size(), 1);
344
345 assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
346 assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
347
348 assert!(QuantType::Q8_0.is_block_quantized());
349 assert!(!QuantType::F16.is_block_quantized());
350 }
351
352 #[test]
353 fn test_quant_type_from_str() {
354 assert_eq!(QuantType::from_str("Q8_0"), Some(QuantType::Q8_0));
355 assert_eq!(QuantType::from_str("INT8"), Some(QuantType::Q8_0));
356 assert_eq!(QuantType::from_str("Q4"), Some(QuantType::Q4_0));
357 assert_eq!(QuantType::from_str("F16"), Some(QuantType::F16));
358 assert_eq!(QuantType::from_str("invalid"), None);
359 }
360
361 #[test]
362 fn test_q4_pack_unpack() {
363 let values: [i8; 32] = [
364 -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7,
365 -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7,
366 ];
367
368 let packed = Q4Block::pack(&values);
369 let block = Q4Block::new(f16::from_f32(1.0), packed);
370 let unpacked = block.unpack();
371
372 assert_eq!(values, unpacked);
373 }
374
375 #[test]
376 fn test_q8_block() {
377 let data = [0i8; 32];
378 let block = Q8Block::new(f16::from_f32(0.5), data);
379 let bytes = block.to_bytes();
380 let restored = Q8Block::from_bytes(&bytes).unwrap();
381
382 assert_eq!(block.scale, restored.scale);
383 assert_eq!(block.data, restored.data);
384 }
385}