Skip to main content

axonml_optim/
grad_scaler.rs

1//! Gradient Scaler for Mixed Precision Training
2//!
3//! # File
4//! `crates/axonml-optim/src/grad_scaler.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17// =============================================================================
18// GradScaler
19// =============================================================================
20
21/// Gradient scaler for mixed precision training.
22///
23/// Scales the loss to prevent gradient underflow when using F16,
24/// then unscales gradients before the optimizer step.
25///
26/// The scale is automatically adjusted based on whether gradients overflow.
27#[derive(Debug, Clone)]
28pub struct GradScaler {
29    /// Current scale factor
30    scale: f32,
31    /// Factor to multiply scale by on successful steps
32    growth_factor: f32,
33    /// Factor to multiply scale by when overflow detected
34    backoff_factor: f32,
35    /// Number of successful steps before growing scale
36    growth_interval: usize,
37    /// Counter for successful steps since last growth
38    growth_tracker: usize,
39    /// Whether inf/nan was found in last unscale
40    found_inf: bool,
41    /// Whether the scaler is enabled
42    enabled: bool,
43}
44
45impl Default for GradScaler {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl GradScaler {
52    /// Creates a new gradient scaler with default settings.
53    ///
54    /// Default configuration:
55    /// - Initial scale: 65536.0 (2^16)
56    /// - Growth factor: 2.0
57    /// - Backoff factor: 0.5
58    /// - Growth interval: 2000 steps
59    #[must_use]
60    pub fn new() -> Self {
61        Self {
62            scale: 65536.0,
63            growth_factor: 2.0,
64            backoff_factor: 0.5,
65            growth_interval: 2000,
66            growth_tracker: 0,
67            found_inf: false,
68            enabled: true,
69        }
70    }
71
72    /// Creates a gradient scaler with custom initial scale.
73    #[must_use]
74    pub fn with_scale(init_scale: f32) -> Self {
75        Self {
76            scale: init_scale,
77            ..Self::new()
78        }
79    }
80
81    /// Creates a gradient scaler with all custom settings.
82    #[must_use]
83    pub fn with_options(
84        init_scale: f32,
85        growth_factor: f32,
86        backoff_factor: f32,
87        growth_interval: usize,
88    ) -> Self {
89        Self {
90            scale: init_scale,
91            growth_factor,
92            backoff_factor,
93            growth_interval,
94            growth_tracker: 0,
95            found_inf: false,
96            enabled: true,
97        }
98    }
99
100    /// Builder: set growth factor
101    #[must_use]
102    pub fn growth_factor(mut self, factor: f32) -> Self {
103        self.growth_factor = factor;
104        self
105    }
106
107    /// Builder: set backoff factor
108    #[must_use]
109    pub fn backoff_factor(mut self, factor: f32) -> Self {
110        self.backoff_factor = factor;
111        self
112    }
113
114    /// Builder: set growth interval
115    #[must_use]
116    pub fn growth_interval(mut self, interval: usize) -> Self {
117        self.growth_interval = interval;
118        self
119    }
120
121    /// Builder: set enabled state
122    #[must_use]
123    pub fn enabled(mut self, enabled: bool) -> Self {
124        self.enabled = enabled;
125        self
126    }
127
128    /// Returns the current scale factor.
129    #[must_use]
130    pub fn get_scale(&self) -> f32 {
131        if self.enabled { self.scale } else { 1.0 }
132    }
133
134    /// Sets the scale factor.
135    pub fn set_scale(&mut self, scale: f32) {
136        self.scale = scale;
137    }
138
139    /// Returns whether the scaler is enabled.
140    #[must_use]
141    pub fn is_enabled(&self) -> bool {
142        self.enabled
143    }
144
145    /// Enables or disables the scaler.
146    pub fn set_enabled(&mut self, enabled: bool) {
147        self.enabled = enabled;
148    }
149
150    /// Scales a loss value for backward pass.
151    ///
152    /// Multiply the loss by this before calling backward().
153    #[must_use]
154    pub fn scale_loss(&self, loss: f32) -> f32 {
155        if self.enabled {
156            loss * self.scale
157        } else {
158            loss
159        }
160    }
161
162    /// Unscales gradients in place and checks for inf/nan.
163    ///
164    /// Returns true if all gradients are finite, false if any overflow.
165    pub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool {
166        if !self.enabled {
167            self.found_inf = false;
168            return true;
169        }
170
171        let inv_scale = 1.0 / self.scale;
172        self.found_inf = false;
173
174        for g in grads.iter_mut() {
175            if g.is_infinite() || g.is_nan() {
176                self.found_inf = true;
177                // Don't return early - still need to unscale other grads
178                // But mark that we found inf
179            }
180            *g *= inv_scale;
181        }
182
183        !self.found_inf
184    }
185
186    /// Unscales gradients on all optimizer parameters in place.
187    ///
188    /// Equivalent to PyTorch's `GradScaler.unscale_(optimizer)`.
189    /// Returns true if all gradients are finite.
190    pub fn unscale_optimizer<O: crate::Optimizer>(&mut self, optimizer: &O) -> bool {
191        if !self.enabled {
192            self.found_inf = false;
193            return true;
194        }
195
196        let inv_scale = 1.0 / self.scale;
197        self.found_inf = false;
198
199        for param in optimizer.parameters() {
200            if let Some(grad) = param.grad() {
201                let mut grad_vec = grad.to_vec();
202                for g in &mut grad_vec {
203                    if g.is_infinite() || g.is_nan() {
204                        self.found_inf = true;
205                    }
206                    *g *= inv_scale;
207                }
208                let unscaled = axonml_tensor::Tensor::from_vec(grad_vec, grad.shape())
209                    .expect("grad_scaler: tensor creation failed");
210                param.set_grad(unscaled);
211            }
212        }
213
214        !self.found_inf
215    }
216
217    /// Checks a slice of gradients for inf/nan without modifying them.
218    #[must_use]
219    pub fn check_grads(&self, grads: &[f32]) -> bool {
220        grads.iter().all(|g| g.is_finite())
221    }
222
223    /// Returns whether inf/nan was found in the last unscale operation.
224    #[must_use]
225    pub fn found_inf(&self) -> bool {
226        self.found_inf
227    }
228
229    /// Marks that inf was found (for external gradient checking).
230    pub fn set_found_inf(&mut self, found: bool) {
231        self.found_inf = found;
232    }
233
234    /// Updates the scale factor based on overflow history.
235    ///
236    /// Call this after each optimizer step:
237    /// - If overflow was detected, scale is reduced by backoff_factor
238    /// - If no overflow for growth_interval steps, scale is increased by growth_factor
239    pub fn update(&mut self) {
240        if !self.enabled {
241            return;
242        }
243
244        if self.found_inf {
245            // Reduce scale on overflow
246            self.scale *= self.backoff_factor;
247            self.growth_tracker = 0;
248            // Clamp to avoid too small scale
249            self.scale = self.scale.max(1.0);
250        } else {
251            // Track successful steps
252            self.growth_tracker += 1;
253            if self.growth_tracker >= self.growth_interval {
254                // Increase scale
255                self.scale *= self.growth_factor;
256                self.growth_tracker = 0;
257                // Clamp to avoid overflow
258                self.scale = self.scale.min(f32::MAX / 2.0);
259            }
260        }
261    }
262
263    /// Returns the current state for checkpointing.
264    #[must_use]
265    pub fn state_dict(&self) -> GradScalerState {
266        GradScalerState {
267            scale: self.scale,
268            growth_tracker: self.growth_tracker,
269        }
270    }
271
272    /// Loads state from a checkpoint.
273    pub fn load_state_dict(&mut self, state: GradScalerState) {
274        self.scale = state.scale;
275        self.growth_tracker = state.growth_tracker;
276    }
277}
278
279/// Serializable state for GradScaler checkpointing.
280#[derive(Debug, Clone, Copy)]
281pub struct GradScalerState {
282    /// Current scale factor
283    pub scale: f32,
284    /// Growth tracker value
285    pub growth_tracker: usize,
286}
287
288// =============================================================================
289// Tests
290// =============================================================================
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_grad_scaler_creation() {
298        let scaler = GradScaler::new();
299        assert!((scaler.get_scale() - 65536.0).abs() < 1e-6);
300        assert!(scaler.is_enabled());
301        assert!(!scaler.found_inf());
302    }
303
304    #[test]
305    fn test_grad_scaler_with_scale() {
306        let scaler = GradScaler::with_scale(1024.0);
307        assert!((scaler.get_scale() - 1024.0).abs() < 1e-6);
308    }
309
310    #[test]
311    fn test_scale_loss() {
312        let scaler = GradScaler::with_scale(100.0);
313        let loss = 0.5;
314        let scaled = scaler.scale_loss(loss);
315        assert!((scaled - 50.0).abs() < 1e-6);
316    }
317
318    #[test]
319    fn test_unscale_grads() {
320        let mut scaler = GradScaler::with_scale(100.0);
321        let mut grads = vec![100.0, 200.0, 300.0];
322
323        let valid = scaler.unscale_grads(&mut grads);
324
325        assert!(valid);
326        assert!(!scaler.found_inf());
327        assert!((grads[0] - 1.0).abs() < 1e-6);
328        assert!((grads[1] - 2.0).abs() < 1e-6);
329        assert!((grads[2] - 3.0).abs() < 1e-6);
330    }
331
332    #[test]
333    fn test_unscale_grads_with_inf() {
334        let mut scaler = GradScaler::with_scale(100.0);
335        let mut grads = vec![100.0, f32::INFINITY, 300.0];
336
337        let valid = scaler.unscale_grads(&mut grads);
338
339        assert!(!valid);
340        assert!(scaler.found_inf());
341    }
342
343    #[test]
344    fn test_unscale_grads_with_nan() {
345        let mut scaler = GradScaler::with_scale(100.0);
346        let mut grads = vec![100.0, f32::NAN, 300.0];
347
348        let valid = scaler.unscale_grads(&mut grads);
349
350        assert!(!valid);
351        assert!(scaler.found_inf());
352    }
353
354    #[test]
355    fn test_update_on_overflow() {
356        let mut scaler = GradScaler::with_scale(1000.0);
357        scaler.found_inf = true;
358
359        scaler.update();
360
361        assert!((scaler.get_scale() - 500.0).abs() < 1e-6);
362        assert_eq!(scaler.growth_tracker, 0);
363    }
364
365    #[test]
366    fn test_update_growth() {
367        let mut scaler = GradScaler::with_options(100.0, 2.0, 0.5, 3);
368
369        // Simulate 3 successful steps
370        for _ in 0..3 {
371            scaler.found_inf = false;
372            scaler.update();
373        }
374
375        assert!((scaler.get_scale() - 200.0).abs() < 1e-6);
376        assert_eq!(scaler.growth_tracker, 0);
377    }
378
379    #[test]
380    fn test_disabled_scaler() {
381        let mut scaler = GradScaler::new().enabled(false);
382
383        assert!(!scaler.is_enabled());
384        assert!((scaler.get_scale() - 1.0).abs() < 1e-6);
385        assert!((scaler.scale_loss(0.5) - 0.5).abs() < 1e-6);
386
387        let mut grads = vec![1.0, 2.0, 3.0];
388        let valid = scaler.unscale_grads(&mut grads);
389        assert!(valid);
390        // Grads should be unchanged
391        assert!((grads[0] - 1.0).abs() < 1e-6);
392    }
393
394    #[test]
395    fn test_state_dict() {
396        let mut scaler = GradScaler::with_scale(500.0);
397        scaler.growth_tracker = 10;
398
399        let state = scaler.state_dict();
400        assert!((state.scale - 500.0).abs() < 1e-6);
401        assert_eq!(state.growth_tracker, 10);
402
403        let mut new_scaler = GradScaler::new();
404        new_scaler.load_state_dict(state);
405        assert!((new_scaler.get_scale() - 500.0).abs() < 1e-6);
406        assert_eq!(new_scaler.growth_tracker, 10);
407    }
408
409    #[test]
410    fn test_builder_pattern() {
411        let scaler = GradScaler::with_scale(1000.0)
412            .growth_factor(3.0)
413            .backoff_factor(0.25)
414            .growth_interval(100);
415
416        assert!((scaler.growth_factor - 3.0).abs() < 1e-6);
417        assert!((scaler.backoff_factor - 0.25).abs() < 1e-6);
418        assert_eq!(scaler.growth_interval, 100);
419    }
420}