1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[repr(u8)]
6pub enum QuantScheme {
7 QInt8 = 0x01,
12
13 #[allow(dead_code)]
18 QInt4 = 0x02,
19
20 #[allow(dead_code)]
25 Binary = 0x03,
26
27 #[allow(dead_code)]
32 FP16Passthrough = 0x04,
33}
34
35impl QuantScheme {
36 pub fn bytes_per_value(self) -> usize {
38 match self {
39 QuantScheme::QInt8 => 1,
40 QuantScheme::QInt4 => 1, QuantScheme::Binary => 1, QuantScheme::FP16Passthrough => 2,
43 }
44 }
45
46 pub fn compression_ratio(self) -> f32 {
48 4.0 / self.bytes_per_value() as f32
49 }
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55
56 #[test]
57 fn test_bytes_per_value() {
58 assert_eq!(QuantScheme::QInt8.bytes_per_value(), 1);
59 assert_eq!(QuantScheme::QInt4.bytes_per_value(), 1);
60 assert_eq!(QuantScheme::Binary.bytes_per_value(), 1);
61 assert_eq!(QuantScheme::FP16Passthrough.bytes_per_value(), 2);
62 }
63
64 #[test]
65 fn test_compression_ratio() {
66 assert_eq!(QuantScheme::QInt8.compression_ratio(), 4.0);
67 assert_eq!(QuantScheme::FP16Passthrough.compression_ratio(), 2.0);
68 }
69}