kizzasi_inference/
precision.rs

1//! Mixed precision support for efficient inference
2//!
3//! This module provides support for FP16 (half precision) and BF16 (bfloat16)
4//! inference to reduce memory usage and increase throughput on supported hardware.
5
6use crate::error::{InferenceError, InferenceResult};
7use half::{bf16, f16};
8use scirs2_core::ndarray::{Array1, Array2};
9
10/// Precision mode for inference
11#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Default)]
12pub enum PrecisionMode {
13    /// Full precision (FP32)
14    #[default]
15    FP32,
16    /// Half precision (FP16) - good for NVIDIA GPUs
17    FP16,
18    /// Brain float 16 (BF16) - good for modern accelerators
19    BF16,
20    /// Mixed precision - compute in FP16/BF16 but accumulate in FP32
21    Mixed {
22        /// Compute precision
23        compute: ComputePrecision,
24        /// Whether to accumulate in FP32
25        accumulate_fp32: bool,
26    },
27}
28
29/// Compute precision for mixed mode
30#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31pub enum ComputePrecision {
32    /// FP16 compute
33    FP16,
34    /// BF16 compute
35    BF16,
36}
37
38impl PrecisionMode {
39    /// Check if this mode uses reduced precision
40    pub fn is_reduced_precision(&self) -> bool {
41        !matches!(self, PrecisionMode::FP32)
42    }
43
44    /// Get the memory reduction factor compared to FP32
45    pub fn memory_reduction_factor(&self) -> f32 {
46        match self {
47            PrecisionMode::FP32 => 1.0,
48            PrecisionMode::FP16 | PrecisionMode::BF16 => 0.5,
49            PrecisionMode::Mixed { .. } => 0.75, // Mixed uses some FP32 for accumulation
50        }
51    }
52
53    /// Get human-readable name
54    pub fn name(&self) -> &str {
55        match self {
56            PrecisionMode::FP32 => "FP32",
57            PrecisionMode::FP16 => "FP16",
58            PrecisionMode::BF16 => "BF16",
59            PrecisionMode::Mixed { compute, .. } => match compute {
60                ComputePrecision::FP16 => "Mixed-FP16",
61                ComputePrecision::BF16 => "Mixed-BF16",
62            },
63        }
64    }
65}
66
67/// Configuration for mixed precision inference
68#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
69pub struct PrecisionConfig {
70    /// Precision mode to use
71    pub mode: PrecisionMode,
72    /// Loss scaling factor to prevent underflow in FP16
73    pub loss_scale: f32,
74    /// Whether to automatically adjust loss scale
75    pub dynamic_loss_scale: bool,
76    /// Threshold for gradient clipping in reduced precision
77    pub grad_clip_threshold: Option<f32>,
78}
79
80impl Default for PrecisionConfig {
81    fn default() -> Self {
82        Self {
83            mode: PrecisionMode::FP32,
84            loss_scale: 1.0,
85            dynamic_loss_scale: false,
86            grad_clip_threshold: None,
87        }
88    }
89}
90
91impl PrecisionConfig {
92    /// Create a new precision configuration
93    pub fn new() -> Self {
94        Self::default()
95    }
96
97    /// Set precision mode
98    pub fn mode(mut self, mode: PrecisionMode) -> Self {
99        self.mode = mode;
100        self
101    }
102
103    /// Enable FP16 precision
104    pub fn fp16(mut self) -> Self {
105        self.mode = PrecisionMode::FP16;
106        self
107    }
108
109    /// Enable BF16 precision
110    pub fn bf16(mut self) -> Self {
111        self.mode = PrecisionMode::BF16;
112        self
113    }
114
115    /// Enable mixed precision with FP16 compute
116    pub fn mixed_fp16(mut self, accumulate_fp32: bool) -> Self {
117        self.mode = PrecisionMode::Mixed {
118            compute: ComputePrecision::FP16,
119            accumulate_fp32,
120        };
121        self
122    }
123
124    /// Enable mixed precision with BF16 compute
125    pub fn mixed_bf16(mut self, accumulate_fp32: bool) -> Self {
126        self.mode = PrecisionMode::Mixed {
127            compute: ComputePrecision::BF16,
128            accumulate_fp32,
129        };
130        self
131    }
132
133    /// Set loss scaling factor
134    pub fn loss_scale(mut self, scale: f32) -> Self {
135        self.loss_scale = scale;
136        self
137    }
138
139    /// Enable dynamic loss scaling
140    pub fn dynamic_loss_scale(mut self, enabled: bool) -> Self {
141        self.dynamic_loss_scale = enabled;
142        self
143    }
144
145    /// Set gradient clipping threshold
146    pub fn grad_clip_threshold(mut self, threshold: f32) -> Self {
147        self.grad_clip_threshold = Some(threshold);
148        self
149    }
150}
151
152/// Precision converter for array operations
153pub struct PrecisionConverter {
154    config: PrecisionConfig,
155}
156
157impl PrecisionConverter {
158    /// Create a new precision converter
159    pub fn new(config: PrecisionConfig) -> Self {
160        Self { config }
161    }
162
163    /// Convert FP32 array to reduced precision and back (for inference)
164    pub fn convert_and_compute_1d(
165        &self,
166        data: &Array1<f32>,
167        op: impl Fn(&Array1<f32>) -> Array1<f32>,
168    ) -> InferenceResult<Array1<f32>> {
169        match self.config.mode {
170            PrecisionMode::FP32 => Ok(op(data)),
171            PrecisionMode::FP16 => {
172                let fp16_data = self.to_fp16_1d(data);
173                let fp16_result = op(&self.from_fp16_1d(&fp16_data));
174                Ok(fp16_result)
175            }
176            PrecisionMode::BF16 => {
177                let bf16_data = self.to_bf16_1d(data);
178                let bf16_result = op(&self.from_bf16_1d(&bf16_data));
179                Ok(bf16_result)
180            }
181            PrecisionMode::Mixed {
182                compute,
183                accumulate_fp32,
184            } => {
185                if accumulate_fp32 {
186                    // Compute in reduced precision, accumulate in FP32
187                    let reduced = match compute {
188                        ComputePrecision::FP16 => {
189                            let fp16_data = self.to_fp16_1d(data);
190                            self.from_fp16_1d(&fp16_data)
191                        }
192                        ComputePrecision::BF16 => {
193                            let bf16_data = self.to_bf16_1d(data);
194                            self.from_bf16_1d(&bf16_data)
195                        }
196                    };
197                    Ok(op(&reduced))
198                } else {
199                    // Full mixed precision
200                    match compute {
201                        ComputePrecision::FP16 => {
202                            let fp16_data = self.to_fp16_1d(data);
203                            Ok(op(&self.from_fp16_1d(&fp16_data)))
204                        }
205                        ComputePrecision::BF16 => {
206                            let bf16_data = self.to_bf16_1d(data);
207                            Ok(op(&self.from_bf16_1d(&bf16_data)))
208                        }
209                    }
210                }
211            }
212        }
213    }
214
215    /// Convert FP32 array to FP16
216    pub fn to_fp16_1d(&self, data: &Array1<f32>) -> Vec<f16> {
217        data.iter().map(|&x| f16::from_f32(x)).collect()
218    }
219
220    /// Convert FP16 array to FP32
221    pub fn from_fp16_1d(&self, data: &[f16]) -> Array1<f32> {
222        Array1::from_vec(data.iter().map(|&x| x.to_f32()).collect())
223    }
224
225    /// Convert FP32 array to BF16
226    pub fn to_bf16_1d(&self, data: &Array1<f32>) -> Vec<bf16> {
227        data.iter().map(|&x| bf16::from_f32(x)).collect()
228    }
229
230    /// Convert BF16 array to FP32
231    pub fn from_bf16_1d(&self, data: &[bf16]) -> Array1<f32> {
232        Array1::from_vec(data.iter().map(|&x| x.to_f32()).collect())
233    }
234
235    /// Convert 2D FP32 array to FP16
236    pub fn to_fp16_2d(&self, data: &Array2<f32>) -> Vec<f16> {
237        data.iter().map(|&x| f16::from_f32(x)).collect()
238    }
239
240    /// Convert FP16 to 2D FP32 array
241    pub fn from_fp16_2d(
242        &self,
243        data: &[f16],
244        shape: (usize, usize),
245    ) -> InferenceResult<Array2<f32>> {
246        let vec: Vec<f32> = data.iter().map(|&x| x.to_f32()).collect();
247        Array2::from_shape_vec(shape, vec).map_err(|e| {
248            InferenceError::ForwardError(format!("Shape error in FP16 conversion: {}", e))
249        })
250    }
251
252    /// Convert 2D FP32 array to BF16
253    pub fn to_bf16_2d(&self, data: &Array2<f32>) -> Vec<bf16> {
254        data.iter().map(|&x| bf16::from_f32(x)).collect()
255    }
256
257    /// Convert BF16 to 2D FP32 array
258    pub fn from_bf16_2d(
259        &self,
260        data: &[bf16],
261        shape: (usize, usize),
262    ) -> InferenceResult<Array2<f32>> {
263        let vec: Vec<f32> = data.iter().map(|&x| x.to_f32()).collect();
264        Array2::from_shape_vec(shape, vec).map_err(|e| {
265            InferenceError::ForwardError(format!("Shape error in BF16 conversion: {}", e))
266        })
267    }
268
269    /// Get the configuration
270    pub fn config(&self) -> &PrecisionConfig {
271        &self.config
272    }
273}
274
275/// Statistics about precision conversion
276#[derive(Debug, Clone, Default)]
277pub struct PrecisionStats {
278    /// Number of conversions performed
279    pub num_conversions: usize,
280    /// Total memory saved (bytes)
281    pub memory_saved: usize,
282    /// Average numerical error from conversion
283    pub avg_error: f64,
284    /// Maximum numerical error observed
285    pub max_error: f64,
286}
287
288impl PrecisionStats {
289    /// Create new statistics
290    pub fn new() -> Self {
291        Self::default()
292    }
293
294    /// Record a conversion
295    pub fn record_conversion(&mut self, original_size: usize, precision_mode: &PrecisionMode) {
296        self.num_conversions += 1;
297        let saved =
298            (original_size as f32 * (1.0 - precision_mode.memory_reduction_factor())) as usize;
299        self.memory_saved += saved;
300    }
301
302    /// Record numerical error
303    pub fn record_error(&mut self, error: f64) {
304        let n = self.num_conversions as f64;
305        self.avg_error = (self.avg_error * (n - 1.0) + error) / n;
306        self.max_error = self.max_error.max(error);
307    }
308
309    /// Get memory saved in MB
310    pub fn memory_saved_mb(&self) -> f64 {
311        self.memory_saved as f64 / (1024.0 * 1024.0)
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_precision_mode_creation() {
321        let mode = PrecisionMode::FP32;
322        assert_eq!(mode.name(), "FP32");
323        assert!(!mode.is_reduced_precision());
324    }
325
326    #[test]
327    fn test_precision_mode_fp16() {
328        let mode = PrecisionMode::FP16;
329        assert_eq!(mode.name(), "FP16");
330        assert!(mode.is_reduced_precision());
331        assert_eq!(mode.memory_reduction_factor(), 0.5);
332    }
333
334    #[test]
335    fn test_precision_mode_bf16() {
336        let mode = PrecisionMode::BF16;
337        assert_eq!(mode.name(), "BF16");
338        assert!(mode.is_reduced_precision());
339        assert_eq!(mode.memory_reduction_factor(), 0.5);
340    }
341
342    #[test]
343    fn test_precision_mode_mixed() {
344        let mode = PrecisionMode::Mixed {
345            compute: ComputePrecision::FP16,
346            accumulate_fp32: true,
347        };
348        assert_eq!(mode.name(), "Mixed-FP16");
349        assert!(mode.is_reduced_precision());
350    }
351
352    #[test]
353    fn test_precision_config_builder() {
354        let config = PrecisionConfig::new()
355            .fp16()
356            .loss_scale(128.0)
357            .dynamic_loss_scale(true);
358
359        assert_eq!(config.mode, PrecisionMode::FP16);
360        assert_eq!(config.loss_scale, 128.0);
361        assert!(config.dynamic_loss_scale);
362    }
363
364    #[test]
365    fn test_fp16_conversion_1d() {
366        let config = PrecisionConfig::new().fp16();
367        let converter = PrecisionConverter::new(config);
368
369        let data = Array1::from_vec(vec![1.0, 2.5, -3.75, 0.0]);
370        let fp16_data = converter.to_fp16_1d(&data);
371        let restored = converter.from_fp16_1d(&fp16_data);
372
373        // Should be very close (within FP16 precision)
374        for (orig, rest) in data.iter().zip(restored.iter()) {
375            assert!((orig - rest).abs() < 0.001);
376        }
377    }
378
379    #[test]
380    fn test_bf16_conversion_1d() {
381        let config = PrecisionConfig::new().bf16();
382        let converter = PrecisionConverter::new(config);
383
384        let data = Array1::from_vec(vec![1.0, 2.5, -3.75, 0.0]);
385        let bf16_data = converter.to_bf16_1d(&data);
386        let restored = converter.from_bf16_1d(&bf16_data);
387
388        // BF16 has less precision than FP16
389        for (orig, rest) in data.iter().zip(restored.iter()) {
390            assert!((orig - rest).abs() < 0.01);
391        }
392    }
393
394    #[test]
395    fn test_convert_and_compute() {
396        let config = PrecisionConfig::new().fp16();
397        let converter = PrecisionConverter::new(config);
398
399        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
400
401        // Simple doubling operation
402        let result = converter
403            .convert_and_compute_1d(&data, |x| x.mapv(|v| v * 2.0))
404            .unwrap();
405
406        // Check results are close to expected (within FP16 precision)
407        for (i, &val) in result.iter().enumerate() {
408            let expected = data[i] * 2.0;
409            assert!((val - expected).abs() < 0.01);
410        }
411    }
412
413    #[test]
414    fn test_precision_stats() {
415        let mut stats = PrecisionStats::new();
416
417        assert_eq!(stats.num_conversions, 0);
418        assert_eq!(stats.memory_saved, 0);
419
420        let mode = PrecisionMode::FP16;
421        stats.record_conversion(1000, &mode);
422
423        assert_eq!(stats.num_conversions, 1);
424        assert_eq!(stats.memory_saved, 500); // 50% reduction
425
426        stats.record_error(0.001);
427        assert!(stats.avg_error > 0.0);
428        assert!(stats.max_error > 0.0);
429    }
430
431    #[test]
432    fn test_mixed_precision_compute() {
433        let config = PrecisionConfig::new().mixed_fp16(true);
434        let converter = PrecisionConverter::new(config);
435
436        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
437
438        let result = converter
439            .convert_and_compute_1d(&data, |x| x.mapv(|v| v * 2.0))
440            .unwrap();
441
442        // Mixed precision should have good accuracy
443        for (i, &val) in result.iter().enumerate() {
444            let expected = data[i] * 2.0;
445            assert!((val - expected).abs() < 0.001);
446        }
447    }
448
449    #[test]
450    fn test_fp16_2d_conversion() {
451        let config = PrecisionConfig::new().fp16();
452        let converter = PrecisionConverter::new(config);
453
454        let data = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
455
456        let fp16_data = converter.to_fp16_2d(&data);
457        let restored = converter.from_fp16_2d(&fp16_data, (2, 3)).unwrap();
458
459        assert_eq!(restored.shape(), &[2, 3]);
460
461        for (orig, rest) in data.iter().zip(restored.iter()) {
462            assert!((orig - rest).abs() < 0.001);
463        }
464    }
465
466    #[test]
467    fn test_bf16_2d_conversion() {
468        let config = PrecisionConfig::new().bf16();
469        let converter = PrecisionConverter::new(config);
470
471        let data = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
472
473        let bf16_data = converter.to_bf16_2d(&data);
474        let restored = converter.from_bf16_2d(&bf16_data, (2, 2)).unwrap();
475
476        assert_eq!(restored.shape(), &[2, 2]);
477
478        for (orig, rest) in data.iter().zip(restored.iter()) {
479            assert!((orig - rest).abs() < 0.01);
480        }
481    }
482}