axonml_optim/
grad_scaler.rs1#[derive(Debug, Clone)]
28pub struct GradScaler {
29 scale: f32,
31 growth_factor: f32,
33 backoff_factor: f32,
35 growth_interval: usize,
37 growth_tracker: usize,
39 found_inf: bool,
41 enabled: bool,
43}
44
45impl Default for GradScaler {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl GradScaler {
52 #[must_use]
60 pub fn new() -> Self {
61 Self {
62 scale: 65536.0,
63 growth_factor: 2.0,
64 backoff_factor: 0.5,
65 growth_interval: 2000,
66 growth_tracker: 0,
67 found_inf: false,
68 enabled: true,
69 }
70 }
71
72 #[must_use]
74 pub fn with_scale(init_scale: f32) -> Self {
75 Self {
76 scale: init_scale,
77 ..Self::new()
78 }
79 }
80
81 #[must_use]
83 pub fn with_options(
84 init_scale: f32,
85 growth_factor: f32,
86 backoff_factor: f32,
87 growth_interval: usize,
88 ) -> Self {
89 Self {
90 scale: init_scale,
91 growth_factor,
92 backoff_factor,
93 growth_interval,
94 growth_tracker: 0,
95 found_inf: false,
96 enabled: true,
97 }
98 }
99
100 #[must_use]
102 pub fn growth_factor(mut self, factor: f32) -> Self {
103 self.growth_factor = factor;
104 self
105 }
106
107 #[must_use]
109 pub fn backoff_factor(mut self, factor: f32) -> Self {
110 self.backoff_factor = factor;
111 self
112 }
113
114 #[must_use]
116 pub fn growth_interval(mut self, interval: usize) -> Self {
117 self.growth_interval = interval;
118 self
119 }
120
121 #[must_use]
123 pub fn enabled(mut self, enabled: bool) -> Self {
124 self.enabled = enabled;
125 self
126 }
127
128 #[must_use]
130 pub fn get_scale(&self) -> f32 {
131 if self.enabled { self.scale } else { 1.0 }
132 }
133
134 pub fn set_scale(&mut self, scale: f32) {
136 self.scale = scale;
137 }
138
139 #[must_use]
141 pub fn is_enabled(&self) -> bool {
142 self.enabled
143 }
144
145 pub fn set_enabled(&mut self, enabled: bool) {
147 self.enabled = enabled;
148 }
149
150 #[must_use]
154 pub fn scale_loss(&self, loss: f32) -> f32 {
155 if self.enabled {
156 loss * self.scale
157 } else {
158 loss
159 }
160 }
161
162 pub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool {
166 if !self.enabled {
167 self.found_inf = false;
168 return true;
169 }
170
171 let inv_scale = 1.0 / self.scale;
172 self.found_inf = false;
173
174 for g in grads.iter_mut() {
175 if g.is_infinite() || g.is_nan() {
176 self.found_inf = true;
177 }
180 *g *= inv_scale;
181 }
182
183 !self.found_inf
184 }
185
186 pub fn unscale_optimizer<O: crate::Optimizer>(&mut self, optimizer: &O) -> bool {
191 if !self.enabled {
192 self.found_inf = false;
193 return true;
194 }
195
196 let inv_scale = 1.0 / self.scale;
197 self.found_inf = false;
198
199 for param in optimizer.parameters() {
200 if let Some(grad) = param.grad() {
201 let mut grad_vec = grad.to_vec();
202 for g in &mut grad_vec {
203 if g.is_infinite() || g.is_nan() {
204 self.found_inf = true;
205 }
206 *g *= inv_scale;
207 }
208 let unscaled = axonml_tensor::Tensor::from_vec(grad_vec, grad.shape())
209 .expect("grad_scaler: tensor creation failed");
210 param.set_grad(unscaled);
211 }
212 }
213
214 !self.found_inf
215 }
216
217 #[must_use]
219 pub fn check_grads(&self, grads: &[f32]) -> bool {
220 grads.iter().all(|g| g.is_finite())
221 }
222
223 #[must_use]
225 pub fn found_inf(&self) -> bool {
226 self.found_inf
227 }
228
229 pub fn set_found_inf(&mut self, found: bool) {
231 self.found_inf = found;
232 }
233
234 pub fn update(&mut self) {
240 if !self.enabled {
241 return;
242 }
243
244 if self.found_inf {
245 self.scale *= self.backoff_factor;
247 self.growth_tracker = 0;
248 self.scale = self.scale.max(1.0);
250 } else {
251 self.growth_tracker += 1;
253 if self.growth_tracker >= self.growth_interval {
254 self.scale *= self.growth_factor;
256 self.growth_tracker = 0;
257 self.scale = self.scale.min(f32::MAX / 2.0);
259 }
260 }
261 }
262
263 #[must_use]
265 pub fn state_dict(&self) -> GradScalerState {
266 GradScalerState {
267 scale: self.scale,
268 growth_tracker: self.growth_tracker,
269 }
270 }
271
272 pub fn load_state_dict(&mut self, state: GradScalerState) {
274 self.scale = state.scale;
275 self.growth_tracker = state.growth_tracker;
276 }
277}
278
279#[derive(Debug, Clone, Copy)]
281pub struct GradScalerState {
282 pub scale: f32,
284 pub growth_tracker: usize,
286}
287
288#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_grad_scaler_creation() {
298 let scaler = GradScaler::new();
299 assert!((scaler.get_scale() - 65536.0).abs() < 1e-6);
300 assert!(scaler.is_enabled());
301 assert!(!scaler.found_inf());
302 }
303
304 #[test]
305 fn test_grad_scaler_with_scale() {
306 let scaler = GradScaler::with_scale(1024.0);
307 assert!((scaler.get_scale() - 1024.0).abs() < 1e-6);
308 }
309
310 #[test]
311 fn test_scale_loss() {
312 let scaler = GradScaler::with_scale(100.0);
313 let loss = 0.5;
314 let scaled = scaler.scale_loss(loss);
315 assert!((scaled - 50.0).abs() < 1e-6);
316 }
317
318 #[test]
319 fn test_unscale_grads() {
320 let mut scaler = GradScaler::with_scale(100.0);
321 let mut grads = vec![100.0, 200.0, 300.0];
322
323 let valid = scaler.unscale_grads(&mut grads);
324
325 assert!(valid);
326 assert!(!scaler.found_inf());
327 assert!((grads[0] - 1.0).abs() < 1e-6);
328 assert!((grads[1] - 2.0).abs() < 1e-6);
329 assert!((grads[2] - 3.0).abs() < 1e-6);
330 }
331
332 #[test]
333 fn test_unscale_grads_with_inf() {
334 let mut scaler = GradScaler::with_scale(100.0);
335 let mut grads = vec![100.0, f32::INFINITY, 300.0];
336
337 let valid = scaler.unscale_grads(&mut grads);
338
339 assert!(!valid);
340 assert!(scaler.found_inf());
341 }
342
343 #[test]
344 fn test_unscale_grads_with_nan() {
345 let mut scaler = GradScaler::with_scale(100.0);
346 let mut grads = vec![100.0, f32::NAN, 300.0];
347
348 let valid = scaler.unscale_grads(&mut grads);
349
350 assert!(!valid);
351 assert!(scaler.found_inf());
352 }
353
354 #[test]
355 fn test_update_on_overflow() {
356 let mut scaler = GradScaler::with_scale(1000.0);
357 scaler.found_inf = true;
358
359 scaler.update();
360
361 assert!((scaler.get_scale() - 500.0).abs() < 1e-6);
362 assert_eq!(scaler.growth_tracker, 0);
363 }
364
365 #[test]
366 fn test_update_growth() {
367 let mut scaler = GradScaler::with_options(100.0, 2.0, 0.5, 3);
368
369 for _ in 0..3 {
371 scaler.found_inf = false;
372 scaler.update();
373 }
374
375 assert!((scaler.get_scale() - 200.0).abs() < 1e-6);
376 assert_eq!(scaler.growth_tracker, 0);
377 }
378
379 #[test]
380 fn test_disabled_scaler() {
381 let mut scaler = GradScaler::new().enabled(false);
382
383 assert!(!scaler.is_enabled());
384 assert!((scaler.get_scale() - 1.0).abs() < 1e-6);
385 assert!((scaler.scale_loss(0.5) - 0.5).abs() < 1e-6);
386
387 let mut grads = vec![1.0, 2.0, 3.0];
388 let valid = scaler.unscale_grads(&mut grads);
389 assert!(valid);
390 assert!((grads[0] - 1.0).abs() < 1e-6);
392 }
393
394 #[test]
395 fn test_state_dict() {
396 let mut scaler = GradScaler::with_scale(500.0);
397 scaler.growth_tracker = 10;
398
399 let state = scaler.state_dict();
400 assert!((state.scale - 500.0).abs() < 1e-6);
401 assert_eq!(state.growth_tracker, 10);
402
403 let mut new_scaler = GradScaler::new();
404 new_scaler.load_state_dict(state);
405 assert!((new_scaler.get_scale() - 500.0).abs() < 1e-6);
406 assert_eq!(new_scaler.growth_tracker, 10);
407 }
408
409 #[test]
410 fn test_builder_pattern() {
411 let scaler = GradScaler::with_scale(1000.0)
412 .growth_factor(3.0)
413 .backoff_factor(0.25)
414 .growth_interval(100);
415
416 assert!((scaler.growth_factor - 3.0).abs() < 1e-6);
417 assert!((scaler.backoff_factor - 0.25).abs() < 1e-6);
418 assert_eq!(scaler.growth_interval, 100);
419 }
420}