#[derive(Debug, Clone)]
pub struct GradScaler {
scale: f32,
growth_factor: f32,
backoff_factor: f32,
growth_interval: usize,
growth_tracker: usize,
found_inf: bool,
enabled: bool,
}
impl Default for GradScaler {
fn default() -> Self {
Self::new()
}
}
impl GradScaler {
#[must_use]
pub fn new() -> Self {
Self {
scale: 65536.0,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
growth_tracker: 0,
found_inf: false,
enabled: true,
}
}
#[must_use]
pub fn with_scale(init_scale: f32) -> Self {
Self {
scale: init_scale,
..Self::new()
}
}
#[must_use]
pub fn with_options(
init_scale: f32,
growth_factor: f32,
backoff_factor: f32,
growth_interval: usize,
) -> Self {
Self {
scale: init_scale,
growth_factor,
backoff_factor,
growth_interval,
growth_tracker: 0,
found_inf: false,
enabled: true,
}
}
#[must_use]
pub fn growth_factor(mut self, factor: f32) -> Self {
self.growth_factor = factor;
self
}
#[must_use]
pub fn backoff_factor(mut self, factor: f32) -> Self {
self.backoff_factor = factor;
self
}
#[must_use]
pub fn growth_interval(mut self, interval: usize) -> Self {
self.growth_interval = interval;
self
}
#[must_use]
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
#[must_use]
pub fn get_scale(&self) -> f32 {
if self.enabled { self.scale } else { 1.0 }
}
pub fn set_scale(&mut self, scale: f32) {
self.scale = scale;
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
#[must_use]
pub fn scale_loss(&self, loss: f32) -> f32 {
if self.enabled {
loss * self.scale
} else {
loss
}
}
pub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool {
if !self.enabled {
self.found_inf = false;
return true;
}
let inv_scale = 1.0 / self.scale;
self.found_inf = false;
for g in grads.iter_mut() {
if g.is_infinite() || g.is_nan() {
self.found_inf = true;
}
*g *= inv_scale;
}
!self.found_inf
}
pub fn unscale_optimizer<O: crate::Optimizer>(&mut self, optimizer: &O) -> bool {
if !self.enabled {
self.found_inf = false;
return true;
}
let inv_scale = 1.0 / self.scale;
self.found_inf = false;
for param in optimizer.parameters() {
if let Some(grad) = param.grad() {
let mut grad_vec = grad.to_vec();
for g in &mut grad_vec {
if g.is_infinite() || g.is_nan() {
self.found_inf = true;
}
*g *= inv_scale;
}
let unscaled = axonml_tensor::Tensor::from_vec(grad_vec, grad.shape())
.expect("grad_scaler: tensor creation failed");
param.set_grad(unscaled);
}
}
!self.found_inf
}
#[must_use]
pub fn check_grads(&self, grads: &[f32]) -> bool {
grads.iter().all(|g| g.is_finite())
}
#[must_use]
pub fn found_inf(&self) -> bool {
self.found_inf
}
pub fn set_found_inf(&mut self, found: bool) {
self.found_inf = found;
}
pub fn update(&mut self) {
if !self.enabled {
return;
}
if self.found_inf {
self.scale *= self.backoff_factor;
self.growth_tracker = 0;
self.scale = self.scale.max(1.0);
} else {
self.growth_tracker += 1;
if self.growth_tracker >= self.growth_interval {
self.scale *= self.growth_factor;
self.growth_tracker = 0;
self.scale = self.scale.min(f32::MAX / 2.0);
}
}
}
#[must_use]
pub fn state_dict(&self) -> GradScalerState {
GradScalerState {
scale: self.scale,
growth_tracker: self.growth_tracker,
}
}
pub fn load_state_dict(&mut self, state: GradScalerState) {
self.scale = state.scale;
self.growth_tracker = state.growth_tracker;
}
}
#[derive(Debug, Clone, Copy)]
pub struct GradScalerState {
pub scale: f32,
pub growth_tracker: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grad_scaler_creation() {
let scaler = GradScaler::new();
assert!((scaler.get_scale() - 65536.0).abs() < 1e-6);
assert!(scaler.is_enabled());
assert!(!scaler.found_inf());
}
#[test]
fn test_grad_scaler_with_scale() {
let scaler = GradScaler::with_scale(1024.0);
assert!((scaler.get_scale() - 1024.0).abs() < 1e-6);
}
#[test]
fn test_scale_loss() {
let scaler = GradScaler::with_scale(100.0);
let loss = 0.5;
let scaled = scaler.scale_loss(loss);
assert!((scaled - 50.0).abs() < 1e-6);
}
#[test]
fn test_unscale_grads() {
let mut scaler = GradScaler::with_scale(100.0);
let mut grads = vec![100.0, 200.0, 300.0];
let valid = scaler.unscale_grads(&mut grads);
assert!(valid);
assert!(!scaler.found_inf());
assert!((grads[0] - 1.0).abs() < 1e-6);
assert!((grads[1] - 2.0).abs() < 1e-6);
assert!((grads[2] - 3.0).abs() < 1e-6);
}
#[test]
fn test_unscale_grads_with_inf() {
let mut scaler = GradScaler::with_scale(100.0);
let mut grads = vec![100.0, f32::INFINITY, 300.0];
let valid = scaler.unscale_grads(&mut grads);
assert!(!valid);
assert!(scaler.found_inf());
}
#[test]
fn test_unscale_grads_with_nan() {
let mut scaler = GradScaler::with_scale(100.0);
let mut grads = vec![100.0, f32::NAN, 300.0];
let valid = scaler.unscale_grads(&mut grads);
assert!(!valid);
assert!(scaler.found_inf());
}
#[test]
fn test_update_on_overflow() {
let mut scaler = GradScaler::with_scale(1000.0);
scaler.found_inf = true;
scaler.update();
assert!((scaler.get_scale() - 500.0).abs() < 1e-6);
assert_eq!(scaler.growth_tracker, 0);
}
#[test]
fn test_update_growth() {
let mut scaler = GradScaler::with_options(100.0, 2.0, 0.5, 3);
for _ in 0..3 {
scaler.found_inf = false;
scaler.update();
}
assert!((scaler.get_scale() - 200.0).abs() < 1e-6);
assert_eq!(scaler.growth_tracker, 0);
}
#[test]
fn test_disabled_scaler() {
let mut scaler = GradScaler::new().enabled(false);
assert!(!scaler.is_enabled());
assert!((scaler.get_scale() - 1.0).abs() < 1e-6);
assert!((scaler.scale_loss(0.5) - 0.5).abs() < 1e-6);
let mut grads = vec![1.0, 2.0, 3.0];
let valid = scaler.unscale_grads(&mut grads);
assert!(valid);
assert!((grads[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_state_dict() {
let mut scaler = GradScaler::with_scale(500.0);
scaler.growth_tracker = 10;
let state = scaler.state_dict();
assert!((state.scale - 500.0).abs() < 1e-6);
assert_eq!(state.growth_tracker, 10);
let mut new_scaler = GradScaler::new();
new_scaler.load_state_dict(state);
assert!((new_scaler.get_scale() - 500.0).abs() < 1e-6);
assert_eq!(new_scaler.growth_tracker, 10);
}
#[test]
fn test_builder_pattern() {
let scaler = GradScaler::with_scale(1000.0)
.growth_factor(3.0)
.backoff_factor(0.25)
.growth_interval(100);
assert!((scaler.growth_factor - 3.0).abs() < 1e-6);
assert!((scaler.backoff_factor - 0.25).abs() < 1e-6);
assert_eq!(scaler.growth_interval, 100);
}
}