1use half::f16;
9use std::fmt;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum QuantType {
18 Q8_0,
21
22 Q4_0,
25
26 Q4_1,
29
30 Q5_0,
32
33 Q5_1,
35
36 F16,
38
39 F32,
41}
42
43impl QuantType {
44 pub fn block_size(&self) -> usize {
46 match self {
47 QuantType::Q8_0
48 | QuantType::Q4_0
49 | QuantType::Q4_1
50 | QuantType::Q5_0
51 | QuantType::Q5_1 => 32,
52 QuantType::F16 | QuantType::F32 => 1,
53 }
54 }
55
56 pub fn bytes_per_block(&self) -> usize {
58 match self {
59 QuantType::Q8_0 => 2 + 32, QuantType::Q4_0 => 2 + 16, QuantType::Q4_1 => 4 + 16, QuantType::Q5_0 => 2 + 20, QuantType::Q5_1 => 4 + 20, QuantType::F16 => 2,
65 QuantType::F32 => 4,
66 }
67 }
68
69 pub fn bits_per_value(&self) -> usize {
71 match self {
72 QuantType::Q8_0 => 8,
73 QuantType::Q4_0 | QuantType::Q4_1 => 4,
74 QuantType::Q5_0 | QuantType::Q5_1 => 5,
75 QuantType::F16 => 16,
76 QuantType::F32 => 32,
77 }
78 }
79
80 pub fn compression_ratio(&self) -> f32 {
82 32.0 / self.bits_per_value() as f32
83 }
84
85 pub fn is_block_quantized(&self) -> bool {
87 matches!(
88 self,
89 QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q5_0 | QuantType::Q5_1
90 )
91 }
92
93 pub fn from_str(s: &str) -> Option<Self> {
95 match s.to_uppercase().as_str() {
96 "Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
97 "Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
98 "Q4_1" => Some(QuantType::Q4_1),
99 "Q5_0" | "Q5" => Some(QuantType::Q5_0),
100 "Q5_1" => Some(QuantType::Q5_1),
101 "F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
102 "F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
103 _ => None,
104 }
105 }
106}
107
108impl fmt::Display for QuantType {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 match self {
111 QuantType::Q8_0 => write!(f, "Q8_0"),
112 QuantType::Q4_0 => write!(f, "Q4_0"),
113 QuantType::Q4_1 => write!(f, "Q4_1"),
114 QuantType::Q5_0 => write!(f, "Q5_0"),
115 QuantType::Q5_1 => write!(f, "Q5_1"),
116 QuantType::F16 => write!(f, "F16"),
117 QuantType::F32 => write!(f, "F32"),
118 }
119 }
120}
121
122#[derive(Debug, Clone)]
128pub struct Q8Block {
129 pub scale: f16,
131 pub data: [i8; 32],
133}
134
135impl Q8Block {
136 pub fn new(scale: f16, data: [i8; 32]) -> Self {
138 Self { scale, data }
139 }
140
141 pub fn to_bytes(&self) -> Vec<u8> {
143 let mut bytes = Vec::with_capacity(34);
144 bytes.extend_from_slice(&self.scale.to_le_bytes());
145 bytes.extend(self.data.iter().map(|&x| x as u8));
146 bytes
147 }
148
149 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
151 if bytes.len() < 34 {
152 return None;
153 }
154 let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
155 let mut data = [0i8; 32];
156 for (i, &b) in bytes[2..34].iter().enumerate() {
157 data[i] = b as i8;
158 }
159 Some(Self { scale, data })
160 }
161}
162
163#[derive(Debug, Clone)]
165pub struct Q4Block {
166 pub scale: f16,
168 pub data: [u8; 16],
170}
171
172impl Q4Block {
173 pub fn new(scale: f16, data: [u8; 16]) -> Self {
175 Self { scale, data }
176 }
177
178 pub fn unpack(&self) -> [i8; 32] {
180 let mut result = [0i8; 32];
181 for i in 0..16 {
182 let byte = self.data[i];
183 result[i * 2] = ((byte & 0x0F) as i8) - 8;
184 result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
185 }
186 result
187 }
188
189 pub fn pack(values: &[i8; 32]) -> [u8; 16] {
191 let mut data = [0u8; 16];
192 for i in 0..16 {
193 let low = ((values[i * 2] + 8) as u8) & 0x0F;
194 let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
195 data[i] = low | (high << 4);
196 }
197 data
198 }
199
200 pub fn to_bytes(&self) -> Vec<u8> {
202 let mut bytes = Vec::with_capacity(18);
203 bytes.extend_from_slice(&self.scale.to_le_bytes());
204 bytes.extend_from_slice(&self.data);
205 bytes
206 }
207
208 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
210 if bytes.len() < 18 {
211 return None;
212 }
213 let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
214 let mut data = [0u8; 16];
215 data.copy_from_slice(&bytes[2..18]);
216 Some(Self { scale, data })
217 }
218}
219
220#[derive(Debug, Clone)]
222pub struct Q4_1Block {
223 pub scale: f16,
225 pub min: f16,
227 pub data: [u8; 16],
229}
230
231impl Q4_1Block {
232 pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
234 Self { scale, min, data }
235 }
236
237 pub fn unpack(&self) -> [u8; 32] {
239 let mut result = [0u8; 32];
240 for i in 0..16 {
241 let byte = self.data[i];
242 result[i * 2] = byte & 0x0F;
243 result[i * 2 + 1] = byte >> 4;
244 }
245 result
246 }
247
248 pub fn to_bytes(&self) -> Vec<u8> {
250 let mut bytes = Vec::with_capacity(20);
251 bytes.extend_from_slice(&self.scale.to_le_bytes());
252 bytes.extend_from_slice(&self.min.to_le_bytes());
253 bytes.extend_from_slice(&self.data);
254 bytes
255 }
256}
257
258#[derive(Debug, Clone)]
264pub enum QuantizedBlock {
265 Q8(Q8Block),
267 Q4(Q4Block),
269 Q4_1(Q4_1Block),
271 F16(Vec<f16>),
273 F32(Vec<f32>),
275}
276
277impl QuantizedBlock {
278 pub fn quant_type(&self) -> QuantType {
280 match self {
281 QuantizedBlock::Q8(_) => QuantType::Q8_0,
282 QuantizedBlock::Q4(_) => QuantType::Q4_0,
283 QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
284 QuantizedBlock::F16(_) => QuantType::F16,
285 QuantizedBlock::F32(_) => QuantType::F32,
286 }
287 }
288}
289
290#[derive(Debug, Clone)]
296pub struct QuantizedTensor {
297 pub shape: Vec<usize>,
299 pub quant_type: QuantType,
301 pub blocks: Vec<QuantizedBlock>,
303 pub numel: usize,
305}
306
307impl QuantizedTensor {
308 pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
310 let numel = shape.iter().product();
311 Self {
312 shape,
313 quant_type,
314 blocks,
315 numel,
316 }
317 }
318
319 pub fn size_bytes(&self) -> usize {
321 self.blocks.len() * self.quant_type.bytes_per_block()
322 }
323
324 pub fn compression_ratio(&self) -> f32 {
326 let original_bytes = self.numel * 4;
327 original_bytes as f32 / self.size_bytes() as f32
328 }
329
330 pub fn num_blocks(&self) -> usize {
332 self.blocks.len()
333 }
334}
335
336#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_quant_type_properties() {
346 assert_eq!(QuantType::Q8_0.block_size(), 32);
347 assert_eq!(QuantType::Q4_0.block_size(), 32);
348 assert_eq!(QuantType::F16.block_size(), 1);
349
350 assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
351 assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
352
353 assert!(QuantType::Q8_0.is_block_quantized());
354 assert!(!QuantType::F16.is_block_quantized());
355 }
356
357 #[test]
358 fn test_quant_type_from_str() {
359 assert_eq!(QuantType::from_str("Q8_0"), Some(QuantType::Q8_0));
360 assert_eq!(QuantType::from_str("INT8"), Some(QuantType::Q8_0));
361 assert_eq!(QuantType::from_str("Q4"), Some(QuantType::Q4_0));
362 assert_eq!(QuantType::from_str("F16"), Some(QuantType::F16));
363 assert_eq!(QuantType::from_str("invalid"), None);
364 }
365
366 #[test]
367 fn test_q4_pack_unpack() {
368 let values: [i8; 32] = [
369 -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1,
370 0, 1, 2, 3, 4, 5, 6, 7,
371 ];
372
373 let packed = Q4Block::pack(&values);
374 let block = Q4Block::new(f16::from_f32(1.0), packed);
375 let unpacked = block.unpack();
376
377 assert_eq!(values, unpacked);
378 }
379
380 #[test]
381 fn test_q8_block() {
382 let data = [0i8; 32];
383 let block = Q8Block::new(f16::from_f32(0.5), data);
384 let bytes = block.to_bytes();
385 let restored = Q8Block::from_bytes(&bytes).unwrap();
386
387 assert_eq!(block.scale, restored.scale);
388 assert_eq!(block.data, restored.data);
389 }
390}