Skip to main content

haagenti_adaptive/
precision.rs

1//! Precision types and hardware capabilities
2
3use serde::{Deserialize, Serialize};
4
5/// Precision levels for inference
6#[derive(
7    Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Default,
8)]
9pub enum Precision {
10    /// 4-bit integer (most aggressive compression)
11    INT4,
12    /// 8-bit integer
13    INT8,
14    /// 16-bit floating point (brain float)
15    BF16,
16    /// 16-bit floating point
17    #[default]
18    FP16,
19    /// 32-bit floating point (full precision)
20    FP32,
21}
22
23impl Precision {
24    /// Bits per element
25    pub fn bits(&self) -> u32 {
26        match self {
27            Precision::INT4 => 4,
28            Precision::INT8 => 8,
29            Precision::BF16 => 16,
30            Precision::FP16 => 16,
31            Precision::FP32 => 32,
32        }
33    }
34
35    /// Bytes per element
36    pub fn bytes(&self) -> f32 {
37        self.bits() as f32 / 8.0
38    }
39
40    /// VRAM usage relative to FP32 (0.0 - 1.0)
41    pub fn vram_ratio(&self) -> f32 {
42        self.bits() as f32 / 32.0
43    }
44
45    /// Approximate speedup factor relative to FP32
46    pub fn speedup_factor(&self) -> f32 {
47        match self {
48            Precision::INT4 => 4.0,
49            Precision::INT8 => 2.5,
50            Precision::BF16 => 1.8,
51            Precision::FP16 => 2.0,
52            Precision::FP32 => 1.0,
53        }
54    }
55
56    /// Quality impact (lower is more lossy)
57    /// This is an approximation - actual impact depends on model and content
58    pub fn quality_factor(&self) -> f32 {
59        match self {
60            Precision::INT4 => 0.92,
61            Precision::INT8 => 0.97,
62            Precision::BF16 => 0.995,
63            Precision::FP16 => 0.998,
64            Precision::FP32 => 1.0,
65        }
66    }
67
68    /// Whether this precision is lossless (or nearly so)
69    pub fn is_lossless(&self) -> bool {
70        matches!(self, Precision::FP32 | Precision::FP16 | Precision::BF16)
71    }
72
73    /// Parse from string (returns None on failure)
74    pub fn parse(s: &str) -> Option<Self> {
75        s.parse().ok()
76    }
77}
78
79impl std::str::FromStr for Precision {
80    type Err = ();
81
82    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
83        match s.to_uppercase().as_str() {
84            "INT4" | "I4" | "4BIT" => Ok(Precision::INT4),
85            "INT8" | "I8" | "8BIT" => Ok(Precision::INT8),
86            "BF16" | "BFLOAT16" => Ok(Precision::BF16),
87            "FP16" | "FLOAT16" | "F16" | "HALF" => Ok(Precision::FP16),
88            "FP32" | "FLOAT32" | "F32" | "FULL" => Ok(Precision::FP32),
89            _ => Err(()),
90        }
91    }
92}
93
94impl std::fmt::Display for Precision {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        match self {
97            Precision::INT4 => write!(f, "INT4"),
98            Precision::INT8 => write!(f, "INT8"),
99            Precision::BF16 => write!(f, "BF16"),
100            Precision::FP16 => write!(f, "FP16"),
101            Precision::FP32 => write!(f, "FP32"),
102        }
103    }
104}
105
106/// Hardware precision capabilities
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct PrecisionCapabilities {
109    /// Supported precisions (ordered by preference)
110    pub supported: Vec<Precision>,
111    /// Native (fastest) precision
112    pub native: Precision,
113    /// Whether INT4 uses tensor cores
114    pub int4_tensor_cores: bool,
115    /// Whether INT8 uses tensor cores
116    pub int8_tensor_cores: bool,
117    /// Available VRAM in MB
118    pub vram_mb: u64,
119    /// Compute capability (for NVIDIA GPUs)
120    pub compute_capability: Option<(u32, u32)>,
121}
122
123impl PrecisionCapabilities {
124    /// Create capabilities for a modern GPU (RTX 30/40 series)
125    pub fn modern_gpu(vram_mb: u64) -> Self {
126        Self {
127            supported: vec![
128                Precision::INT4,
129                Precision::INT8,
130                Precision::BF16,
131                Precision::FP16,
132                Precision::FP32,
133            ],
134            native: Precision::FP16,
135            int4_tensor_cores: true,
136            int8_tensor_cores: true,
137            vram_mb,
138            compute_capability: Some((8, 6)), // Ampere
139        }
140    }
141
142    /// Create capabilities for an older GPU (GTX 10 series)
143    pub fn legacy_gpu(vram_mb: u64) -> Self {
144        Self {
145            supported: vec![Precision::FP16, Precision::FP32],
146            native: Precision::FP32,
147            int4_tensor_cores: false,
148            int8_tensor_cores: false,
149            vram_mb,
150            compute_capability: Some((6, 1)), // Pascal
151        }
152    }
153
154    /// Create capabilities for CPU
155    pub fn cpu(memory_mb: u64) -> Self {
156        Self {
157            supported: vec![Precision::INT8, Precision::FP32],
158            native: Precision::FP32,
159            int4_tensor_cores: false,
160            int8_tensor_cores: false,
161            vram_mb: memory_mb,
162            compute_capability: None,
163        }
164    }
165
166    /// Check if a precision is supported
167    pub fn supports(&self, precision: Precision) -> bool {
168        self.supported.contains(&precision)
169    }
170
171    /// Get the best supported precision at or below the given level
172    pub fn best_supported(&self, max_precision: Precision) -> Precision {
173        self.supported
174            .iter()
175            .filter(|&&p| p <= max_precision)
176            .max()
177            .copied()
178            .unwrap_or(self.native)
179    }
180
181    /// Estimate VRAM usage for a model at given precision
182    pub fn estimate_vram(&self, model_params: u64, precision: Precision) -> u64 {
183        let base_bytes = model_params * 4; // FP32 baseline
184        (base_bytes as f32 * precision.vram_ratio()) as u64
185    }
186
187    /// Check if a model fits at given precision
188    pub fn fits_model(&self, model_params: u64, precision: Precision) -> bool {
189        let required_mb = self.estimate_vram(model_params, precision) / (1024 * 1024);
190        required_mb <= self.vram_mb
191    }
192}
193
194impl Default for PrecisionCapabilities {
195    fn default() -> Self {
196        Self::modern_gpu(8192) // 8GB default
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_precision_ordering() {
206        assert!(Precision::INT4 < Precision::INT8);
207        assert!(Precision::INT8 < Precision::FP16);
208        assert!(Precision::FP16 < Precision::FP32);
209    }
210
211    #[test]
212    fn test_precision_bits() {
213        assert_eq!(Precision::INT4.bits(), 4);
214        assert_eq!(Precision::INT8.bits(), 8);
215        assert_eq!(Precision::FP16.bits(), 16);
216        assert_eq!(Precision::FP32.bits(), 32);
217    }
218
219    #[test]
220    fn test_vram_ratio() {
221        assert!((Precision::INT4.vram_ratio() - 0.125).abs() < 0.001);
222        assert!((Precision::INT8.vram_ratio() - 0.25).abs() < 0.001);
223        assert!((Precision::FP16.vram_ratio() - 0.5).abs() < 0.001);
224        assert!((Precision::FP32.vram_ratio() - 1.0).abs() < 0.001);
225    }
226
227    #[test]
228    fn test_capabilities() {
229        let caps = PrecisionCapabilities::modern_gpu(12288);
230        assert!(caps.supports(Precision::INT4));
231        assert!(caps.supports(Precision::FP16));
232        assert!(caps.int4_tensor_cores);
233    }
234
235    #[test]
236    fn test_best_supported() {
237        let caps = PrecisionCapabilities::legacy_gpu(4096);
238        // Legacy doesn't support INT4/INT8
239        assert_eq!(caps.best_supported(Precision::INT4), Precision::FP32);
240        assert_eq!(caps.best_supported(Precision::FP16), Precision::FP16);
241    }
242}