use scivex_core::Float;
use crate::optim::Optimizer;
use crate::variable::Variable;
pub struct GradScaler<T: Float> {
scale: T,
growth_factor: T,
backoff_factor: T,
growth_interval: usize,
steps_since_growth: usize,
found_inf: bool,
}
impl<T: Float> GradScaler<T> {
pub fn new() -> Self {
Self {
scale: T::from_f64(65536.0),
growth_factor: T::from_f64(2.0),
backoff_factor: T::from_f64(0.5),
growth_interval: 2000,
steps_since_growth: 0,
found_inf: false,
}
}
#[must_use]
pub fn with_init_scale(mut self, scale: T) -> Self {
self.scale = scale;
self
}
#[must_use]
pub fn with_growth_factor(mut self, factor: T) -> Self {
self.growth_factor = factor;
self
}
#[must_use]
pub fn with_backoff_factor(mut self, factor: T) -> Self {
self.backoff_factor = factor;
self
}
#[must_use]
pub fn with_growth_interval(mut self, interval: usize) -> Self {
self.growth_interval = interval;
self
}
pub fn get_scale(&self) -> T {
self.scale
}
pub fn scale(&self, loss: &Variable<T>) -> Variable<T> {
crate::ops::scalar_mul(loss, self.scale)
}
pub fn step<O: Optimizer<T>>(&mut self, optimizer: &mut O, params: &[Variable<T>]) {
self.found_inf = false;
let inv_scale = self.scale.recip();
for p in params {
if let Some(grad) = p.grad() {
let unscaled = grad.map(|g| g * inv_scale);
let has_inf = unscaled.as_slice().iter().any(|&v| !v.is_finite());
if has_inf {
self.found_inf = true;
return;
}
p.set_grad(unscaled);
}
}
optimizer.step();
}
pub fn update(&mut self) {
if self.found_inf {
self.scale *= self.backoff_factor;
self.steps_since_growth = 0;
} else {
self.steps_since_growth += 1;
if self.steps_since_growth >= self.growth_interval {
self.scale *= self.growth_factor;
self.steps_since_growth = 0;
}
}
}
}
impl<T: Float> Default for GradScaler<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scivex_core::Tensor;
#[test]
fn test_grad_scaler_scale() {
let scaler = GradScaler::<f64>::new();
let v = Variable::new(Tensor::from_vec(vec![2.0], vec![1]).unwrap(), true);
let scaled = scaler.scale(&v);
assert!((scaled.data().as_slice()[0] - 2.0 * 65536.0).abs() < 1e-6);
}
#[test]
fn test_grad_scaler_backoff_on_inf() {
let mut scaler = GradScaler::<f64>::new();
let initial_scale = scaler.get_scale();
scaler.found_inf = true;
scaler.update();
assert!((scaler.get_scale() - initial_scale * 0.5).abs() < 1e-6);
}
#[test]
fn test_grad_scaler_growth() {
let mut scaler = GradScaler::<f64>::new().with_growth_interval(2);
let initial_scale = scaler.get_scale();
scaler.found_inf = false;
scaler.update();
assert!((scaler.get_scale() - initial_scale).abs() < 1e-6);
scaler.update();
assert!((scaler.get_scale() - initial_scale * 2.0).abs() < 1e-6);
}
#[test]
fn test_grad_scaler_default() {
let scaler = GradScaler::<f32>::default();
assert!((scaler.get_scale() - 65536.0).abs() < 1e-2);
}
}