1use anyhow::{anyhow, Result};
8use scirs2_core::ndarray_ext::Array1;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use tracing::{debug, info, warn};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct MixedPrecisionConfig {
16 pub enabled: bool,
18 pub init_scale: f32,
20 pub scale_growth_factor: f32,
22 pub scale_backoff_factor: f32,
24 pub scale_growth_interval: usize,
26 pub dynamic_loss_scale: bool,
28 pub grad_clip_threshold: f32,
30 pub gradient_accumulation: bool,
32 pub accumulation_steps: usize,
34}
35
36impl Default for MixedPrecisionConfig {
37 fn default() -> Self {
38 Self {
39 enabled: true,
40 init_scale: 65536.0, scale_growth_factor: 2.0,
42 scale_backoff_factor: 0.5,
43 scale_growth_interval: 2000,
44 dynamic_loss_scale: true,
45 grad_clip_threshold: 1.0,
46 gradient_accumulation: false,
47 accumulation_steps: 1,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct MixedPrecisionStats {
55 pub current_scale: f32,
57 pub num_overflows: usize,
59 pub num_successful_steps: usize,
61 pub num_scale_updates: usize,
63 pub avg_gradient_norm: f32,
65 pub memory_saved_bytes: usize,
67}
68
69impl Default for MixedPrecisionStats {
70 fn default() -> Self {
71 Self {
72 current_scale: 1.0,
73 num_overflows: 0,
74 num_successful_steps: 0,
75 num_scale_updates: 0,
76 avg_gradient_norm: 0.0,
77 memory_saved_bytes: 0,
78 }
79 }
80}
81
82pub struct MixedPrecisionTrainer {
84 config: MixedPrecisionConfig,
85 stats: MixedPrecisionStats,
86 steps_since_overflow: usize,
87 accumulated_gradients: HashMap<String, Array1<f32>>,
88 accumulation_count: usize,
89}
90
91impl MixedPrecisionTrainer {
92 pub fn new(config: MixedPrecisionConfig) -> Self {
94 let initial_scale = if config.enabled {
95 config.init_scale
96 } else {
97 1.0
98 };
99
100 info!(
101 "Initialized mixed precision trainer: enabled={}, init_scale={}",
102 config.enabled, initial_scale
103 );
104
105 Self {
106 config,
107 stats: MixedPrecisionStats {
108 current_scale: initial_scale,
109 ..Default::default()
110 },
111 steps_since_overflow: 0,
112 accumulated_gradients: HashMap::new(),
113 accumulation_count: 0,
114 }
115 }
116
117 pub fn to_fp16(&self, tensor: &Array1<f32>) -> Array1<f32> {
121 if !self.config.enabled {
122 return tensor.clone();
123 }
124
125 const FP16_MAX: f32 = 65504.0;
127 const FP16_MIN: f32 = -65504.0;
128
129 tensor.mapv(|x| x.clamp(FP16_MIN, FP16_MAX))
130 }
131
132 pub fn to_fp32(&self, tensor: &Array1<f32>) -> Array1<f32> {
134 tensor.clone()
135 }
136
137 pub fn scale_loss(&self, loss: f32) -> f32 {
139 if !self.config.enabled {
140 return loss;
141 }
142
143 loss * self.stats.current_scale
144 }
145
146 pub fn unscale_gradients(&self, gradients: &Array1<f32>) -> Result<Array1<f32>> {
148 if !self.config.enabled {
149 return Ok(gradients.clone());
150 }
151
152 if self.has_inf_or_nan(gradients) {
154 return Err(anyhow!("Gradient overflow detected"));
155 }
156
157 let unscaled = gradients / self.stats.current_scale;
159
160 let grad_norm = self.compute_gradient_norm(&unscaled);
162
163 if grad_norm > self.config.grad_clip_threshold {
164 let scale_factor = self.config.grad_clip_threshold / grad_norm;
165 Ok(&unscaled * scale_factor)
166 } else {
167 Ok(unscaled)
168 }
169 }
170
171 pub fn update_parameters(
173 &mut self,
174 parameters: &mut Array1<f32>,
175 gradients: &Array1<f32>,
176 learning_rate: f32,
177 ) -> Result<()> {
178 if !self.config.enabled {
179 *parameters = &*parameters - &(gradients * learning_rate);
181 return Ok(());
182 }
183
184 let unscaled_grads = match self.unscale_gradients(gradients) {
186 Ok(grads) => grads,
187 Err(_) => {
188 self.handle_overflow();
189 return Ok(()); }
191 };
192
193 if self.config.gradient_accumulation {
194 let param_key = format!("{:p}", parameters);
196
197 let accumulated = self
198 .accumulated_gradients
199 .entry(param_key)
200 .or_insert_with(|| Array1::zeros(parameters.len()));
201
202 *accumulated = &*accumulated + &unscaled_grads;
203 self.accumulation_count += 1;
204
205 if self.accumulation_count >= self.config.accumulation_steps {
207 let avg_grad = &*accumulated / (self.config.accumulation_steps as f32);
208
209 *parameters = &*parameters - &(&avg_grad * learning_rate);
211
212 self.accumulated_gradients.clear();
214 self.accumulation_count = 0;
215
216 self.on_successful_step();
217 }
218 } else {
219 *parameters = &*parameters - &(&unscaled_grads * learning_rate);
221
222 self.on_successful_step();
223 }
224
225 Ok(())
226 }
227
228 fn handle_overflow(&mut self) {
230 self.stats.num_overflows += 1;
231 self.steps_since_overflow = 0;
232
233 if self.config.dynamic_loss_scale {
234 self.stats.current_scale *= self.config.scale_backoff_factor;
235 self.stats.num_scale_updates += 1;
236
237 warn!(
238 "Gradient overflow detected! Reducing loss scale to {}",
239 self.stats.current_scale
240 );
241 }
242 }
243
244 fn on_successful_step(&mut self) {
246 self.stats.num_successful_steps += 1;
247 self.steps_since_overflow += 1;
248
249 if self.config.dynamic_loss_scale
251 && self.steps_since_overflow >= self.config.scale_growth_interval
252 {
253 self.stats.current_scale *= self.config.scale_growth_factor;
254 self.stats.num_scale_updates += 1;
255 self.steps_since_overflow = 0;
256
257 debug!(
258 "Increasing loss scale to {} after {} successful steps",
259 self.stats.current_scale, self.config.scale_growth_interval
260 );
261 }
262 }
263
264 fn has_inf_or_nan(&self, tensor: &Array1<f32>) -> bool {
266 tensor.iter().any(|&x| x.is_infinite() || x.is_nan())
267 }
268
269 fn compute_gradient_norm(&self, gradients: &Array1<f32>) -> f32 {
271 gradients.dot(gradients).sqrt()
272 }
273
274 pub fn get_stats(&self) -> &MixedPrecisionStats {
276 &self.stats
277 }
278
279 pub fn reset_stats(&mut self) {
281 self.stats = MixedPrecisionStats {
282 current_scale: self.config.init_scale,
283 ..Default::default()
284 };
285 self.steps_since_overflow = 0;
286 }
287
288 pub fn estimate_memory_savings(&mut self, num_parameters: usize) {
290 if self.config.enabled {
292 self.stats.memory_saved_bytes = num_parameters * 2;
293 } else {
294 self.stats.memory_saved_bytes = 0;
295 }
296 }
297
298 pub fn update_gradient_stats(&mut self, gradients: &Array1<f32>) {
300 let norm = self.compute_gradient_norm(gradients);
301 let n = self.stats.num_successful_steps as f32;
302
303 if n > 0.0 {
304 self.stats.avg_gradient_norm = (self.stats.avg_gradient_norm * (n - 1.0) + norm) / n;
305 } else {
306 self.stats.avg_gradient_norm = norm;
307 }
308 }
309
310 pub fn is_stable(&self) -> bool {
312 if !self.config.enabled {
313 return true;
314 }
315
316 let overflow_rate =
318 self.stats.num_overflows as f32 / (self.stats.num_successful_steps + 1) as f32;
319
320 overflow_rate < 0.1 }
322
323 pub fn config(&self) -> &MixedPrecisionConfig {
325 &self.config
326 }
327}
328
329pub trait MixedPrecisionEmbedding {
331 fn to_mixed_precision(&self, trainer: &MixedPrecisionTrainer) -> Self;
333
334 fn to_full_precision(&self, trainer: &MixedPrecisionTrainer) -> Self;
336}
337
338impl MixedPrecisionEmbedding for Array1<f32> {
339 fn to_mixed_precision(&self, trainer: &MixedPrecisionTrainer) -> Self {
340 trainer.to_fp16(self)
341 }
342
343 fn to_full_precision(&self, trainer: &MixedPrecisionTrainer) -> Self {
344 trainer.to_fp32(self)
345 }
346}
347
348impl MixedPrecisionEmbedding for HashMap<String, Array1<f32>> {
349 fn to_mixed_precision(&self, trainer: &MixedPrecisionTrainer) -> Self {
350 self.iter()
351 .map(|(k, v)| (k.clone(), trainer.to_fp16(v)))
352 .collect()
353 }
354
355 fn to_full_precision(&self, trainer: &MixedPrecisionTrainer) -> Self {
356 self.iter()
357 .map(|(k, v)| (k.clone(), trainer.to_fp32(v)))
358 .collect()
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use scirs2_core::ndarray_ext::array;
366
367 #[test]
368 fn test_mixed_precision_creation() {
369 let config = MixedPrecisionConfig::default();
370 let trainer = MixedPrecisionTrainer::new(config);
371
372 assert_eq!(trainer.stats.current_scale, 65536.0);
373 assert_eq!(trainer.stats.num_overflows, 0);
374 }
375
376 #[test]
377 fn test_fp16_conversion() {
378 let config = MixedPrecisionConfig::default();
379 let trainer = MixedPrecisionTrainer::new(config);
380
381 let tensor = array![1.0, 2.0, 3.0];
382 let fp16 = trainer.to_fp16(&tensor);
383 let fp32 = trainer.to_fp32(&fp16);
384
385 assert_eq!(tensor.len(), fp32.len());
386 }
387
388 #[test]
389 fn test_loss_scaling() {
390 let config = MixedPrecisionConfig {
391 enabled: true,
392 init_scale: 1024.0,
393 ..Default::default()
394 };
395
396 let trainer = MixedPrecisionTrainer::new(config);
397
398 let loss = 0.5;
399 let scaled_loss = trainer.scale_loss(loss);
400
401 assert_eq!(scaled_loss, 512.0);
402 }
403
404 #[test]
405 fn test_gradient_unscaling() {
406 let config = MixedPrecisionConfig {
407 enabled: true,
408 init_scale: 1024.0,
409 grad_clip_threshold: 10.0,
410 ..Default::default()
411 };
412
413 let trainer = MixedPrecisionTrainer::new(config);
414
415 let scaled_grads = array![1024.0, 2048.0, 512.0];
416 let unscaled = trainer.unscale_gradients(&scaled_grads).unwrap();
417
418 assert!((unscaled[0] - 1.0).abs() < 1e-5);
420 assert!((unscaled[1] - 2.0).abs() < 1e-5);
421 assert!((unscaled[2] - 0.5).abs() < 1e-5);
422 }
423
424 #[test]
425 fn test_gradient_clipping() {
426 let config = MixedPrecisionConfig {
427 enabled: true,
428 init_scale: 1.0,
429 grad_clip_threshold: 1.0,
430 ..Default::default()
431 };
432
433 let trainer = MixedPrecisionTrainer::new(config.clone());
434
435 let grads = array![10.0, 10.0, 10.0];
437 let clipped = trainer.unscale_gradients(&grads).unwrap();
438
439 let norm = clipped.dot(&clipped).sqrt();
440 assert!(norm <= config.grad_clip_threshold + 1e-5);
441 }
442
443 #[test]
444 fn test_overflow_handling() {
445 let config = MixedPrecisionConfig {
446 enabled: true,
447 init_scale: 1024.0,
448 dynamic_loss_scale: true,
449 scale_backoff_factor: 0.5,
450 ..Default::default()
451 };
452
453 let mut trainer = MixedPrecisionTrainer::new(config.clone());
454
455 let bad_grads = array![f32::INFINITY, 1.0, 2.0];
457
458 let result = trainer.unscale_gradients(&bad_grads);
459 assert!(result.is_err());
460
461 trainer.handle_overflow();
463
464 assert_eq!(trainer.stats.current_scale, 512.0);
466 assert_eq!(trainer.stats.num_overflows, 1);
467 }
468
469 #[test]
470 fn test_parameter_update() {
471 let config = MixedPrecisionConfig {
472 enabled: true,
473 init_scale: 1.0,
474 ..Default::default()
475 };
476
477 let mut trainer = MixedPrecisionTrainer::new(config);
478
479 let mut params = array![1.0, 2.0, 3.0];
480 let grads = array![0.1, 0.2, 0.3];
481 let lr = 0.1;
482
483 trainer.update_parameters(&mut params, &grads, lr).unwrap();
484
485 assert!((params[0] - 0.99).abs() < 1e-5);
487 assert!((params[1] - 1.98).abs() < 1e-5);
488 assert!((params[2] - 2.97).abs() < 1e-5);
489 }
490
491 #[test]
492 fn test_stability_check() {
493 let config = MixedPrecisionConfig::default();
494 let mut trainer = MixedPrecisionTrainer::new(config);
495
496 trainer.stats.num_successful_steps = 100;
497 trainer.stats.num_overflows = 5; assert!(trainer.is_stable());
500
501 trainer.stats.num_overflows = 15; assert!(!trainer.is_stable());
503 }
504}