Skip to main content

entrenar/autograd/precision/
scaler.rs

1//! Gradient scaler for mixed-precision training.
2
3use super::MixedPrecisionConfig;
4
5/// Default number of successful steps before the loss scale is increased
6const DEFAULT_SCALE_GROWTH_INTERVAL: usize = 2000;
7
8/// Gradient scaler for mixed-precision training
9///
10/// Handles loss scaling to prevent gradient underflow in fp16 training.
11#[derive(Debug)]
12pub struct GradScaler {
13    /// Current loss scale
14    scale: f32,
15    /// Growth factor
16    growth_factor: f32,
17    /// Backoff factor
18    backoff_factor: f32,
19    /// Growth interval
20    pub(crate) growth_interval: usize,
21    /// Steps since last growth
22    steps_since_growth: usize,
23    /// Whether dynamic scaling is enabled
24    dynamic: bool,
25    /// Number of overflows encountered
26    overflow_count: usize,
27    /// Number of successful steps
28    successful_steps: usize,
29}
30
31impl GradScaler {
32    /// Create a new gradient scaler
33    pub fn new(initial_scale: f32) -> Self {
34        Self {
35            scale: initial_scale,
36            growth_factor: 2.0,
37            backoff_factor: 0.5,
38            growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
39            steps_since_growth: 0,
40            dynamic: true,
41            overflow_count: 0,
42            successful_steps: 0,
43        }
44    }
45
46    /// Create from config
47    pub fn from_config(config: &MixedPrecisionConfig) -> Self {
48        Self {
49            scale: config.initial_scale,
50            growth_factor: config.scale_growth_factor,
51            backoff_factor: config.scale_backoff_factor,
52            growth_interval: config.scale_growth_interval,
53            steps_since_growth: 0,
54            dynamic: config.dynamic_scaling,
55            overflow_count: 0,
56            successful_steps: 0,
57        }
58    }
59
60    /// Get current scale
61    pub fn scale(&self) -> f32 {
62        self.scale
63    }
64
65    /// Scale a loss value
66    pub fn scale_loss(&self, loss: f32) -> f32 {
67        loss * self.scale
68    }
69
70    /// Unscale a gradient value
71    pub fn unscale_grad(&self, grad: f32) -> f32 {
72        grad / self.scale
73    }
74
75    /// Unscale gradients in place and check for overflow
76    ///
77    /// Returns true if gradients are valid (no overflow), false otherwise.
78    pub fn unscale_and_check(&self, grads: &mut [f32]) -> bool {
79        let inv_scale = 1.0 / self.scale;
80        let mut has_overflow = false;
81
82        for grad in grads.iter_mut() {
83            *grad *= inv_scale;
84            if !grad.is_finite() {
85                has_overflow = true;
86            }
87        }
88
89        !has_overflow
90    }
91
92    /// Update the scale after a step
93    ///
94    /// Call this after each optimizer step. Pass `true` if gradients were valid.
95    pub fn update(&mut self, grads_valid: bool) {
96        contract_pre_update!();
97        if !self.dynamic {
98            return;
99        }
100
101        if grads_valid {
102            self.successful_steps += 1;
103            self.steps_since_growth += 1;
104
105            // Grow scale after interval of successful steps
106            if self.steps_since_growth >= self.growth_interval {
107                self.scale *= self.growth_factor;
108                self.steps_since_growth = 0;
109            }
110        } else {
111            // Overflow detected - reduce scale
112            self.overflow_count += 1;
113            self.scale *= self.backoff_factor;
114            self.steps_since_growth = 0;
115
116            // Ensure scale doesn't go too low
117            self.scale = self.scale.max(1.0);
118        }
119    }
120
121    /// Get overflow count
122    pub fn overflow_count(&self) -> usize {
123        self.overflow_count
124    }
125
126    /// Get successful step count
127    pub fn successful_steps(&self) -> usize {
128        self.successful_steps
129    }
130
131    /// Check if dynamic scaling is enabled
132    pub fn is_dynamic(&self) -> bool {
133        self.dynamic
134    }
135
136    /// Enable/disable dynamic scaling
137    pub fn set_dynamic(&mut self, enabled: bool) {
138        self.dynamic = enabled;
139    }
140}
141
142impl Default for GradScaler {
143    fn default() -> Self {
144        Self::new(65536.0)
145    }
146}