1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
//! Model weight quantisation levels.
use std::fmt;
use serde::{Deserialize, Serialize};
/// Model weight quantisation levels.
///
/// # Examples
///
/// ```rust
/// use ai_hwaccel::QuantizationLevel;
///
/// let q = QuantizationLevel::Int8;
/// assert_eq!(q.bits_per_param(), 8);
/// assert!((q.memory_reduction_factor() - 4.0).abs() < f64::EPSILON);
/// ```
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum QuantizationLevel {
/// Full precision — FP32, 32 bits per parameter.
None,
/// Half precision — FP16, 16 bits per parameter.
Float16,
/// Brain floating point — BF16, 16 bits per parameter.
BFloat16,
/// 8-bit integer quantisation.
Int8,
/// 4-bit integer quantisation (GPTQ / AWQ style).
Int4,
}
impl QuantizationLevel {
/// Number of bits used per model parameter.
#[must_use]
#[inline]
pub fn bits_per_param(&self) -> u32 {
match self {
Self::None => 32,
Self::Float16 | Self::BFloat16 => 16,
Self::Int8 => 8,
Self::Int4 => 4,
}
}
/// Memory reduction factor relative to FP32.
#[must_use]
#[inline]
pub fn memory_reduction_factor(&self) -> f64 {
32.0 / self.bits_per_param() as f64
}
}
impl TryFrom<u32> for QuantizationLevel {
type Error = u32;
/// Convert from bit width to quantisation level.
///
/// - `32` → `None` (FP32)
/// - `16` → `Float16`
/// - `8` → `Int8`
/// - `4` → `Int4`
fn try_from(bits: u32) -> Result<Self, u32> {
match bits {
32 => Ok(Self::None),
16 => Ok(Self::Float16),
8 => Ok(Self::Int8),
4 => Ok(Self::Int4),
other => Err(other),
}
}
}
impl fmt::Display for QuantizationLevel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => write!(f, "FP32"),
Self::Float16 => write!(f, "FP16"),
Self::BFloat16 => write!(f, "BF16"),
Self::Int8 => write!(f, "INT8"),
Self::Int4 => write!(f, "INT4"),
}
}
}