Skip to main content

axonml_optim/
grad_scaler.rs

1//! Gradient Scaler for Mixed Precision Training
2//!
3//! # File
4//! `crates/axonml-optim/src/grad_scaler.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17// =============================================================================
18// GradScaler
19// =============================================================================
20
21/// Gradient scaler for mixed precision training.
22///
23/// Scales the loss to prevent gradient underflow when using F16,
24/// then unscales gradients before the optimizer step.
25///
26/// The scale is automatically adjusted based on whether gradients overflow.
27#[derive(Debug, Clone)]
28pub struct GradScaler {
29    /// Current scale factor
30    scale: f32,
31    /// Factor to multiply scale by on successful steps
32    growth_factor: f32,
33    /// Factor to multiply scale by when overflow detected
34    backoff_factor: f32,
35    /// Number of successful steps before growing scale
36    growth_interval: usize,
37    /// Counter for successful steps since last growth
38    growth_tracker: usize,
39    /// Whether inf/nan was found in last unscale
40    found_inf: bool,
41    /// Whether the scaler is enabled
42    enabled: bool,
43}
44
45impl Default for GradScaler {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl GradScaler {
52    /// Creates a new gradient scaler with default settings.
53    ///
54    /// Default configuration:
55    /// - Initial scale: 65536.0 (2^16)
56    /// - Growth factor: 2.0
57    /// - Backoff factor: 0.5
58    /// - Growth interval: 2000 steps
59    #[must_use]
60    pub fn new() -> Self {
61        Self {
62            scale: 65536.0,
63            growth_factor: 2.0,
64            backoff_factor: 0.5,
65            growth_interval: 2000,
66            growth_tracker: 0,
67            found_inf: false,
68            enabled: true,
69        }
70    }
71
72    /// Creates a gradient scaler with custom initial scale.
73    #[must_use]
74    pub fn with_scale(init_scale: f32) -> Self {
75        Self {
76            scale: init_scale,
77            ..Self::new()
78        }
79    }
80
81    /// Creates a gradient scaler with all custom settings.
82    #[must_use]
83    pub fn with_options(
84        init_scale: f32,
85        growth_factor: f32,
86        backoff_factor: f32,
87        growth_interval: usize,
88    ) -> Self {
89        Self {
90            scale: init_scale,
91            growth_factor,
92            backoff_factor,
93            growth_interval,
94            growth_tracker: 0,
95            found_inf: false,
96            enabled: true,
97        }
98    }
99
100    /// Builder: set growth factor
101    #[must_use]
102    pub fn growth_factor(mut self, factor: f32) -> Self {
103        self.growth_factor = factor;
104        self
105    }
106
107    /// Builder: set backoff factor
108    #[must_use]
109    pub fn backoff_factor(mut self, factor: f32) -> Self {
110        self.backoff_factor = factor;
111        self
112    }
113
114    /// Builder: set growth interval
115    #[must_use]
116    pub fn growth_interval(mut self, interval: usize) -> Self {
117        self.growth_interval = interval;
118        self
119    }
120
121    /// Builder: set enabled state
122    #[must_use]
123    pub fn enabled(mut self, enabled: bool) -> Self {
124        self.enabled = enabled;
125        self
126    }
127
128    /// Returns the current scale factor.
129    #[must_use]
130    pub fn get_scale(&self) -> f32 {
131        if self.enabled { self.scale } else { 1.0 }
132    }
133
134    /// Sets the scale factor.
135    pub fn set_scale(&mut self, scale: f32) {
136        self.scale = scale;
137    }
138
139    /// Returns whether the scaler is enabled.
140    #[must_use]
141    pub fn is_enabled(&self) -> bool {
142        self.enabled
143    }
144
145    /// Enables or disables the scaler.
146    pub fn set_enabled(&mut self, enabled: bool) {
147        self.enabled = enabled;
148    }
149
150    /// Scales a loss value for backward pass.
151    ///
152    /// Multiply the loss by this before calling backward().
153    #[must_use]
154    pub fn scale_loss(&self, loss: f32) -> f32 {
155        if self.enabled {
156            loss * self.scale
157        } else {
158            loss
159        }
160    }
161
162    /// Unscales gradients in place and checks for inf/nan.
163    ///
164    /// Returns true if all gradients are finite, false if any overflow.
165    pub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool {
166        if !self.enabled {
167            self.found_inf = false;
168            return true;
169        }
170
171        let inv_scale = 1.0 / self.scale;
172        self.found_inf = false;
173
174        for g in grads.iter_mut() {
175            if g.is_infinite() || g.is_nan() {
176                self.found_inf = true;
177                // Don't return early - still need to unscale other grads
178                // But mark that we found inf
179            }
180            *g *= inv_scale;
181        }
182
183        !self.found_inf
184    }
185
186    /// Checks a slice of gradients for inf/nan without modifying them.
187    #[must_use]
188    pub fn check_grads(&self, grads: &[f32]) -> bool {
189        grads.iter().all(|g| g.is_finite())
190    }
191
192    /// Returns whether inf/nan was found in the last unscale operation.
193    #[must_use]
194    pub fn found_inf(&self) -> bool {
195        self.found_inf
196    }
197
198    /// Marks that inf was found (for external gradient checking).
199    pub fn set_found_inf(&mut self, found: bool) {
200        self.found_inf = found;
201    }
202
203    /// Updates the scale factor based on overflow history.
204    ///
205    /// Call this after each optimizer step:
206    /// - If overflow was detected, scale is reduced by backoff_factor
207    /// - If no overflow for growth_interval steps, scale is increased by growth_factor
208    pub fn update(&mut self) {
209        if !self.enabled {
210            return;
211        }
212
213        if self.found_inf {
214            // Reduce scale on overflow
215            self.scale *= self.backoff_factor;
216            self.growth_tracker = 0;
217            // Clamp to avoid too small scale
218            self.scale = self.scale.max(1.0);
219        } else {
220            // Track successful steps
221            self.growth_tracker += 1;
222            if self.growth_tracker >= self.growth_interval {
223                // Increase scale
224                self.scale *= self.growth_factor;
225                self.growth_tracker = 0;
226                // Clamp to avoid overflow
227                self.scale = self.scale.min(f32::MAX / 2.0);
228            }
229        }
230    }
231
232    /// Returns the current state for checkpointing.
233    #[must_use]
234    pub fn state_dict(&self) -> GradScalerState {
235        GradScalerState {
236            scale: self.scale,
237            growth_tracker: self.growth_tracker,
238        }
239    }
240
241    /// Loads state from a checkpoint.
242    pub fn load_state_dict(&mut self, state: GradScalerState) {
243        self.scale = state.scale;
244        self.growth_tracker = state.growth_tracker;
245    }
246}
247
248/// Serializable state for GradScaler checkpointing.
249#[derive(Debug, Clone, Copy)]
250pub struct GradScalerState {
251    /// Current scale factor
252    pub scale: f32,
253    /// Growth tracker value
254    pub growth_tracker: usize,
255}
256
257// =============================================================================
258// Tests
259// =============================================================================
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_grad_scaler_creation() {
267        let scaler = GradScaler::new();
268        assert!((scaler.get_scale() - 65536.0).abs() < 1e-6);
269        assert!(scaler.is_enabled());
270        assert!(!scaler.found_inf());
271    }
272
273    #[test]
274    fn test_grad_scaler_with_scale() {
275        let scaler = GradScaler::with_scale(1024.0);
276        assert!((scaler.get_scale() - 1024.0).abs() < 1e-6);
277    }
278
279    #[test]
280    fn test_scale_loss() {
281        let scaler = GradScaler::with_scale(100.0);
282        let loss = 0.5;
283        let scaled = scaler.scale_loss(loss);
284        assert!((scaled - 50.0).abs() < 1e-6);
285    }
286
287    #[test]
288    fn test_unscale_grads() {
289        let mut scaler = GradScaler::with_scale(100.0);
290        let mut grads = vec![100.0, 200.0, 300.0];
291
292        let valid = scaler.unscale_grads(&mut grads);
293
294        assert!(valid);
295        assert!(!scaler.found_inf());
296        assert!((grads[0] - 1.0).abs() < 1e-6);
297        assert!((grads[1] - 2.0).abs() < 1e-6);
298        assert!((grads[2] - 3.0).abs() < 1e-6);
299    }
300
301    #[test]
302    fn test_unscale_grads_with_inf() {
303        let mut scaler = GradScaler::with_scale(100.0);
304        let mut grads = vec![100.0, f32::INFINITY, 300.0];
305
306        let valid = scaler.unscale_grads(&mut grads);
307
308        assert!(!valid);
309        assert!(scaler.found_inf());
310    }
311
312    #[test]
313    fn test_unscale_grads_with_nan() {
314        let mut scaler = GradScaler::with_scale(100.0);
315        let mut grads = vec![100.0, f32::NAN, 300.0];
316
317        let valid = scaler.unscale_grads(&mut grads);
318
319        assert!(!valid);
320        assert!(scaler.found_inf());
321    }
322
323    #[test]
324    fn test_update_on_overflow() {
325        let mut scaler = GradScaler::with_scale(1000.0);
326        scaler.found_inf = true;
327
328        scaler.update();
329
330        assert!((scaler.get_scale() - 500.0).abs() < 1e-6);
331        assert_eq!(scaler.growth_tracker, 0);
332    }
333
334    #[test]
335    fn test_update_growth() {
336        let mut scaler = GradScaler::with_options(100.0, 2.0, 0.5, 3);
337
338        // Simulate 3 successful steps
339        for _ in 0..3 {
340            scaler.found_inf = false;
341            scaler.update();
342        }
343
344        assert!((scaler.get_scale() - 200.0).abs() < 1e-6);
345        assert_eq!(scaler.growth_tracker, 0);
346    }
347
348    #[test]
349    fn test_disabled_scaler() {
350        let mut scaler = GradScaler::new().enabled(false);
351
352        assert!(!scaler.is_enabled());
353        assert!((scaler.get_scale() - 1.0).abs() < 1e-6);
354        assert!((scaler.scale_loss(0.5) - 0.5).abs() < 1e-6);
355
356        let mut grads = vec![1.0, 2.0, 3.0];
357        let valid = scaler.unscale_grads(&mut grads);
358        assert!(valid);
359        // Grads should be unchanged
360        assert!((grads[0] - 1.0).abs() < 1e-6);
361    }
362
363    #[test]
364    fn test_state_dict() {
365        let mut scaler = GradScaler::with_scale(500.0);
366        scaler.growth_tracker = 10;
367
368        let state = scaler.state_dict();
369        assert!((state.scale - 500.0).abs() < 1e-6);
370        assert_eq!(state.growth_tracker, 10);
371
372        let mut new_scaler = GradScaler::new();
373        new_scaler.load_state_dict(state);
374        assert!((new_scaler.get_scale() - 500.0).abs() < 1e-6);
375        assert_eq!(new_scaler.growth_tracker, 10);
376    }
377
378    #[test]
379    fn test_builder_pattern() {
380        let scaler = GradScaler::with_scale(1000.0)
381            .growth_factor(3.0)
382            .backoff_factor(0.25)
383            .growth_interval(100);
384
385        assert!((scaler.growth_factor - 3.0).abs() < 1e-6);
386        assert!((scaler.backoff_factor - 0.25).abs() < 1e-6);
387        assert_eq!(scaler.growth_interval, 100);
388    }
389}