ghostflow_nn/
mixed_precision.rs

1//! Mixed Precision Training
2//!
3//! Implements mixed precision training for faster training and reduced memory:
4//! - FP16 (Half precision)
5//! - BF16 (Brain Float 16)
6//! - FP8 (8-bit floating point)
7//! - Automatic loss scaling
8//! - Gradient scaling and unscaling
9//! - Dynamic loss scaling
10
11use ghostflow_core::Tensor;
12use std::collections::HashMap;
13
14/// Precision mode for mixed precision training
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum PrecisionMode {
17    /// Full precision (FP32)
18    FP32,
19    /// Half precision (FP16)
20    FP16,
21    /// Brain Float 16
22    BF16,
23    /// 8-bit floating point
24    FP8,
25}
26
27/// Mixed precision training configuration
28#[derive(Debug, Clone)]
29pub struct MixedPrecisionConfig {
30    /// Precision mode
31    pub mode: PrecisionMode,
32    /// Initial loss scale
33    pub init_scale: f32,
34    /// Growth factor for loss scale
35    pub growth_factor: f32,
36    /// Backoff factor for loss scale
37    pub backoff_factor: f32,
38    /// Growth interval (steps)
39    pub growth_interval: usize,
40    /// Enable dynamic loss scaling
41    pub dynamic_loss_scale: bool,
42}
43
44impl Default for MixedPrecisionConfig {
45    fn default() -> Self {
46        MixedPrecisionConfig {
47            mode: PrecisionMode::FP16,
48            init_scale: 65536.0,
49            growth_factor: 2.0,
50            backoff_factor: 0.5,
51            growth_interval: 2000,
52            dynamic_loss_scale: true,
53        }
54    }
55}
56
57impl MixedPrecisionConfig {
58    /// FP16 configuration
59    pub fn fp16() -> Self {
60        Self {
61            mode: PrecisionMode::FP16,
62            ..Default::default()
63        }
64    }
65    
66    /// BF16 configuration
67    pub fn bf16() -> Self {
68        Self {
69            mode: PrecisionMode::BF16,
70            init_scale: 1.0, // BF16 has better range, less scaling needed
71            dynamic_loss_scale: false,
72            ..Default::default()
73        }
74    }
75    
76    /// FP8 configuration
77    pub fn fp8() -> Self {
78        Self {
79            mode: PrecisionMode::FP8,
80            init_scale: 1024.0,
81            ..Default::default()
82        }
83    }
84}
85
86/// Gradient scaler for mixed precision training
87pub struct GradScaler {
88    config: MixedPrecisionConfig,
89    scale: f32,
90    growth_tracker: usize,
91    found_inf_count: usize,
92}
93
94impl GradScaler {
95    /// Create new gradient scaler
96    pub fn new(config: MixedPrecisionConfig) -> Self {
97        GradScaler {
98            scale: config.init_scale,
99            config,
100            growth_tracker: 0,
101            found_inf_count: 0,
102        }
103    }
104    
105    /// Scale loss for backward pass
106    pub fn scale_loss(&self, loss: &Tensor) -> Tensor {
107        if self.config.mode == PrecisionMode::FP32 {
108            return loss.clone();
109        }
110        
111        loss.mul_scalar(self.scale)
112    }
113    
114    /// Unscale gradients
115    pub fn unscale_gradients(&self, gradients: &mut HashMap<String, Tensor>) -> bool {
116        if self.config.mode == PrecisionMode::FP32 {
117            return true;
118        }
119        
120        let inv_scale = 1.0 / self.scale;
121        let mut found_inf = false;
122        
123        for (_name, grad) in gradients.iter_mut() {
124            // Check for inf/nan
125            if self.has_inf_or_nan(grad) {
126                found_inf = true;
127                break;
128            }
129            
130            // Unscale gradient
131            *grad = grad.mul_scalar(inv_scale);
132        }
133        
134        !found_inf
135    }
136    
137    /// Step optimizer with gradient scaling
138    pub fn step<F>(&mut self, optimizer_step: F, gradients: &mut HashMap<String, Tensor>) -> bool
139    where
140        F: FnOnce(),
141    {
142        // Unscale gradients
143        let success = self.unscale_gradients(gradients);
144        
145        if success {
146            // Take optimizer step
147            optimizer_step();
148            
149            // Update scale
150            self.update_scale(false);
151            true
152        } else {
153            // Skip step due to inf/nan
154            self.update_scale(true);
155            false
156        }
157    }
158    
159    /// Update loss scale
160    fn update_scale(&mut self, found_inf: bool) {
161        if !self.config.dynamic_loss_scale {
162            return;
163        }
164        
165        if found_inf {
166            // Reduce scale
167            self.scale *= self.config.backoff_factor;
168            self.scale = self.scale.max(1.0);
169            self.growth_tracker = 0;
170            self.found_inf_count += 1;
171        } else {
172            // Increase scale after growth_interval successful steps
173            self.growth_tracker += 1;
174            if self.growth_tracker >= self.config.growth_interval {
175                self.scale *= self.config.growth_factor;
176                self.scale = self.scale.min(65536.0); // Cap at 2^16
177                self.growth_tracker = 0;
178            }
179        }
180    }
181    
182    /// Check if tensor has inf or nan
183    fn has_inf_or_nan(&self, tensor: &Tensor) -> bool {
184        let data = tensor.data_f32();
185        data.iter().any(|&x| x.is_infinite() || x.is_nan())
186    }
187    
188    /// Get current scale
189    pub fn get_scale(&self) -> f32 {
190        self.scale
191    }
192    
193    /// Get statistics
194    pub fn get_stats(&self) -> (f32, usize, usize) {
195        (self.scale, self.growth_tracker, self.found_inf_count)
196    }
197}
198
199/// Convert tensor to lower precision
200pub fn to_half_precision(tensor: &Tensor, mode: PrecisionMode) -> Tensor {
201    match mode {
202        PrecisionMode::FP32 => tensor.clone(),
203        PrecisionMode::FP16 => convert_to_fp16(tensor),
204        PrecisionMode::BF16 => convert_to_bf16(tensor),
205        PrecisionMode::FP8 => convert_to_fp8(tensor),
206    }
207}
208
209/// Convert tensor from lower precision to FP32
210pub fn to_full_precision(tensor: &Tensor, mode: PrecisionMode) -> Tensor {
211    match mode {
212        PrecisionMode::FP32 => tensor.clone(),
213        PrecisionMode::FP16 => convert_from_fp16(tensor),
214        PrecisionMode::BF16 => convert_from_bf16(tensor),
215        PrecisionMode::FP8 => convert_from_fp8(tensor),
216    }
217}
218
219/// Convert to FP16 (IEEE 754 half precision)
220fn convert_to_fp16(tensor: &Tensor) -> Tensor {
221    let data = tensor.data_f32();
222    let dims = tensor.dims();
223    
224    // Simulate FP16 by clamping range and reducing precision
225    let fp16_data: Vec<f32> = data.iter().map(|&x| {
226        // FP16 range: ±65504
227        let clamped = x.clamp(-65504.0, 65504.0);
228        // Reduce precision (FP16 has 10-bit mantissa vs FP32's 23-bit)
229        let scale = 1024.0; // 2^10
230        (clamped * scale).round() / scale
231    }).collect();
232    
233    Tensor::from_slice(&fp16_data, dims).unwrap()
234}
235
236/// Convert from FP16 to FP32
237fn convert_from_fp16(tensor: &Tensor) -> Tensor {
238    // Already in FP32 representation, just return
239    tensor.clone()
240}
241
242/// Convert to BF16 (Brain Float 16)
243fn convert_to_bf16(tensor: &Tensor) -> Tensor {
244    let data = tensor.data_f32();
245    let dims = tensor.dims();
246    
247    // BF16: 8-bit exponent (same as FP32), 7-bit mantissa
248    // Better range than FP16, less precision
249    let bf16_data: Vec<f32> = data.iter().map(|&x| {
250        // Truncate mantissa to 7 bits
251        let bits = x.to_bits();
252        let truncated = bits & 0xFFFF_0000; // Keep sign, exponent, and top 7 mantissa bits
253        f32::from_bits(truncated)
254    }).collect();
255    
256    Tensor::from_slice(&bf16_data, dims).unwrap()
257}
258
259/// Convert from BF16 to FP32
260fn convert_from_bf16(tensor: &Tensor) -> Tensor {
261    tensor.clone()
262}
263
264/// Convert to FP8 (8-bit floating point)
265fn convert_to_fp8(tensor: &Tensor) -> Tensor {
266    let data = tensor.data_f32();
267    let dims = tensor.dims();
268    
269    // FP8 E4M3: 4-bit exponent, 3-bit mantissa
270    // Very limited range and precision
271    let fp8_data: Vec<f32> = data.iter().map(|&x| {
272        // Clamp to FP8 range (approximately ±448)
273        let clamped = x.clamp(-448.0, 448.0);
274        // Quantize to 8 levels of precision
275        let scale = 8.0;
276        (clamped * scale).round() / scale
277    }).collect();
278    
279    Tensor::from_slice(&fp8_data, dims).unwrap()
280}
281
282/// Convert from FP8 to FP32
283fn convert_from_fp8(tensor: &Tensor) -> Tensor {
284    tensor.clone()
285}
286
287/// Automatic Mixed Precision context manager
288pub struct AutocastContext {
289    mode: PrecisionMode,
290    enabled: bool,
291}
292
293impl AutocastContext {
294    /// Create new autocast context
295    pub fn new(mode: PrecisionMode) -> Self {
296        AutocastContext {
297            mode,
298            enabled: true,
299        }
300    }
301    
302    /// Disable autocast
303    pub fn disable(&mut self) {
304        self.enabled = false;
305    }
306    
307    /// Enable autocast
308    pub fn enable(&mut self) {
309        self.enabled = true;
310    }
311    
312    /// Cast tensor if autocast is enabled
313    pub fn cast(&self, tensor: &Tensor) -> Tensor {
314        if self.enabled {
315            to_half_precision(tensor, self.mode)
316        } else {
317            tensor.clone()
318        }
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    
326    #[test]
327    fn test_grad_scaler() {
328        let config = MixedPrecisionConfig::fp16();
329        let mut scaler = GradScaler::new(config);
330        
331        // Test loss scaling
332        let loss = Tensor::from_slice(&[1.0f32], &[1]).unwrap();
333        let scaled_loss = scaler.scale_loss(&loss);
334        
335        let scaled_data = scaled_loss.data_f32();
336        assert_eq!(scaled_data[0], 65536.0);
337    }
338    
339    #[test]
340    fn test_unscale_gradients() {
341        let config = MixedPrecisionConfig::fp16();
342        let scaler = GradScaler::new(config);
343        
344        let mut gradients = HashMap::new();
345        gradients.insert(
346            "weight".to_string(),
347            Tensor::from_slice(&[65536.0f32, 131072.0], &[2]).unwrap()
348        );
349        
350        let success = scaler.unscale_gradients(&mut gradients);
351        assert!(success);
352        
353        let grad = gradients.get("weight").unwrap();
354        let data = grad.data_f32();
355        assert!((data[0] - 1.0).abs() < 1e-5);
356        assert!((data[1] - 2.0).abs() < 1e-5);
357    }
358    
359    #[test]
360    fn test_fp16_conversion() {
361        let tensor = Tensor::from_slice(&[1.5f32, -2.5, 100.0], &[3]).unwrap();
362        let fp16 = convert_to_fp16(&tensor);
363        let data = fp16.data_f32();
364        
365        // Check values are approximately preserved
366        assert!((data[0] - 1.5).abs() < 0.01);
367        assert!((data[1] + 2.5).abs() < 0.01);
368        assert!((data[2] - 100.0).abs() < 0.1);
369    }
370    
371    #[test]
372    fn test_bf16_conversion() {
373        let tensor = Tensor::from_slice(&[1.5f32, -2.5, 1000.0], &[3]).unwrap();
374        let bf16 = convert_to_bf16(&tensor);
375        let data = bf16.data_f32();
376        
377        // BF16 should preserve larger values better than FP16
378        assert!((data[2] - 1000.0).abs() < 10.0);
379    }
380    
381    #[test]
382    fn test_inf_detection() {
383        let config = MixedPrecisionConfig::fp16();
384        let scaler = GradScaler::new(config);
385        
386        let tensor = Tensor::from_slice(&[1.0f32, f32::INFINITY, 2.0], &[3]).unwrap();
387        assert!(scaler.has_inf_or_nan(&tensor));
388        
389        let tensor = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
390        assert!(!scaler.has_inf_or_nan(&tensor));
391    }
392    
393    #[test]
394    fn test_autocast_context() {
395        let mut ctx = AutocastContext::new(PrecisionMode::FP16);
396        let tensor = Tensor::from_slice(&[1.5f32, 2.5], &[2]).unwrap();
397        
398        let casted = ctx.cast(&tensor);
399        assert_ne!(casted.data_f32(), tensor.data_f32());
400        
401        ctx.disable();
402        let not_casted = ctx.cast(&tensor);
403        assert_eq!(not_casted.data_f32(), tensor.data_f32());
404    }
405    
406    #[test]
407    fn test_dynamic_loss_scaling() {
408        let config = MixedPrecisionConfig::fp16();
409        let mut scaler = GradScaler::new(config);
410        
411        let initial_scale = scaler.get_scale();
412        
413        // Simulate successful steps
414        for _ in 0..2000 {
415            scaler.update_scale(false);
416        }
417        
418        let grown_scale = scaler.get_scale();
419        assert!(grown_scale > initial_scale);
420        
421        // Simulate inf/nan
422        scaler.update_scale(true);
423        let reduced_scale = scaler.get_scale();
424        assert!(reduced_scale < grown_scale);
425    }
426}