Skip to main content

axonml_optim/
grad_scaler.rs

1//! Gradient Scaler for Mixed Precision Training
2//!
3//! Provides gradient scaling to prevent underflow when using F16 gradients.
4//! Essential for stable mixed precision training.
5//!
6//! # Example
7//! ```rust,ignore
8//! use axonml_optim::{Adam, GradScaler};
9//! use axonml_autograd::amp::autocast;
10//!
11//! let mut optimizer = Adam::new(model.parameters(), 0.001);
12//! let mut scaler = GradScaler::new();
13//!
14//! for (input, target) in dataloader {
15//!     optimizer.zero_grad();
16//!
17//!     // Forward pass with autocast
18//!     let output = autocast(DType::F16, || model.forward(&input));
19//!     let loss = loss_fn.compute(&output, &target);
20//!
21//!     // Scale loss and backward
22//!     let scaled_loss = loss * scaler.get_scale();
23//!     scaled_loss.backward();
24//!
25//!     // Unscale gradients and step
26//!     if scaler.step(&mut optimizer) {
27//!         // Optimizer step was taken
28//!     }
29//!     scaler.update();
30//! }
31//! ```
32//!
33//! @version 0.1.0
34
35// =============================================================================
36// GradScaler
37// =============================================================================
38
39/// Gradient scaler for mixed precision training.
40///
41/// Scales the loss to prevent gradient underflow when using F16,
42/// then unscales gradients before the optimizer step.
43///
44/// The scale is automatically adjusted based on whether gradients overflow.
45#[derive(Debug, Clone)]
46pub struct GradScaler {
47    /// Current scale factor
48    scale: f32,
49    /// Factor to multiply scale by on successful steps
50    growth_factor: f32,
51    /// Factor to multiply scale by when overflow detected
52    backoff_factor: f32,
53    /// Number of successful steps before growing scale
54    growth_interval: usize,
55    /// Counter for successful steps since last growth
56    growth_tracker: usize,
57    /// Whether inf/nan was found in last unscale
58    found_inf: bool,
59    /// Whether the scaler is enabled
60    enabled: bool,
61}
62
63impl Default for GradScaler {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl GradScaler {
70    /// Creates a new gradient scaler with default settings.
71    ///
72    /// Default configuration:
73    /// - Initial scale: 65536.0 (2^16)
74    /// - Growth factor: 2.0
75    /// - Backoff factor: 0.5
76    /// - Growth interval: 2000 steps
77    #[must_use]
78    pub fn new() -> Self {
79        Self {
80            scale: 65536.0,
81            growth_factor: 2.0,
82            backoff_factor: 0.5,
83            growth_interval: 2000,
84            growth_tracker: 0,
85            found_inf: false,
86            enabled: true,
87        }
88    }
89
90    /// Creates a gradient scaler with custom initial scale.
91    #[must_use]
92    pub fn with_scale(init_scale: f32) -> Self {
93        Self {
94            scale: init_scale,
95            ..Self::new()
96        }
97    }
98
99    /// Creates a gradient scaler with all custom settings.
100    #[must_use]
101    pub fn with_options(
102        init_scale: f32,
103        growth_factor: f32,
104        backoff_factor: f32,
105        growth_interval: usize,
106    ) -> Self {
107        Self {
108            scale: init_scale,
109            growth_factor,
110            backoff_factor,
111            growth_interval,
112            growth_tracker: 0,
113            found_inf: false,
114            enabled: true,
115        }
116    }
117
118    /// Builder: set growth factor
119    #[must_use]
120    pub fn growth_factor(mut self, factor: f32) -> Self {
121        self.growth_factor = factor;
122        self
123    }
124
125    /// Builder: set backoff factor
126    #[must_use]
127    pub fn backoff_factor(mut self, factor: f32) -> Self {
128        self.backoff_factor = factor;
129        self
130    }
131
132    /// Builder: set growth interval
133    #[must_use]
134    pub fn growth_interval(mut self, interval: usize) -> Self {
135        self.growth_interval = interval;
136        self
137    }
138
139    /// Builder: set enabled state
140    #[must_use]
141    pub fn enabled(mut self, enabled: bool) -> Self {
142        self.enabled = enabled;
143        self
144    }
145
146    /// Returns the current scale factor.
147    #[must_use]
148    pub fn get_scale(&self) -> f32 {
149        if self.enabled {
150            self.scale
151        } else {
152            1.0
153        }
154    }
155
156    /// Sets the scale factor.
157    pub fn set_scale(&mut self, scale: f32) {
158        self.scale = scale;
159    }
160
161    /// Returns whether the scaler is enabled.
162    #[must_use]
163    pub fn is_enabled(&self) -> bool {
164        self.enabled
165    }
166
167    /// Enables or disables the scaler.
168    pub fn set_enabled(&mut self, enabled: bool) {
169        self.enabled = enabled;
170    }
171
172    /// Scales a loss value for backward pass.
173    ///
174    /// Multiply the loss by this before calling backward().
175    #[must_use]
176    pub fn scale_loss(&self, loss: f32) -> f32 {
177        if self.enabled {
178            loss * self.scale
179        } else {
180            loss
181        }
182    }
183
184    /// Unscales gradients in place and checks for inf/nan.
185    ///
186    /// Returns true if all gradients are finite, false if any overflow.
187    pub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool {
188        if !self.enabled {
189            self.found_inf = false;
190            return true;
191        }
192
193        let inv_scale = 1.0 / self.scale;
194        self.found_inf = false;
195
196        for g in grads.iter_mut() {
197            if g.is_infinite() || g.is_nan() {
198                self.found_inf = true;
199                // Don't return early - still need to unscale other grads
200                // But mark that we found inf
201            }
202            *g *= inv_scale;
203        }
204
205        !self.found_inf
206    }
207
208    /// Checks a slice of gradients for inf/nan without modifying them.
209    #[must_use]
210    pub fn check_grads(&self, grads: &[f32]) -> bool {
211        grads.iter().all(|g| g.is_finite())
212    }
213
214    /// Returns whether inf/nan was found in the last unscale operation.
215    #[must_use]
216    pub fn found_inf(&self) -> bool {
217        self.found_inf
218    }
219
220    /// Marks that inf was found (for external gradient checking).
221    pub fn set_found_inf(&mut self, found: bool) {
222        self.found_inf = found;
223    }
224
225    /// Updates the scale factor based on overflow history.
226    ///
227    /// Call this after each optimizer step:
228    /// - If overflow was detected, scale is reduced by backoff_factor
229    /// - If no overflow for growth_interval steps, scale is increased by growth_factor
230    pub fn update(&mut self) {
231        if !self.enabled {
232            return;
233        }
234
235        if self.found_inf {
236            // Reduce scale on overflow
237            self.scale *= self.backoff_factor;
238            self.growth_tracker = 0;
239            // Clamp to avoid too small scale
240            self.scale = self.scale.max(1.0);
241        } else {
242            // Track successful steps
243            self.growth_tracker += 1;
244            if self.growth_tracker >= self.growth_interval {
245                // Increase scale
246                self.scale *= self.growth_factor;
247                self.growth_tracker = 0;
248                // Clamp to avoid overflow
249                self.scale = self.scale.min(f32::MAX / 2.0);
250            }
251        }
252    }
253
254    /// Returns the current state for checkpointing.
255    #[must_use]
256    pub fn state_dict(&self) -> GradScalerState {
257        GradScalerState {
258            scale: self.scale,
259            growth_tracker: self.growth_tracker,
260        }
261    }
262
263    /// Loads state from a checkpoint.
264    pub fn load_state_dict(&mut self, state: GradScalerState) {
265        self.scale = state.scale;
266        self.growth_tracker = state.growth_tracker;
267    }
268}
269
270/// Serializable state for GradScaler checkpointing.
271#[derive(Debug, Clone, Copy)]
272pub struct GradScalerState {
273    /// Current scale factor
274    pub scale: f32,
275    /// Growth tracker value
276    pub growth_tracker: usize,
277}
278
279// =============================================================================
280// Tests
281// =============================================================================
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn test_grad_scaler_creation() {
289        let scaler = GradScaler::new();
290        assert!((scaler.get_scale() - 65536.0).abs() < 1e-6);
291        assert!(scaler.is_enabled());
292        assert!(!scaler.found_inf());
293    }
294
295    #[test]
296    fn test_grad_scaler_with_scale() {
297        let scaler = GradScaler::with_scale(1024.0);
298        assert!((scaler.get_scale() - 1024.0).abs() < 1e-6);
299    }
300
301    #[test]
302    fn test_scale_loss() {
303        let scaler = GradScaler::with_scale(100.0);
304        let loss = 0.5;
305        let scaled = scaler.scale_loss(loss);
306        assert!((scaled - 50.0).abs() < 1e-6);
307    }
308
309    #[test]
310    fn test_unscale_grads() {
311        let mut scaler = GradScaler::with_scale(100.0);
312        let mut grads = vec![100.0, 200.0, 300.0];
313
314        let valid = scaler.unscale_grads(&mut grads);
315
316        assert!(valid);
317        assert!(!scaler.found_inf());
318        assert!((grads[0] - 1.0).abs() < 1e-6);
319        assert!((grads[1] - 2.0).abs() < 1e-6);
320        assert!((grads[2] - 3.0).abs() < 1e-6);
321    }
322
323    #[test]
324    fn test_unscale_grads_with_inf() {
325        let mut scaler = GradScaler::with_scale(100.0);
326        let mut grads = vec![100.0, f32::INFINITY, 300.0];
327
328        let valid = scaler.unscale_grads(&mut grads);
329
330        assert!(!valid);
331        assert!(scaler.found_inf());
332    }
333
334    #[test]
335    fn test_unscale_grads_with_nan() {
336        let mut scaler = GradScaler::with_scale(100.0);
337        let mut grads = vec![100.0, f32::NAN, 300.0];
338
339        let valid = scaler.unscale_grads(&mut grads);
340
341        assert!(!valid);
342        assert!(scaler.found_inf());
343    }
344
345    #[test]
346    fn test_update_on_overflow() {
347        let mut scaler = GradScaler::with_scale(1000.0);
348        scaler.found_inf = true;
349
350        scaler.update();
351
352        assert!((scaler.get_scale() - 500.0).abs() < 1e-6);
353        assert_eq!(scaler.growth_tracker, 0);
354    }
355
356    #[test]
357    fn test_update_growth() {
358        let mut scaler = GradScaler::with_options(100.0, 2.0, 0.5, 3);
359
360        // Simulate 3 successful steps
361        for _ in 0..3 {
362            scaler.found_inf = false;
363            scaler.update();
364        }
365
366        assert!((scaler.get_scale() - 200.0).abs() < 1e-6);
367        assert_eq!(scaler.growth_tracker, 0);
368    }
369
370    #[test]
371    fn test_disabled_scaler() {
372        let mut scaler = GradScaler::new().enabled(false);
373
374        assert!(!scaler.is_enabled());
375        assert!((scaler.get_scale() - 1.0).abs() < 1e-6);
376        assert!((scaler.scale_loss(0.5) - 0.5).abs() < 1e-6);
377
378        let mut grads = vec![1.0, 2.0, 3.0];
379        let valid = scaler.unscale_grads(&mut grads);
380        assert!(valid);
381        // Grads should be unchanged
382        assert!((grads[0] - 1.0).abs() < 1e-6);
383    }
384
385    #[test]
386    fn test_state_dict() {
387        let mut scaler = GradScaler::with_scale(500.0);
388        scaler.growth_tracker = 10;
389
390        let state = scaler.state_dict();
391        assert!((state.scale - 500.0).abs() < 1e-6);
392        assert_eq!(state.growth_tracker, 10);
393
394        let mut new_scaler = GradScaler::new();
395        new_scaler.load_state_dict(state);
396        assert!((new_scaler.get_scale() - 500.0).abs() < 1e-6);
397        assert_eq!(new_scaler.growth_tracker, 10);
398    }
399
400    #[test]
401    fn test_builder_pattern() {
402        let scaler = GradScaler::with_scale(1000.0)
403            .growth_factor(3.0)
404            .backoff_factor(0.25)
405            .growth_interval(100);
406
407        assert!((scaler.growth_factor - 3.0).abs() < 1e-6);
408        assert!((scaler.backoff_factor - 0.25).abs() < 1e-6);
409        assert_eq!(scaler.growth_interval, 100);
410    }
411}