Skip to main content

axonml_optim/
grad_scaler.rs

1//! `GradScaler` — dynamic loss scaling for AMP (mixed-precision) training.
2//!
3//! Scales the loss before backward to prevent F16 underflow, then unscales
4//! gradients before the optimizer step. Adaptive scale factor doubles
5//! every N steps if no NaN/Inf is detected, halves on overflow. Pairs
6//! with `AutocastGuard` in `axonml-autograd::amp`.
7//!
8//! # File
9//! `crates/axonml-optim/src/grad_scaler.rs`
10//!
11//! # Author
12//! Andrew Jewell Sr. — AutomataNexus LLC
13//! ORCID: 0009-0005-2158-7060
14//!
15//! # Updated
16//! April 14, 2026 11:15 PM EST
17//!
18//! # Disclaimer
19//! Use at own risk. This software is provided "as is", without warranty of any
20//! kind, express or implied. The author and AutomataNexus shall not be held
21//! liable for any damages arising from the use of this software.
22
23// =============================================================================
24// GradScaler
25// =============================================================================
26
27/// Gradient scaler for mixed precision training.
28///
29/// Scales the loss to prevent gradient underflow when using F16,
30/// then unscales gradients before the optimizer step.
31///
32/// The scale is automatically adjusted based on whether gradients overflow.
33#[derive(Debug, Clone)]
34pub struct GradScaler {
35    /// Current scale factor
36    scale: f32,
37    /// Factor to multiply scale by on successful steps
38    growth_factor: f32,
39    /// Factor to multiply scale by when overflow detected
40    backoff_factor: f32,
41    /// Number of successful steps before growing scale
42    growth_interval: usize,
43    /// Counter for successful steps since last growth
44    growth_tracker: usize,
45    /// Whether inf/nan was found in last unscale
46    found_inf: bool,
47    /// Whether the scaler is enabled
48    enabled: bool,
49}
50
51impl Default for GradScaler {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl GradScaler {
58    /// Creates a new gradient scaler with default settings.
59    ///
60    /// Default configuration:
61    /// - Initial scale: 65536.0 (2^16)
62    /// - Growth factor: 2.0
63    /// - Backoff factor: 0.5
64    /// - Growth interval: 2000 steps
65    #[must_use]
66    pub fn new() -> Self {
67        Self {
68            scale: 65536.0,
69            growth_factor: 2.0,
70            backoff_factor: 0.5,
71            growth_interval: 2000,
72            growth_tracker: 0,
73            found_inf: false,
74            enabled: true,
75        }
76    }
77
78    /// Creates a gradient scaler with custom initial scale.
79    #[must_use]
80    pub fn with_scale(init_scale: f32) -> Self {
81        Self {
82            scale: init_scale,
83            ..Self::new()
84        }
85    }
86
87    /// Creates a gradient scaler with all custom settings.
88    #[must_use]
89    pub fn with_options(
90        init_scale: f32,
91        growth_factor: f32,
92        backoff_factor: f32,
93        growth_interval: usize,
94    ) -> Self {
95        Self {
96            scale: init_scale,
97            growth_factor,
98            backoff_factor,
99            growth_interval,
100            growth_tracker: 0,
101            found_inf: false,
102            enabled: true,
103        }
104    }
105
106    /// Builder: set growth factor
107    #[must_use]
108    pub fn growth_factor(mut self, factor: f32) -> Self {
109        self.growth_factor = factor;
110        self
111    }
112
113    /// Builder: set backoff factor
114    #[must_use]
115    pub fn backoff_factor(mut self, factor: f32) -> Self {
116        self.backoff_factor = factor;
117        self
118    }
119
120    /// Builder: set growth interval
121    #[must_use]
122    pub fn growth_interval(mut self, interval: usize) -> Self {
123        self.growth_interval = interval;
124        self
125    }
126
127    /// Builder: set enabled state
128    #[must_use]
129    pub fn enabled(mut self, enabled: bool) -> Self {
130        self.enabled = enabled;
131        self
132    }
133
134    /// Returns the current scale factor.
135    #[must_use]
136    pub fn get_scale(&self) -> f32 {
137        if self.enabled { self.scale } else { 1.0 }
138    }
139
140    /// Sets the scale factor.
141    pub fn set_scale(&mut self, scale: f32) {
142        self.scale = scale;
143    }
144
145    /// Returns whether the scaler is enabled.
146    #[must_use]
147    pub fn is_enabled(&self) -> bool {
148        self.enabled
149    }
150
151    /// Enables or disables the scaler.
152    pub fn set_enabled(&mut self, enabled: bool) {
153        self.enabled = enabled;
154    }
155
156    /// Scales a loss value for backward pass.
157    ///
158    /// Multiply the loss by this before calling backward().
159    #[must_use]
160    pub fn scale_loss(&self, loss: f32) -> f32 {
161        if self.enabled {
162            loss * self.scale
163        } else {
164            loss
165        }
166    }
167
168    /// Unscales gradients in place and checks for inf/nan.
169    ///
170    /// Returns true if all gradients are finite, false if any overflow.
171    pub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool {
172        if !self.enabled {
173            self.found_inf = false;
174            return true;
175        }
176
177        let inv_scale = 1.0 / self.scale;
178        self.found_inf = false;
179
180        for g in grads.iter_mut() {
181            if g.is_infinite() || g.is_nan() {
182                self.found_inf = true;
183                // Don't return early - still need to unscale other grads
184                // But mark that we found inf
185            }
186            *g *= inv_scale;
187        }
188
189        !self.found_inf
190    }
191
192    /// Unscales gradients on all optimizer parameters in place.
193    ///
194    /// Equivalent to PyTorch's `GradScaler.unscale_(optimizer)`.
195    /// Returns true if all gradients are finite.
196    pub fn unscale_optimizer<O: crate::Optimizer>(&mut self, optimizer: &O) -> bool {
197        if !self.enabled {
198            self.found_inf = false;
199            return true;
200        }
201
202        let inv_scale = 1.0 / self.scale;
203        self.found_inf = false;
204
205        for param in optimizer.parameters() {
206            if let Some(grad) = param.grad() {
207                let mut grad_vec = grad.to_vec();
208                for g in &mut grad_vec {
209                    if g.is_infinite() || g.is_nan() {
210                        self.found_inf = true;
211                    }
212                    *g *= inv_scale;
213                }
214                let unscaled = axonml_tensor::Tensor::from_vec(grad_vec, grad.shape())
215                    .expect("grad_scaler: tensor creation failed");
216                param.set_grad(unscaled);
217            }
218        }
219
220        !self.found_inf
221    }
222
223    /// Checks a slice of gradients for inf/nan without modifying them.
224    #[must_use]
225    pub fn check_grads(&self, grads: &[f32]) -> bool {
226        grads.iter().all(|g| g.is_finite())
227    }
228
229    /// Returns whether inf/nan was found in the last unscale operation.
230    #[must_use]
231    pub fn found_inf(&self) -> bool {
232        self.found_inf
233    }
234
235    /// Marks that inf was found (for external gradient checking).
236    pub fn set_found_inf(&mut self, found: bool) {
237        self.found_inf = found;
238    }
239
240    /// Updates the scale factor based on overflow history.
241    ///
242    /// Call this after each optimizer step:
243    /// - If overflow was detected, scale is reduced by backoff_factor
244    /// - If no overflow for growth_interval steps, scale is increased by growth_factor
245    pub fn update(&mut self) {
246        if !self.enabled {
247            return;
248        }
249
250        if self.found_inf {
251            // Reduce scale on overflow
252            self.scale *= self.backoff_factor;
253            self.growth_tracker = 0;
254            // Clamp to avoid too small scale
255            self.scale = self.scale.max(1.0);
256        } else {
257            // Track successful steps
258            self.growth_tracker += 1;
259            if self.growth_tracker >= self.growth_interval {
260                // Increase scale
261                self.scale *= self.growth_factor;
262                self.growth_tracker = 0;
263                // Clamp to avoid overflow
264                self.scale = self.scale.min(f32::MAX / 2.0);
265            }
266        }
267    }
268
269    /// Returns the current state for checkpointing.
270    #[must_use]
271    pub fn state_dict(&self) -> GradScalerState {
272        GradScalerState {
273            scale: self.scale,
274            growth_tracker: self.growth_tracker,
275        }
276    }
277
278    /// Loads state from a checkpoint.
279    pub fn load_state_dict(&mut self, state: GradScalerState) {
280        self.scale = state.scale;
281        self.growth_tracker = state.growth_tracker;
282    }
283}
284
285/// Serializable state for GradScaler checkpointing.
286#[derive(Debug, Clone, Copy)]
287pub struct GradScalerState {
288    /// Current scale factor
289    pub scale: f32,
290    /// Growth tracker value
291    pub growth_tracker: usize,
292}
293
294// =============================================================================
295// Tests
296// =============================================================================
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_grad_scaler_creation() {
304        let scaler = GradScaler::new();
305        assert!((scaler.get_scale() - 65536.0).abs() < 1e-6);
306        assert!(scaler.is_enabled());
307        assert!(!scaler.found_inf());
308    }
309
310    #[test]
311    fn test_grad_scaler_with_scale() {
312        let scaler = GradScaler::with_scale(1024.0);
313        assert!((scaler.get_scale() - 1024.0).abs() < 1e-6);
314    }
315
316    #[test]
317    fn test_scale_loss() {
318        let scaler = GradScaler::with_scale(100.0);
319        let loss = 0.5;
320        let scaled = scaler.scale_loss(loss);
321        assert!((scaled - 50.0).abs() < 1e-6);
322    }
323
324    #[test]
325    fn test_unscale_grads() {
326        let mut scaler = GradScaler::with_scale(100.0);
327        let mut grads = vec![100.0, 200.0, 300.0];
328
329        let valid = scaler.unscale_grads(&mut grads);
330
331        assert!(valid);
332        assert!(!scaler.found_inf());
333        assert!((grads[0] - 1.0).abs() < 1e-6);
334        assert!((grads[1] - 2.0).abs() < 1e-6);
335        assert!((grads[2] - 3.0).abs() < 1e-6);
336    }
337
338    #[test]
339    fn test_unscale_grads_with_inf() {
340        let mut scaler = GradScaler::with_scale(100.0);
341        let mut grads = vec![100.0, f32::INFINITY, 300.0];
342
343        let valid = scaler.unscale_grads(&mut grads);
344
345        assert!(!valid);
346        assert!(scaler.found_inf());
347    }
348
349    #[test]
350    fn test_unscale_grads_with_nan() {
351        let mut scaler = GradScaler::with_scale(100.0);
352        let mut grads = vec![100.0, f32::NAN, 300.0];
353
354        let valid = scaler.unscale_grads(&mut grads);
355
356        assert!(!valid);
357        assert!(scaler.found_inf());
358    }
359
360    #[test]
361    fn test_update_on_overflow() {
362        let mut scaler = GradScaler::with_scale(1000.0);
363        scaler.found_inf = true;
364
365        scaler.update();
366
367        assert!((scaler.get_scale() - 500.0).abs() < 1e-6);
368        assert_eq!(scaler.growth_tracker, 0);
369    }
370
371    #[test]
372    fn test_update_growth() {
373        let mut scaler = GradScaler::with_options(100.0, 2.0, 0.5, 3);
374
375        // Simulate 3 successful steps
376        for _ in 0..3 {
377            scaler.found_inf = false;
378            scaler.update();
379        }
380
381        assert!((scaler.get_scale() - 200.0).abs() < 1e-6);
382        assert_eq!(scaler.growth_tracker, 0);
383    }
384
385    #[test]
386    fn test_disabled_scaler() {
387        let mut scaler = GradScaler::new().enabled(false);
388
389        assert!(!scaler.is_enabled());
390        assert!((scaler.get_scale() - 1.0).abs() < 1e-6);
391        assert!((scaler.scale_loss(0.5) - 0.5).abs() < 1e-6);
392
393        let mut grads = vec![1.0, 2.0, 3.0];
394        let valid = scaler.unscale_grads(&mut grads);
395        assert!(valid);
396        // Grads should be unchanged
397        assert!((grads[0] - 1.0).abs() < 1e-6);
398    }
399
400    #[test]
401    fn test_state_dict() {
402        let mut scaler = GradScaler::with_scale(500.0);
403        scaler.growth_tracker = 10;
404
405        let state = scaler.state_dict();
406        assert!((state.scale - 500.0).abs() < 1e-6);
407        assert_eq!(state.growth_tracker, 10);
408
409        let mut new_scaler = GradScaler::new();
410        new_scaler.load_state_dict(state);
411        assert!((new_scaler.get_scale() - 500.0).abs() < 1e-6);
412        assert_eq!(new_scaler.growth_tracker, 10);
413    }
414
415    #[test]
416    fn test_builder_pattern() {
417        let scaler = GradScaler::with_scale(1000.0)
418            .growth_factor(3.0)
419            .backoff_factor(0.25)
420            .growth_interval(100);
421
422        assert!((scaler.growth_factor - 3.0).abs() < 1e-6);
423        assert!((scaler.backoff_factor - 0.25).abs() < 1e-6);
424        assert_eq!(scaler.growth_interval, 100);
425    }
426}