1use serde::{Deserialize, Serialize};
4
5#[derive(
7 Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Default,
8)]
9pub enum Precision {
10 INT4,
12 INT8,
14 BF16,
16 #[default]
18 FP16,
19 FP32,
21}
22
23impl Precision {
24 pub fn bits(&self) -> u32 {
26 match self {
27 Precision::INT4 => 4,
28 Precision::INT8 => 8,
29 Precision::BF16 => 16,
30 Precision::FP16 => 16,
31 Precision::FP32 => 32,
32 }
33 }
34
35 pub fn bytes(&self) -> f32 {
37 self.bits() as f32 / 8.0
38 }
39
40 pub fn vram_ratio(&self) -> f32 {
42 self.bits() as f32 / 32.0
43 }
44
45 pub fn speedup_factor(&self) -> f32 {
47 match self {
48 Precision::INT4 => 4.0,
49 Precision::INT8 => 2.5,
50 Precision::BF16 => 1.8,
51 Precision::FP16 => 2.0,
52 Precision::FP32 => 1.0,
53 }
54 }
55
56 pub fn quality_factor(&self) -> f32 {
59 match self {
60 Precision::INT4 => 0.92,
61 Precision::INT8 => 0.97,
62 Precision::BF16 => 0.995,
63 Precision::FP16 => 0.998,
64 Precision::FP32 => 1.0,
65 }
66 }
67
68 pub fn is_lossless(&self) -> bool {
70 matches!(self, Precision::FP32 | Precision::FP16 | Precision::BF16)
71 }
72
73 pub fn parse(s: &str) -> Option<Self> {
75 s.parse().ok()
76 }
77}
78
79impl std::str::FromStr for Precision {
80 type Err = ();
81
82 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
83 match s.to_uppercase().as_str() {
84 "INT4" | "I4" | "4BIT" => Ok(Precision::INT4),
85 "INT8" | "I8" | "8BIT" => Ok(Precision::INT8),
86 "BF16" | "BFLOAT16" => Ok(Precision::BF16),
87 "FP16" | "FLOAT16" | "F16" | "HALF" => Ok(Precision::FP16),
88 "FP32" | "FLOAT32" | "F32" | "FULL" => Ok(Precision::FP32),
89 _ => Err(()),
90 }
91 }
92}
93
94impl std::fmt::Display for Precision {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 Precision::INT4 => write!(f, "INT4"),
98 Precision::INT8 => write!(f, "INT8"),
99 Precision::BF16 => write!(f, "BF16"),
100 Precision::FP16 => write!(f, "FP16"),
101 Precision::FP32 => write!(f, "FP32"),
102 }
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct PrecisionCapabilities {
109 pub supported: Vec<Precision>,
111 pub native: Precision,
113 pub int4_tensor_cores: bool,
115 pub int8_tensor_cores: bool,
117 pub vram_mb: u64,
119 pub compute_capability: Option<(u32, u32)>,
121}
122
123impl PrecisionCapabilities {
124 pub fn modern_gpu(vram_mb: u64) -> Self {
126 Self {
127 supported: vec![
128 Precision::INT4,
129 Precision::INT8,
130 Precision::BF16,
131 Precision::FP16,
132 Precision::FP32,
133 ],
134 native: Precision::FP16,
135 int4_tensor_cores: true,
136 int8_tensor_cores: true,
137 vram_mb,
138 compute_capability: Some((8, 6)), }
140 }
141
142 pub fn legacy_gpu(vram_mb: u64) -> Self {
144 Self {
145 supported: vec![Precision::FP16, Precision::FP32],
146 native: Precision::FP32,
147 int4_tensor_cores: false,
148 int8_tensor_cores: false,
149 vram_mb,
150 compute_capability: Some((6, 1)), }
152 }
153
154 pub fn cpu(memory_mb: u64) -> Self {
156 Self {
157 supported: vec![Precision::INT8, Precision::FP32],
158 native: Precision::FP32,
159 int4_tensor_cores: false,
160 int8_tensor_cores: false,
161 vram_mb: memory_mb,
162 compute_capability: None,
163 }
164 }
165
166 pub fn supports(&self, precision: Precision) -> bool {
168 self.supported.contains(&precision)
169 }
170
171 pub fn best_supported(&self, max_precision: Precision) -> Precision {
173 self.supported
174 .iter()
175 .filter(|&&p| p <= max_precision)
176 .max()
177 .copied()
178 .unwrap_or(self.native)
179 }
180
181 pub fn estimate_vram(&self, model_params: u64, precision: Precision) -> u64 {
183 let base_bytes = model_params * 4; (base_bytes as f32 * precision.vram_ratio()) as u64
185 }
186
187 pub fn fits_model(&self, model_params: u64, precision: Precision) -> bool {
189 let required_mb = self.estimate_vram(model_params, precision) / (1024 * 1024);
190 required_mb <= self.vram_mb
191 }
192}
193
194impl Default for PrecisionCapabilities {
195 fn default() -> Self {
196 Self::modern_gpu(8192) }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_precision_ordering() {
206 assert!(Precision::INT4 < Precision::INT8);
207 assert!(Precision::INT8 < Precision::FP16);
208 assert!(Precision::FP16 < Precision::FP32);
209 }
210
211 #[test]
212 fn test_precision_bits() {
213 assert_eq!(Precision::INT4.bits(), 4);
214 assert_eq!(Precision::INT8.bits(), 8);
215 assert_eq!(Precision::FP16.bits(), 16);
216 assert_eq!(Precision::FP32.bits(), 32);
217 }
218
219 #[test]
220 fn test_vram_ratio() {
221 assert!((Precision::INT4.vram_ratio() - 0.125).abs() < 0.001);
222 assert!((Precision::INT8.vram_ratio() - 0.25).abs() < 0.001);
223 assert!((Precision::FP16.vram_ratio() - 0.5).abs() < 0.001);
224 assert!((Precision::FP32.vram_ratio() - 1.0).abs() < 0.001);
225 }
226
227 #[test]
228 fn test_capabilities() {
229 let caps = PrecisionCapabilities::modern_gpu(12288);
230 assert!(caps.supports(Precision::INT4));
231 assert!(caps.supports(Precision::FP16));
232 assert!(caps.int4_tensor_cores);
233 }
234
235 #[test]
236 fn test_best_supported() {
237 let caps = PrecisionCapabilities::legacy_gpu(4096);
238 assert_eq!(caps.best_supported(Precision::INT4), Precision::FP32);
240 assert_eq!(caps.best_supported(Precision::FP16), Precision::FP16);
241 }
242}