axonml_optim/
grad_scaler.rs1#[derive(Debug, Clone)]
34pub struct GradScaler {
35 scale: f32,
37 growth_factor: f32,
39 backoff_factor: f32,
41 growth_interval: usize,
43 growth_tracker: usize,
45 found_inf: bool,
47 enabled: bool,
49}
50
51impl Default for GradScaler {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl GradScaler {
58 #[must_use]
66 pub fn new() -> Self {
67 Self {
68 scale: 65536.0,
69 growth_factor: 2.0,
70 backoff_factor: 0.5,
71 growth_interval: 2000,
72 growth_tracker: 0,
73 found_inf: false,
74 enabled: true,
75 }
76 }
77
78 #[must_use]
80 pub fn with_scale(init_scale: f32) -> Self {
81 Self {
82 scale: init_scale,
83 ..Self::new()
84 }
85 }
86
87 #[must_use]
89 pub fn with_options(
90 init_scale: f32,
91 growth_factor: f32,
92 backoff_factor: f32,
93 growth_interval: usize,
94 ) -> Self {
95 Self {
96 scale: init_scale,
97 growth_factor,
98 backoff_factor,
99 growth_interval,
100 growth_tracker: 0,
101 found_inf: false,
102 enabled: true,
103 }
104 }
105
106 #[must_use]
108 pub fn growth_factor(mut self, factor: f32) -> Self {
109 self.growth_factor = factor;
110 self
111 }
112
113 #[must_use]
115 pub fn backoff_factor(mut self, factor: f32) -> Self {
116 self.backoff_factor = factor;
117 self
118 }
119
120 #[must_use]
122 pub fn growth_interval(mut self, interval: usize) -> Self {
123 self.growth_interval = interval;
124 self
125 }
126
127 #[must_use]
129 pub fn enabled(mut self, enabled: bool) -> Self {
130 self.enabled = enabled;
131 self
132 }
133
134 #[must_use]
136 pub fn get_scale(&self) -> f32 {
137 if self.enabled { self.scale } else { 1.0 }
138 }
139
140 pub fn set_scale(&mut self, scale: f32) {
142 self.scale = scale;
143 }
144
145 #[must_use]
147 pub fn is_enabled(&self) -> bool {
148 self.enabled
149 }
150
151 pub fn set_enabled(&mut self, enabled: bool) {
153 self.enabled = enabled;
154 }
155
156 #[must_use]
160 pub fn scale_loss(&self, loss: f32) -> f32 {
161 if self.enabled {
162 loss * self.scale
163 } else {
164 loss
165 }
166 }
167
168 pub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool {
172 if !self.enabled {
173 self.found_inf = false;
174 return true;
175 }
176
177 let inv_scale = 1.0 / self.scale;
178 self.found_inf = false;
179
180 for g in grads.iter_mut() {
181 if g.is_infinite() || g.is_nan() {
182 self.found_inf = true;
183 }
186 *g *= inv_scale;
187 }
188
189 !self.found_inf
190 }
191
192 pub fn unscale_optimizer<O: crate::Optimizer>(&mut self, optimizer: &O) -> bool {
197 if !self.enabled {
198 self.found_inf = false;
199 return true;
200 }
201
202 let inv_scale = 1.0 / self.scale;
203 self.found_inf = false;
204
205 for param in optimizer.parameters() {
206 if let Some(grad) = param.grad() {
207 let mut grad_vec = grad.to_vec();
208 for g in &mut grad_vec {
209 if g.is_infinite() || g.is_nan() {
210 self.found_inf = true;
211 }
212 *g *= inv_scale;
213 }
214 let unscaled = axonml_tensor::Tensor::from_vec(grad_vec, grad.shape())
215 .expect("grad_scaler: tensor creation failed");
216 param.set_grad(unscaled);
217 }
218 }
219
220 !self.found_inf
221 }
222
223 #[must_use]
225 pub fn check_grads(&self, grads: &[f32]) -> bool {
226 grads.iter().all(|g| g.is_finite())
227 }
228
229 #[must_use]
231 pub fn found_inf(&self) -> bool {
232 self.found_inf
233 }
234
235 pub fn set_found_inf(&mut self, found: bool) {
237 self.found_inf = found;
238 }
239
240 pub fn update(&mut self) {
246 if !self.enabled {
247 return;
248 }
249
250 if self.found_inf {
251 self.scale *= self.backoff_factor;
253 self.growth_tracker = 0;
254 self.scale = self.scale.max(1.0);
256 } else {
257 self.growth_tracker += 1;
259 if self.growth_tracker >= self.growth_interval {
260 self.scale *= self.growth_factor;
262 self.growth_tracker = 0;
263 self.scale = self.scale.min(f32::MAX / 2.0);
265 }
266 }
267 }
268
269 #[must_use]
271 pub fn state_dict(&self) -> GradScalerState {
272 GradScalerState {
273 scale: self.scale,
274 growth_tracker: self.growth_tracker,
275 }
276 }
277
278 pub fn load_state_dict(&mut self, state: GradScalerState) {
280 self.scale = state.scale;
281 self.growth_tracker = state.growth_tracker;
282 }
283}
284
285#[derive(Debug, Clone, Copy)]
287pub struct GradScalerState {
288 pub scale: f32,
290 pub growth_tracker: usize,
292}
293
294#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_grad_scaler_creation() {
304 let scaler = GradScaler::new();
305 assert!((scaler.get_scale() - 65536.0).abs() < 1e-6);
306 assert!(scaler.is_enabled());
307 assert!(!scaler.found_inf());
308 }
309
310 #[test]
311 fn test_grad_scaler_with_scale() {
312 let scaler = GradScaler::with_scale(1024.0);
313 assert!((scaler.get_scale() - 1024.0).abs() < 1e-6);
314 }
315
316 #[test]
317 fn test_scale_loss() {
318 let scaler = GradScaler::with_scale(100.0);
319 let loss = 0.5;
320 let scaled = scaler.scale_loss(loss);
321 assert!((scaled - 50.0).abs() < 1e-6);
322 }
323
324 #[test]
325 fn test_unscale_grads() {
326 let mut scaler = GradScaler::with_scale(100.0);
327 let mut grads = vec![100.0, 200.0, 300.0];
328
329 let valid = scaler.unscale_grads(&mut grads);
330
331 assert!(valid);
332 assert!(!scaler.found_inf());
333 assert!((grads[0] - 1.0).abs() < 1e-6);
334 assert!((grads[1] - 2.0).abs() < 1e-6);
335 assert!((grads[2] - 3.0).abs() < 1e-6);
336 }
337
338 #[test]
339 fn test_unscale_grads_with_inf() {
340 let mut scaler = GradScaler::with_scale(100.0);
341 let mut grads = vec![100.0, f32::INFINITY, 300.0];
342
343 let valid = scaler.unscale_grads(&mut grads);
344
345 assert!(!valid);
346 assert!(scaler.found_inf());
347 }
348
349 #[test]
350 fn test_unscale_grads_with_nan() {
351 let mut scaler = GradScaler::with_scale(100.0);
352 let mut grads = vec![100.0, f32::NAN, 300.0];
353
354 let valid = scaler.unscale_grads(&mut grads);
355
356 assert!(!valid);
357 assert!(scaler.found_inf());
358 }
359
360 #[test]
361 fn test_update_on_overflow() {
362 let mut scaler = GradScaler::with_scale(1000.0);
363 scaler.found_inf = true;
364
365 scaler.update();
366
367 assert!((scaler.get_scale() - 500.0).abs() < 1e-6);
368 assert_eq!(scaler.growth_tracker, 0);
369 }
370
371 #[test]
372 fn test_update_growth() {
373 let mut scaler = GradScaler::with_options(100.0, 2.0, 0.5, 3);
374
375 for _ in 0..3 {
377 scaler.found_inf = false;
378 scaler.update();
379 }
380
381 assert!((scaler.get_scale() - 200.0).abs() < 1e-6);
382 assert_eq!(scaler.growth_tracker, 0);
383 }
384
385 #[test]
386 fn test_disabled_scaler() {
387 let mut scaler = GradScaler::new().enabled(false);
388
389 assert!(!scaler.is_enabled());
390 assert!((scaler.get_scale() - 1.0).abs() < 1e-6);
391 assert!((scaler.scale_loss(0.5) - 0.5).abs() < 1e-6);
392
393 let mut grads = vec![1.0, 2.0, 3.0];
394 let valid = scaler.unscale_grads(&mut grads);
395 assert!(valid);
396 assert!((grads[0] - 1.0).abs() < 1e-6);
398 }
399
400 #[test]
401 fn test_state_dict() {
402 let mut scaler = GradScaler::with_scale(500.0);
403 scaler.growth_tracker = 10;
404
405 let state = scaler.state_dict();
406 assert!((state.scale - 500.0).abs() < 1e-6);
407 assert_eq!(state.growth_tracker, 10);
408
409 let mut new_scaler = GradScaler::new();
410 new_scaler.load_state_dict(state);
411 assert!((new_scaler.get_scale() - 500.0).abs() < 1e-6);
412 assert_eq!(new_scaler.growth_tracker, 10);
413 }
414
415 #[test]
416 fn test_builder_pattern() {
417 let scaler = GradScaler::with_scale(1000.0)
418 .growth_factor(3.0)
419 .backoff_factor(0.25)
420 .growth_interval(100);
421
422 assert!((scaler.growth_factor - 3.0).abs() < 1e-6);
423 assert!((scaler.backoff_factor - 0.25).abs() < 1e-6);
424 assert_eq!(scaler.growth_interval, 100);
425 }
426}