axonml_optim/
grad_scaler.rs1#[derive(Debug, Clone)]
28pub struct GradScaler {
29 scale: f32,
31 growth_factor: f32,
33 backoff_factor: f32,
35 growth_interval: usize,
37 growth_tracker: usize,
39 found_inf: bool,
41 enabled: bool,
43}
44
45impl Default for GradScaler {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl GradScaler {
52 #[must_use]
60 pub fn new() -> Self {
61 Self {
62 scale: 65536.0,
63 growth_factor: 2.0,
64 backoff_factor: 0.5,
65 growth_interval: 2000,
66 growth_tracker: 0,
67 found_inf: false,
68 enabled: true,
69 }
70 }
71
72 #[must_use]
74 pub fn with_scale(init_scale: f32) -> Self {
75 Self {
76 scale: init_scale,
77 ..Self::new()
78 }
79 }
80
81 #[must_use]
83 pub fn with_options(
84 init_scale: f32,
85 growth_factor: f32,
86 backoff_factor: f32,
87 growth_interval: usize,
88 ) -> Self {
89 Self {
90 scale: init_scale,
91 growth_factor,
92 backoff_factor,
93 growth_interval,
94 growth_tracker: 0,
95 found_inf: false,
96 enabled: true,
97 }
98 }
99
100 #[must_use]
102 pub fn growth_factor(mut self, factor: f32) -> Self {
103 self.growth_factor = factor;
104 self
105 }
106
107 #[must_use]
109 pub fn backoff_factor(mut self, factor: f32) -> Self {
110 self.backoff_factor = factor;
111 self
112 }
113
114 #[must_use]
116 pub fn growth_interval(mut self, interval: usize) -> Self {
117 self.growth_interval = interval;
118 self
119 }
120
121 #[must_use]
123 pub fn enabled(mut self, enabled: bool) -> Self {
124 self.enabled = enabled;
125 self
126 }
127
128 #[must_use]
130 pub fn get_scale(&self) -> f32 {
131 if self.enabled { self.scale } else { 1.0 }
132 }
133
134 pub fn set_scale(&mut self, scale: f32) {
136 self.scale = scale;
137 }
138
139 #[must_use]
141 pub fn is_enabled(&self) -> bool {
142 self.enabled
143 }
144
145 pub fn set_enabled(&mut self, enabled: bool) {
147 self.enabled = enabled;
148 }
149
150 #[must_use]
154 pub fn scale_loss(&self, loss: f32) -> f32 {
155 if self.enabled {
156 loss * self.scale
157 } else {
158 loss
159 }
160 }
161
162 pub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool {
166 if !self.enabled {
167 self.found_inf = false;
168 return true;
169 }
170
171 let inv_scale = 1.0 / self.scale;
172 self.found_inf = false;
173
174 for g in grads.iter_mut() {
175 if g.is_infinite() || g.is_nan() {
176 self.found_inf = true;
177 }
180 *g *= inv_scale;
181 }
182
183 !self.found_inf
184 }
185
186 #[must_use]
188 pub fn check_grads(&self, grads: &[f32]) -> bool {
189 grads.iter().all(|g| g.is_finite())
190 }
191
192 #[must_use]
194 pub fn found_inf(&self) -> bool {
195 self.found_inf
196 }
197
198 pub fn set_found_inf(&mut self, found: bool) {
200 self.found_inf = found;
201 }
202
203 pub fn update(&mut self) {
209 if !self.enabled {
210 return;
211 }
212
213 if self.found_inf {
214 self.scale *= self.backoff_factor;
216 self.growth_tracker = 0;
217 self.scale = self.scale.max(1.0);
219 } else {
220 self.growth_tracker += 1;
222 if self.growth_tracker >= self.growth_interval {
223 self.scale *= self.growth_factor;
225 self.growth_tracker = 0;
226 self.scale = self.scale.min(f32::MAX / 2.0);
228 }
229 }
230 }
231
232 #[must_use]
234 pub fn state_dict(&self) -> GradScalerState {
235 GradScalerState {
236 scale: self.scale,
237 growth_tracker: self.growth_tracker,
238 }
239 }
240
241 pub fn load_state_dict(&mut self, state: GradScalerState) {
243 self.scale = state.scale;
244 self.growth_tracker = state.growth_tracker;
245 }
246}
247
248#[derive(Debug, Clone, Copy)]
250pub struct GradScalerState {
251 pub scale: f32,
253 pub growth_tracker: usize,
255}
256
257#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_grad_scaler_creation() {
267 let scaler = GradScaler::new();
268 assert!((scaler.get_scale() - 65536.0).abs() < 1e-6);
269 assert!(scaler.is_enabled());
270 assert!(!scaler.found_inf());
271 }
272
273 #[test]
274 fn test_grad_scaler_with_scale() {
275 let scaler = GradScaler::with_scale(1024.0);
276 assert!((scaler.get_scale() - 1024.0).abs() < 1e-6);
277 }
278
279 #[test]
280 fn test_scale_loss() {
281 let scaler = GradScaler::with_scale(100.0);
282 let loss = 0.5;
283 let scaled = scaler.scale_loss(loss);
284 assert!((scaled - 50.0).abs() < 1e-6);
285 }
286
287 #[test]
288 fn test_unscale_grads() {
289 let mut scaler = GradScaler::with_scale(100.0);
290 let mut grads = vec![100.0, 200.0, 300.0];
291
292 let valid = scaler.unscale_grads(&mut grads);
293
294 assert!(valid);
295 assert!(!scaler.found_inf());
296 assert!((grads[0] - 1.0).abs() < 1e-6);
297 assert!((grads[1] - 2.0).abs() < 1e-6);
298 assert!((grads[2] - 3.0).abs() < 1e-6);
299 }
300
301 #[test]
302 fn test_unscale_grads_with_inf() {
303 let mut scaler = GradScaler::with_scale(100.0);
304 let mut grads = vec![100.0, f32::INFINITY, 300.0];
305
306 let valid = scaler.unscale_grads(&mut grads);
307
308 assert!(!valid);
309 assert!(scaler.found_inf());
310 }
311
312 #[test]
313 fn test_unscale_grads_with_nan() {
314 let mut scaler = GradScaler::with_scale(100.0);
315 let mut grads = vec![100.0, f32::NAN, 300.0];
316
317 let valid = scaler.unscale_grads(&mut grads);
318
319 assert!(!valid);
320 assert!(scaler.found_inf());
321 }
322
323 #[test]
324 fn test_update_on_overflow() {
325 let mut scaler = GradScaler::with_scale(1000.0);
326 scaler.found_inf = true;
327
328 scaler.update();
329
330 assert!((scaler.get_scale() - 500.0).abs() < 1e-6);
331 assert_eq!(scaler.growth_tracker, 0);
332 }
333
334 #[test]
335 fn test_update_growth() {
336 let mut scaler = GradScaler::with_options(100.0, 2.0, 0.5, 3);
337
338 for _ in 0..3 {
340 scaler.found_inf = false;
341 scaler.update();
342 }
343
344 assert!((scaler.get_scale() - 200.0).abs() < 1e-6);
345 assert_eq!(scaler.growth_tracker, 0);
346 }
347
348 #[test]
349 fn test_disabled_scaler() {
350 let mut scaler = GradScaler::new().enabled(false);
351
352 assert!(!scaler.is_enabled());
353 assert!((scaler.get_scale() - 1.0).abs() < 1e-6);
354 assert!((scaler.scale_loss(0.5) - 0.5).abs() < 1e-6);
355
356 let mut grads = vec![1.0, 2.0, 3.0];
357 let valid = scaler.unscale_grads(&mut grads);
358 assert!(valid);
359 assert!((grads[0] - 1.0).abs() < 1e-6);
361 }
362
363 #[test]
364 fn test_state_dict() {
365 let mut scaler = GradScaler::with_scale(500.0);
366 scaler.growth_tracker = 10;
367
368 let state = scaler.state_dict();
369 assert!((state.scale - 500.0).abs() < 1e-6);
370 assert_eq!(state.growth_tracker, 10);
371
372 let mut new_scaler = GradScaler::new();
373 new_scaler.load_state_dict(state);
374 assert!((new_scaler.get_scale() - 500.0).abs() < 1e-6);
375 assert_eq!(new_scaler.growth_tracker, 10);
376 }
377
378 #[test]
379 fn test_builder_pattern() {
380 let scaler = GradScaler::with_scale(1000.0)
381 .growth_factor(3.0)
382 .backoff_factor(0.25)
383 .growth_interval(100);
384
385 assert!((scaler.growth_factor - 3.0).abs() < 1e-6);
386 assert!((scaler.backoff_factor - 0.25).abs() < 1e-6);
387 assert_eq!(scaler.growth_interval, 100);
388 }
389}