use super::MixedPrecisionConfig;
const DEFAULT_SCALE_GROWTH_INTERVAL: usize = 2000;
#[derive(Debug)]
pub struct GradScaler {
scale: f32,
growth_factor: f32,
backoff_factor: f32,
pub(crate) growth_interval: usize,
steps_since_growth: usize,
dynamic: bool,
overflow_count: usize,
successful_steps: usize,
}
impl GradScaler {
pub fn new(initial_scale: f32) -> Self {
Self {
scale: initial_scale,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: DEFAULT_SCALE_GROWTH_INTERVAL,
steps_since_growth: 0,
dynamic: true,
overflow_count: 0,
successful_steps: 0,
}
}
pub fn from_config(config: &MixedPrecisionConfig) -> Self {
Self {
scale: config.initial_scale,
growth_factor: config.scale_growth_factor,
backoff_factor: config.scale_backoff_factor,
growth_interval: config.scale_growth_interval,
steps_since_growth: 0,
dynamic: config.dynamic_scaling,
overflow_count: 0,
successful_steps: 0,
}
}
pub fn scale(&self) -> f32 {
self.scale
}
pub fn scale_loss(&self, loss: f32) -> f32 {
loss * self.scale
}
pub fn unscale_grad(&self, grad: f32) -> f32 {
grad / self.scale
}
pub fn unscale_and_check(&self, grads: &mut [f32]) -> bool {
let inv_scale = 1.0 / self.scale;
let mut has_overflow = false;
for grad in grads.iter_mut() {
*grad *= inv_scale;
if !grad.is_finite() {
has_overflow = true;
}
}
!has_overflow
}
pub fn update(&mut self, grads_valid: bool) {
if !self.dynamic {
return;
}
if grads_valid {
self.successful_steps += 1;
self.steps_since_growth += 1;
if self.steps_since_growth >= self.growth_interval {
self.scale *= self.growth_factor;
self.steps_since_growth = 0;
}
} else {
self.overflow_count += 1;
self.scale *= self.backoff_factor;
self.steps_since_growth = 0;
self.scale = self.scale.max(1.0);
}
}
pub fn overflow_count(&self) -> usize {
self.overflow_count
}
pub fn successful_steps(&self) -> usize {
self.successful_steps
}
pub fn is_dynamic(&self) -> bool {
self.dynamic
}
pub fn set_dynamic(&mut self, enabled: bool) {
self.dynamic = enabled;
}
}
impl Default for GradScaler {
fn default() -> Self {
Self::new(65536.0)
}
}