axonml_optim/
grad_scaler.rs1#[derive(Debug, Clone)]
46pub struct GradScaler {
47 scale: f32,
49 growth_factor: f32,
51 backoff_factor: f32,
53 growth_interval: usize,
55 growth_tracker: usize,
57 found_inf: bool,
59 enabled: bool,
61}
62
63impl Default for GradScaler {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl GradScaler {
70 #[must_use]
78 pub fn new() -> Self {
79 Self {
80 scale: 65536.0,
81 growth_factor: 2.0,
82 backoff_factor: 0.5,
83 growth_interval: 2000,
84 growth_tracker: 0,
85 found_inf: false,
86 enabled: true,
87 }
88 }
89
90 #[must_use]
92 pub fn with_scale(init_scale: f32) -> Self {
93 Self {
94 scale: init_scale,
95 ..Self::new()
96 }
97 }
98
99 #[must_use]
101 pub fn with_options(
102 init_scale: f32,
103 growth_factor: f32,
104 backoff_factor: f32,
105 growth_interval: usize,
106 ) -> Self {
107 Self {
108 scale: init_scale,
109 growth_factor,
110 backoff_factor,
111 growth_interval,
112 growth_tracker: 0,
113 found_inf: false,
114 enabled: true,
115 }
116 }
117
118 #[must_use]
120 pub fn growth_factor(mut self, factor: f32) -> Self {
121 self.growth_factor = factor;
122 self
123 }
124
125 #[must_use]
127 pub fn backoff_factor(mut self, factor: f32) -> Self {
128 self.backoff_factor = factor;
129 self
130 }
131
132 #[must_use]
134 pub fn growth_interval(mut self, interval: usize) -> Self {
135 self.growth_interval = interval;
136 self
137 }
138
139 #[must_use]
141 pub fn enabled(mut self, enabled: bool) -> Self {
142 self.enabled = enabled;
143 self
144 }
145
146 #[must_use]
148 pub fn get_scale(&self) -> f32 {
149 if self.enabled {
150 self.scale
151 } else {
152 1.0
153 }
154 }
155
156 pub fn set_scale(&mut self, scale: f32) {
158 self.scale = scale;
159 }
160
161 #[must_use]
163 pub fn is_enabled(&self) -> bool {
164 self.enabled
165 }
166
167 pub fn set_enabled(&mut self, enabled: bool) {
169 self.enabled = enabled;
170 }
171
172 #[must_use]
176 pub fn scale_loss(&self, loss: f32) -> f32 {
177 if self.enabled {
178 loss * self.scale
179 } else {
180 loss
181 }
182 }
183
184 pub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool {
188 if !self.enabled {
189 self.found_inf = false;
190 return true;
191 }
192
193 let inv_scale = 1.0 / self.scale;
194 self.found_inf = false;
195
196 for g in grads.iter_mut() {
197 if g.is_infinite() || g.is_nan() {
198 self.found_inf = true;
199 }
202 *g *= inv_scale;
203 }
204
205 !self.found_inf
206 }
207
208 #[must_use]
210 pub fn check_grads(&self, grads: &[f32]) -> bool {
211 grads.iter().all(|g| g.is_finite())
212 }
213
214 #[must_use]
216 pub fn found_inf(&self) -> bool {
217 self.found_inf
218 }
219
220 pub fn set_found_inf(&mut self, found: bool) {
222 self.found_inf = found;
223 }
224
225 pub fn update(&mut self) {
231 if !self.enabled {
232 return;
233 }
234
235 if self.found_inf {
236 self.scale *= self.backoff_factor;
238 self.growth_tracker = 0;
239 self.scale = self.scale.max(1.0);
241 } else {
242 self.growth_tracker += 1;
244 if self.growth_tracker >= self.growth_interval {
245 self.scale *= self.growth_factor;
247 self.growth_tracker = 0;
248 self.scale = self.scale.min(f32::MAX / 2.0);
250 }
251 }
252 }
253
254 #[must_use]
256 pub fn state_dict(&self) -> GradScalerState {
257 GradScalerState {
258 scale: self.scale,
259 growth_tracker: self.growth_tracker,
260 }
261 }
262
263 pub fn load_state_dict(&mut self, state: GradScalerState) {
265 self.scale = state.scale;
266 self.growth_tracker = state.growth_tracker;
267 }
268}
269
270#[derive(Debug, Clone, Copy)]
272pub struct GradScalerState {
273 pub scale: f32,
275 pub growth_tracker: usize,
277}
278
279#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_grad_scaler_creation() {
289 let scaler = GradScaler::new();
290 assert!((scaler.get_scale() - 65536.0).abs() < 1e-6);
291 assert!(scaler.is_enabled());
292 assert!(!scaler.found_inf());
293 }
294
295 #[test]
296 fn test_grad_scaler_with_scale() {
297 let scaler = GradScaler::with_scale(1024.0);
298 assert!((scaler.get_scale() - 1024.0).abs() < 1e-6);
299 }
300
301 #[test]
302 fn test_scale_loss() {
303 let scaler = GradScaler::with_scale(100.0);
304 let loss = 0.5;
305 let scaled = scaler.scale_loss(loss);
306 assert!((scaled - 50.0).abs() < 1e-6);
307 }
308
309 #[test]
310 fn test_unscale_grads() {
311 let mut scaler = GradScaler::with_scale(100.0);
312 let mut grads = vec![100.0, 200.0, 300.0];
313
314 let valid = scaler.unscale_grads(&mut grads);
315
316 assert!(valid);
317 assert!(!scaler.found_inf());
318 assert!((grads[0] - 1.0).abs() < 1e-6);
319 assert!((grads[1] - 2.0).abs() < 1e-6);
320 assert!((grads[2] - 3.0).abs() < 1e-6);
321 }
322
323 #[test]
324 fn test_unscale_grads_with_inf() {
325 let mut scaler = GradScaler::with_scale(100.0);
326 let mut grads = vec![100.0, f32::INFINITY, 300.0];
327
328 let valid = scaler.unscale_grads(&mut grads);
329
330 assert!(!valid);
331 assert!(scaler.found_inf());
332 }
333
334 #[test]
335 fn test_unscale_grads_with_nan() {
336 let mut scaler = GradScaler::with_scale(100.0);
337 let mut grads = vec![100.0, f32::NAN, 300.0];
338
339 let valid = scaler.unscale_grads(&mut grads);
340
341 assert!(!valid);
342 assert!(scaler.found_inf());
343 }
344
345 #[test]
346 fn test_update_on_overflow() {
347 let mut scaler = GradScaler::with_scale(1000.0);
348 scaler.found_inf = true;
349
350 scaler.update();
351
352 assert!((scaler.get_scale() - 500.0).abs() < 1e-6);
353 assert_eq!(scaler.growth_tracker, 0);
354 }
355
356 #[test]
357 fn test_update_growth() {
358 let mut scaler = GradScaler::with_options(100.0, 2.0, 0.5, 3);
359
360 for _ in 0..3 {
362 scaler.found_inf = false;
363 scaler.update();
364 }
365
366 assert!((scaler.get_scale() - 200.0).abs() < 1e-6);
367 assert_eq!(scaler.growth_tracker, 0);
368 }
369
370 #[test]
371 fn test_disabled_scaler() {
372 let mut scaler = GradScaler::new().enabled(false);
373
374 assert!(!scaler.is_enabled());
375 assert!((scaler.get_scale() - 1.0).abs() < 1e-6);
376 assert!((scaler.scale_loss(0.5) - 0.5).abs() < 1e-6);
377
378 let mut grads = vec![1.0, 2.0, 3.0];
379 let valid = scaler.unscale_grads(&mut grads);
380 assert!(valid);
381 assert!((grads[0] - 1.0).abs() < 1e-6);
383 }
384
385 #[test]
386 fn test_state_dict() {
387 let mut scaler = GradScaler::with_scale(500.0);
388 scaler.growth_tracker = 10;
389
390 let state = scaler.state_dict();
391 assert!((state.scale - 500.0).abs() < 1e-6);
392 assert_eq!(state.growth_tracker, 10);
393
394 let mut new_scaler = GradScaler::new();
395 new_scaler.load_state_dict(state);
396 assert!((new_scaler.get_scale() - 500.0).abs() < 1e-6);
397 assert_eq!(new_scaler.growth_tracker, 10);
398 }
399
400 #[test]
401 fn test_builder_pattern() {
402 let scaler = GradScaler::with_scale(1000.0)
403 .growth_factor(3.0)
404 .backoff_factor(0.25)
405 .growth_interval(100);
406
407 assert!((scaler.growth_factor - 3.0).abs() < 1e-6);
408 assert!((scaler.backoff_factor - 0.25).abs() < 1e-6);
409 assert_eq!(scaler.growth_interval, 100);
410 }
411}