Skip to main content

ai_hwaccel/
quantization.rs

1//! Model weight quantisation levels.
2
3use std::fmt;
4
5use serde::{Deserialize, Serialize};
6
7/// Model weight quantisation levels.
8///
9/// # Examples
10///
11/// ```rust
12/// use ai_hwaccel::QuantizationLevel;
13///
14/// let q = QuantizationLevel::Int8;
15/// assert_eq!(q.bits_per_param(), 8);
16/// assert!((q.memory_reduction_factor() - 4.0).abs() < f64::EPSILON);
17/// ```
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[non_exhaustive]
20pub enum QuantizationLevel {
21    /// Full precision — FP32, 32 bits per parameter.
22    None,
23    /// Half precision — FP16, 16 bits per parameter.
24    Float16,
25    /// Brain floating point — BF16, 16 bits per parameter.
26    BFloat16,
27    /// 8-bit integer quantisation.
28    Int8,
29    /// 4-bit integer quantisation (GPTQ / AWQ style).
30    Int4,
31}
32
33impl QuantizationLevel {
34    /// Number of bits used per model parameter.
35    #[must_use]
36    #[inline]
37    pub fn bits_per_param(&self) -> u32 {
38        match self {
39            Self::None => 32,
40            Self::Float16 | Self::BFloat16 => 16,
41            Self::Int8 => 8,
42            Self::Int4 => 4,
43        }
44    }
45
46    /// Memory reduction factor relative to FP32.
47    #[must_use]
48    #[inline]
49    pub fn memory_reduction_factor(&self) -> f64 {
50        32.0 / self.bits_per_param() as f64
51    }
52}
53
54impl TryFrom<u32> for QuantizationLevel {
55    type Error = u32;
56
57    /// Convert from bit width to quantisation level.
58    ///
59    /// - `32` → `None` (FP32)
60    /// - `16` → `Float16`
61    /// - `8`  → `Int8`
62    /// - `4`  → `Int4`
63    fn try_from(bits: u32) -> Result<Self, u32> {
64        match bits {
65            32 => Ok(Self::None),
66            16 => Ok(Self::Float16),
67            8 => Ok(Self::Int8),
68            4 => Ok(Self::Int4),
69            other => Err(other),
70        }
71    }
72}
73
74impl fmt::Display for QuantizationLevel {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        match self {
77            Self::None => write!(f, "FP32"),
78            Self::Float16 => write!(f, "FP16"),
79            Self::BFloat16 => write!(f, "BF16"),
80            Self::Int8 => write!(f, "INT8"),
81            Self::Int4 => write!(f, "INT4"),
82        }
83    }
84}