Skip to main content

trueno/brick/tracing/
quant_type.rs

1// ============================================================================
2// QuantType - Quantization type tracking
3// ============================================================================
4
5/// Quantization type for tracking quantization errors (MLT-04).
6///
7/// Note: Variant names follow GGML conventions (e.g., Q4_K) for interoperability.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
9#[allow(non_camel_case_types)]
10pub enum QuantType {
11    /// Full precision (FP32)
12    #[default]
13    F32,
14    /// Half precision (FP16)
15    F16,
16    /// Brain floating point (BF16)
17    Bf16,
18    /// 8-bit integer quantization
19    Q8_0,
20    /// 4-bit quantization (GGML)
21    Q4_0,
22    /// 4-bit quantization with k-quants
23    Q4_K,
24    /// 5-bit quantization with k-quants
25    Q5_K,
26    /// 6-bit quantization with k-quants
27    Q6_K,
28    /// 2-bit quantization
29    Q2_K,
30    /// 3-bit quantization
31    Q3_K,
32}
33
34impl QuantType {
35    /// Get bits per element for this quantization type.
36    pub fn bits_per_element(self) -> f32 {
37        match self {
38            Self::F32 => 32.0,
39            Self::F16 | Self::Bf16 => 16.0,
40            Self::Q8_0 => 8.0,
41            Self::Q6_K => 6.5,
42            Self::Q5_K => 5.5,
43            Self::Q4_0 | Self::Q4_K => 4.5,
44            Self::Q3_K => 3.5,
45            Self::Q2_K => 2.5,
46        }
47    }
48
49    /// Get compression ratio vs FP32.
50    pub fn compression_ratio(self) -> f32 {
51        32.0 / self.bits_per_element()
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58
59    #[test]
60    fn test_quant_type_bits() {
61        assert_eq!(QuantType::F32.bits_per_element(), 32.0);
62        assert_eq!(QuantType::F16.bits_per_element(), 16.0);
63        assert_eq!(QuantType::Q8_0.bits_per_element(), 8.0);
64        assert_eq!(QuantType::Q4_K.bits_per_element(), 4.5);
65    }
66
67    #[test]
68    fn test_quant_type_compression_ratio() {
69        // F32 -> F32 = 1x
70        assert!((QuantType::F32.compression_ratio() - 1.0).abs() < 0.01);
71        // F32 -> F16 = 2x
72        assert!((QuantType::F16.compression_ratio() - 2.0).abs() < 0.01);
73        // F32 -> Q4_K = ~7.1x
74        assert!(QuantType::Q4_K.compression_ratio() > 7.0);
75    }
76}