use tch::{Kind, Tensor};
use super::{Optimizer, ZeroGrad};
use crate::error::SyntaxDotError;
pub struct GradScaler<O> {
enabled: bool,
growth_factor: f64,
backoff_factor: f64,
growth_interval: i64,
optimizer: O,
found_inf: Tensor,
growth_tracker: Tensor,
scale: Tensor,
}
impl<O> GradScaler<O>
where
O: Optimizer,
{
fn new(
enabled: bool,
optimizer: O,
init_scale: f64,
growth_factor: f64,
backoff_factor: f64,
growth_interval: i64,
) -> Result<Self, SyntaxDotError> {
let device = match optimizer.trainable_variables().first() {
Some(tensor) => tensor.device(),
None => return Err(SyntaxDotError::NoTrainableVariables),
};
Ok(GradScaler {
enabled,
growth_factor,
backoff_factor,
growth_interval,
optimizer,
found_inf: Tensor::full(&[1], 0.0, (Kind::Float, device)),
growth_tracker: Tensor::full(&[1], 0, (Kind::Int, device)),
scale: Tensor::full(&[1], init_scale, (Kind::Float, device)),
})
}
pub fn new_with_defaults(enabled: bool, optimizer: O) -> Result<Self, SyntaxDotError> {
GradScaler::new(enabled, optimizer, 2f64.powi(16), 2., 0.5, 2000)
}
pub fn current_scale(&self) -> f32 {
Vec::<f32>::from(&self.scale)[0]
}
pub fn optimizer(&self) -> &O {
&self.optimizer
}
pub fn optimizer_mut(&mut self) -> &mut O {
&mut self.optimizer
}
fn scale(&mut self, t: &Tensor) -> Tensor {
if !self.enabled {
t.shallow_clone()
} else {
t * &self.scale
}
}
fn update(&mut self) {
if !self.enabled {
return;
};
self.scale = Tensor::internal_amp_update_scale(
&self.growth_tracker,
&self.scale,
&self.found_inf,
self.growth_factor,
self.backoff_factor,
self.growth_interval,
);
self.found_inf = self.found_inf.zeros_like();
}
}
impl<O> Optimizer for GradScaler<O>
where
O: Optimizer,
{
fn backward_step(&mut self, loss: &Tensor) {
self.optimizer.trainable_variables().zero_grad();
self.scale(loss).backward();
tch::no_grad(|| self.step());
self.update();
}
fn set_lr_group(&mut self, group: usize, learning_rate: f64) {
self.optimizer.set_lr_group(group, learning_rate)
}
fn set_weight_decay_group(&mut self, group: usize, weight_decay: f64) {
self.optimizer.set_weight_decay_group(group, weight_decay)
}
fn step(&mut self) {
if !self.enabled {
return self.optimizer.step();
}
let inv_scale = self.scale.reciprocal().to_kind(Kind::Float);
for tensor in &mut self.optimizer.trainable_variables() {
if !tensor.grad().defined() {
continue;
}
tensor
.grad()
.internal_amp_non_finite_check_and_unscale(&mut self.found_inf, &inv_scale);
}
let found_inf = (f32::from(&self.found_inf) - 1.0).abs() < f32::EPSILON;
if !found_inf {
self.optimizer.step()
}
}
fn trainable_variables(&self) -> Vec<Tensor> {
self.optimizer.trainable_variables()
}
}