oxirs_embed/
mixed_precision.rs

1//! Mixed Precision Training for Knowledge Graph Embeddings
2//!
3//! This module provides mixed precision training support to accelerate training
4//! and reduce memory usage while maintaining numerical stability. Uses float16
5//! for forward/backward passes and float32 for parameter updates.
6
7use anyhow::{anyhow, Result};
8use scirs2_core::ndarray_ext::Array1;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use tracing::{debug, info, warn};
12
13/// Mixed precision training configuration
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct MixedPrecisionConfig {
16    /// Enable mixed precision training
17    pub enabled: bool,
18    /// Initial loss scale factor
19    pub init_scale: f32,
20    /// Scale factor growth rate
21    pub scale_growth_factor: f32,
22    /// Backoff factor when overflow detected
23    pub scale_backoff_factor: f32,
24    /// Number of successful steps before increasing scale
25    pub scale_growth_interval: usize,
26    /// Use dynamic loss scaling
27    pub dynamic_loss_scale: bool,
28    /// Gradient clipping threshold
29    pub grad_clip_threshold: f32,
30    /// Enable gradient accumulation
31    pub gradient_accumulation: bool,
32    /// Number of steps to accumulate gradients
33    pub accumulation_steps: usize,
34}
35
36impl Default for MixedPrecisionConfig {
37    fn default() -> Self {
38        Self {
39            enabled: true,
40            init_scale: 65536.0, // 2^16
41            scale_growth_factor: 2.0,
42            scale_backoff_factor: 0.5,
43            scale_growth_interval: 2000,
44            dynamic_loss_scale: true,
45            grad_clip_threshold: 1.0,
46            gradient_accumulation: false,
47            accumulation_steps: 1,
48        }
49    }
50}
51
52/// Mixed precision training statistics
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct MixedPrecisionStats {
55    /// Current loss scale
56    pub current_scale: f32,
57    /// Number of overflow events
58    pub num_overflows: usize,
59    /// Number of successful steps
60    pub num_successful_steps: usize,
61    /// Number of scale updates
62    pub num_scale_updates: usize,
63    /// Average gradient norm
64    pub avg_gradient_norm: f32,
65    /// Memory saved (estimated, in bytes)
66    pub memory_saved_bytes: usize,
67}
68
69impl Default for MixedPrecisionStats {
70    fn default() -> Self {
71        Self {
72            current_scale: 1.0,
73            num_overflows: 0,
74            num_successful_steps: 0,
75            num_scale_updates: 0,
76            avg_gradient_norm: 0.0,
77            memory_saved_bytes: 0,
78        }
79    }
80}
81
82/// Mixed precision trainer for embeddings
83pub struct MixedPrecisionTrainer {
84    config: MixedPrecisionConfig,
85    stats: MixedPrecisionStats,
86    steps_since_overflow: usize,
87    accumulated_gradients: HashMap<String, Array1<f32>>,
88    accumulation_count: usize,
89}
90
91impl MixedPrecisionTrainer {
92    /// Create new mixed precision trainer
93    pub fn new(config: MixedPrecisionConfig) -> Self {
94        let initial_scale = if config.enabled {
95            config.init_scale
96        } else {
97            1.0
98        };
99
100        info!(
101            "Initialized mixed precision trainer: enabled={}, init_scale={}",
102            config.enabled, initial_scale
103        );
104
105        Self {
106            config,
107            stats: MixedPrecisionStats {
108                current_scale: initial_scale,
109                ..Default::default()
110            },
111            steps_since_overflow: 0,
112            accumulated_gradients: HashMap::new(),
113            accumulation_count: 0,
114        }
115    }
116
117    /// Convert float32 tensor to float16 (simulated with f32)
118    ///
119    /// Note: Rust doesn't have native float16, so we simulate by clamping range
120    pub fn to_fp16(&self, tensor: &Array1<f32>) -> Array1<f32> {
121        if !self.config.enabled {
122            return tensor.clone();
123        }
124
125        // Simulate FP16 range: approximately [-65504, 65504]
126        const FP16_MAX: f32 = 65504.0;
127        const FP16_MIN: f32 = -65504.0;
128
129        tensor.mapv(|x| x.clamp(FP16_MIN, FP16_MAX))
130    }
131
132    /// Convert float16 back to float32 (no-op in simulation)
133    pub fn to_fp32(&self, tensor: &Array1<f32>) -> Array1<f32> {
134        tensor.clone()
135    }
136
137    /// Scale loss for backward pass
138    pub fn scale_loss(&self, loss: f32) -> f32 {
139        if !self.config.enabled {
140            return loss;
141        }
142
143        loss * self.stats.current_scale
144    }
145
146    /// Unscale gradients after backward pass
147    pub fn unscale_gradients(&self, gradients: &Array1<f32>) -> Result<Array1<f32>> {
148        if !self.config.enabled {
149            return Ok(gradients.clone());
150        }
151
152        // Check for overflow/underflow
153        if self.has_inf_or_nan(gradients) {
154            return Err(anyhow!("Gradient overflow detected"));
155        }
156
157        // Unscale
158        let unscaled = gradients / self.stats.current_scale;
159
160        // Gradient clipping
161        let grad_norm = self.compute_gradient_norm(&unscaled);
162
163        if grad_norm > self.config.grad_clip_threshold {
164            let scale_factor = self.config.grad_clip_threshold / grad_norm;
165            Ok(&unscaled * scale_factor)
166        } else {
167            Ok(unscaled)
168        }
169    }
170
171    /// Update parameters with mixed precision
172    pub fn update_parameters(
173        &mut self,
174        parameters: &mut Array1<f32>,
175        gradients: &Array1<f32>,
176        learning_rate: f32,
177    ) -> Result<()> {
178        if !self.config.enabled {
179            // Standard update
180            *parameters = &*parameters - &(gradients * learning_rate);
181            return Ok(());
182        }
183
184        // Unscale gradients
185        let unscaled_grads = match self.unscale_gradients(gradients) {
186            Ok(grads) => grads,
187            Err(_) => {
188                self.handle_overflow();
189                return Ok(()); // Skip this update
190            }
191        };
192
193        if self.config.gradient_accumulation {
194            // Accumulate gradients
195            let param_key = format!("{:p}", parameters);
196
197            let accumulated = self
198                .accumulated_gradients
199                .entry(param_key)
200                .or_insert_with(|| Array1::zeros(parameters.len()));
201
202            *accumulated = &*accumulated + &unscaled_grads;
203            self.accumulation_count += 1;
204
205            // Only update when we've accumulated enough
206            if self.accumulation_count >= self.config.accumulation_steps {
207                let avg_grad = &*accumulated / (self.config.accumulation_steps as f32);
208
209                // Update in FP32
210                *parameters = &*parameters - &(&avg_grad * learning_rate);
211
212                // Reset accumulation
213                self.accumulated_gradients.clear();
214                self.accumulation_count = 0;
215
216                self.on_successful_step();
217            }
218        } else {
219            // Direct update in FP32
220            *parameters = &*parameters - &(&unscaled_grads * learning_rate);
221
222            self.on_successful_step();
223        }
224
225        Ok(())
226    }
227
228    /// Handle gradient overflow
229    fn handle_overflow(&mut self) {
230        self.stats.num_overflows += 1;
231        self.steps_since_overflow = 0;
232
233        if self.config.dynamic_loss_scale {
234            self.stats.current_scale *= self.config.scale_backoff_factor;
235            self.stats.num_scale_updates += 1;
236
237            warn!(
238                "Gradient overflow detected! Reducing loss scale to {}",
239                self.stats.current_scale
240            );
241        }
242    }
243
244    /// Called after successful parameter update
245    fn on_successful_step(&mut self) {
246        self.stats.num_successful_steps += 1;
247        self.steps_since_overflow += 1;
248
249        // Increase scale if we've had many successful steps
250        if self.config.dynamic_loss_scale
251            && self.steps_since_overflow >= self.config.scale_growth_interval
252        {
253            self.stats.current_scale *= self.config.scale_growth_factor;
254            self.stats.num_scale_updates += 1;
255            self.steps_since_overflow = 0;
256
257            debug!(
258                "Increasing loss scale to {} after {} successful steps",
259                self.stats.current_scale, self.config.scale_growth_interval
260            );
261        }
262    }
263
264    /// Check if tensor contains inf or nan
265    fn has_inf_or_nan(&self, tensor: &Array1<f32>) -> bool {
266        tensor.iter().any(|&x| x.is_infinite() || x.is_nan())
267    }
268
269    /// Compute gradient norm
270    fn compute_gradient_norm(&self, gradients: &Array1<f32>) -> f32 {
271        gradients.dot(gradients).sqrt()
272    }
273
274    /// Get current statistics
275    pub fn get_stats(&self) -> &MixedPrecisionStats {
276        &self.stats
277    }
278
279    /// Reset statistics
280    pub fn reset_stats(&mut self) {
281        self.stats = MixedPrecisionStats {
282            current_scale: self.config.init_scale,
283            ..Default::default()
284        };
285        self.steps_since_overflow = 0;
286    }
287
288    /// Estimate memory savings
289    pub fn estimate_memory_savings(&mut self, num_parameters: usize) {
290        // FP16 uses 2 bytes vs FP32's 4 bytes
291        if self.config.enabled {
292            self.stats.memory_saved_bytes = num_parameters * 2;
293        } else {
294            self.stats.memory_saved_bytes = 0;
295        }
296    }
297
298    /// Update average gradient norm
299    pub fn update_gradient_stats(&mut self, gradients: &Array1<f32>) {
300        let norm = self.compute_gradient_norm(gradients);
301        let n = self.stats.num_successful_steps as f32;
302
303        if n > 0.0 {
304            self.stats.avg_gradient_norm = (self.stats.avg_gradient_norm * (n - 1.0) + norm) / n;
305        } else {
306            self.stats.avg_gradient_norm = norm;
307        }
308    }
309
310    /// Check if training is stable
311    pub fn is_stable(&self) -> bool {
312        if !self.config.enabled {
313            return true;
314        }
315
316        // Consider training unstable if too many overflows
317        let overflow_rate =
318            self.stats.num_overflows as f32 / (self.stats.num_successful_steps + 1) as f32;
319
320        overflow_rate < 0.1 // Less than 10% overflow rate
321    }
322
323    /// Get configuration
324    pub fn config(&self) -> &MixedPrecisionConfig {
325        &self.config
326    }
327}
328
329/// Helper trait for mixed precision operations on embeddings
330pub trait MixedPrecisionEmbedding {
331    /// Convert embeddings to mixed precision format
332    fn to_mixed_precision(&self, trainer: &MixedPrecisionTrainer) -> Self;
333
334    /// Convert back to full precision
335    fn to_full_precision(&self, trainer: &MixedPrecisionTrainer) -> Self;
336}
337
338impl MixedPrecisionEmbedding for Array1<f32> {
339    fn to_mixed_precision(&self, trainer: &MixedPrecisionTrainer) -> Self {
340        trainer.to_fp16(self)
341    }
342
343    fn to_full_precision(&self, trainer: &MixedPrecisionTrainer) -> Self {
344        trainer.to_fp32(self)
345    }
346}
347
348impl MixedPrecisionEmbedding for HashMap<String, Array1<f32>> {
349    fn to_mixed_precision(&self, trainer: &MixedPrecisionTrainer) -> Self {
350        self.iter()
351            .map(|(k, v)| (k.clone(), trainer.to_fp16(v)))
352            .collect()
353    }
354
355    fn to_full_precision(&self, trainer: &MixedPrecisionTrainer) -> Self {
356        self.iter()
357            .map(|(k, v)| (k.clone(), trainer.to_fp32(v)))
358            .collect()
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use scirs2_core::ndarray_ext::array;
366
367    #[test]
368    fn test_mixed_precision_creation() {
369        let config = MixedPrecisionConfig::default();
370        let trainer = MixedPrecisionTrainer::new(config);
371
372        assert_eq!(trainer.stats.current_scale, 65536.0);
373        assert_eq!(trainer.stats.num_overflows, 0);
374    }
375
376    #[test]
377    fn test_fp16_conversion() {
378        let config = MixedPrecisionConfig::default();
379        let trainer = MixedPrecisionTrainer::new(config);
380
381        let tensor = array![1.0, 2.0, 3.0];
382        let fp16 = trainer.to_fp16(&tensor);
383        let fp32 = trainer.to_fp32(&fp16);
384
385        assert_eq!(tensor.len(), fp32.len());
386    }
387
388    #[test]
389    fn test_loss_scaling() {
390        let config = MixedPrecisionConfig {
391            enabled: true,
392            init_scale: 1024.0,
393            ..Default::default()
394        };
395
396        let trainer = MixedPrecisionTrainer::new(config);
397
398        let loss = 0.5;
399        let scaled_loss = trainer.scale_loss(loss);
400
401        assert_eq!(scaled_loss, 512.0);
402    }
403
404    #[test]
405    fn test_gradient_unscaling() {
406        let config = MixedPrecisionConfig {
407            enabled: true,
408            init_scale: 1024.0,
409            grad_clip_threshold: 10.0,
410            ..Default::default()
411        };
412
413        let trainer = MixedPrecisionTrainer::new(config);
414
415        let scaled_grads = array![1024.0, 2048.0, 512.0];
416        let unscaled = trainer.unscale_gradients(&scaled_grads).unwrap();
417
418        // Should be divided by scale (1024.0)
419        assert!((unscaled[0] - 1.0).abs() < 1e-5);
420        assert!((unscaled[1] - 2.0).abs() < 1e-5);
421        assert!((unscaled[2] - 0.5).abs() < 1e-5);
422    }
423
424    #[test]
425    fn test_gradient_clipping() {
426        let config = MixedPrecisionConfig {
427            enabled: true,
428            init_scale: 1.0,
429            grad_clip_threshold: 1.0,
430            ..Default::default()
431        };
432
433        let trainer = MixedPrecisionTrainer::new(config.clone());
434
435        // Large gradients that exceed threshold
436        let grads = array![10.0, 10.0, 10.0];
437        let clipped = trainer.unscale_gradients(&grads).unwrap();
438
439        let norm = clipped.dot(&clipped).sqrt();
440        assert!(norm <= config.grad_clip_threshold + 1e-5);
441    }
442
443    #[test]
444    fn test_overflow_handling() {
445        let config = MixedPrecisionConfig {
446            enabled: true,
447            init_scale: 1024.0,
448            dynamic_loss_scale: true,
449            scale_backoff_factor: 0.5,
450            ..Default::default()
451        };
452
453        let mut trainer = MixedPrecisionTrainer::new(config.clone());
454
455        // Simulate overflow with inf gradients
456        let bad_grads = array![f32::INFINITY, 1.0, 2.0];
457
458        let result = trainer.unscale_gradients(&bad_grads);
459        assert!(result.is_err());
460
461        // Manually handle overflow
462        trainer.handle_overflow();
463
464        // Scale should be reduced
465        assert_eq!(trainer.stats.current_scale, 512.0);
466        assert_eq!(trainer.stats.num_overflows, 1);
467    }
468
469    #[test]
470    fn test_parameter_update() {
471        let config = MixedPrecisionConfig {
472            enabled: true,
473            init_scale: 1.0,
474            ..Default::default()
475        };
476
477        let mut trainer = MixedPrecisionTrainer::new(config);
478
479        let mut params = array![1.0, 2.0, 3.0];
480        let grads = array![0.1, 0.2, 0.3];
481        let lr = 0.1;
482
483        trainer.update_parameters(&mut params, &grads, lr).unwrap();
484
485        // params should be updated: params -= lr * grads
486        assert!((params[0] - 0.99).abs() < 1e-5);
487        assert!((params[1] - 1.98).abs() < 1e-5);
488        assert!((params[2] - 2.97).abs() < 1e-5);
489    }
490
491    #[test]
492    fn test_stability_check() {
493        let config = MixedPrecisionConfig::default();
494        let mut trainer = MixedPrecisionTrainer::new(config);
495
496        trainer.stats.num_successful_steps = 100;
497        trainer.stats.num_overflows = 5; // 5% overflow rate
498
499        assert!(trainer.is_stable());
500
501        trainer.stats.num_overflows = 15; // 15% overflow rate
502        assert!(!trainer.is_stable());
503    }
504}