use crate::autograd::Variable;
use crate::optim::Optimizer;
use crate::tensor::Tensor;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub enum StepResult {
Success {
scale: f32,
grad_norm: Option<f32>,
},
Overflow {
scale: f32,
new_scale: f32,
},
InfNan {
scale: f32,
},
}
#[derive(Clone, Debug)]
pub struct ScalerState {
pub scale: f32,
pub growth_factor: f32,
pub backoff_factor: f32,
pub growth_interval: usize,
pub growth_tracker: usize,
pub consecutive_non_overflow: usize,
pub enabled: bool,
}
impl Default for ScalerState {
fn default() -> Self {
Self {
scale: 65536.0, growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
growth_tracker: 0,
consecutive_non_overflow: 0,
enabled: true,
}
}
}
pub struct GradScaler {
state: ScalerState,
found_overflow: bool,
scaled_grads: HashMap<usize, Tensor<f32>>,
}
impl Default for GradScaler {
fn default() -> Self {
Self::new(None, None, None, None, None)
}
}
impl GradScaler {
pub fn new(
init_scale: Option<f32>,
growth_factor: Option<f32>,
backoff_factor: Option<f32>,
growth_interval: Option<usize>,
enabled: Option<bool>,
) -> Self {
let mut state = ScalerState::default();
if let Some(scale) = init_scale {
state.scale = scale;
}
if let Some(factor) = growth_factor {
state.growth_factor = factor;
}
if let Some(factor) = backoff_factor {
state.backoff_factor = factor;
}
if let Some(interval) = growth_interval {
state.growth_interval = interval;
}
if let Some(enabled) = enabled {
state.enabled = enabled;
}
Self {
state,
found_overflow: false,
scaled_grads: HashMap::new(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn default() -> Self {
Self::new(None, None, None, None, None)
}
pub fn scale(&self, loss: &Variable<f32>) -> Variable<f32> {
if !self.state.enabled {
return loss.clone();
}
let scale_tensor = Tensor::from_vec(vec![self.state.scale], vec![1]);
let scale_var = Variable::new(scale_tensor, false);
loss * &scale_var
}
pub fn scale_tensor(&self, tensor: &Tensor<f32>) -> Tensor<f32> {
if !self.state.enabled {
return tensor.clone();
}
tensor * self.state.scale
}
pub fn unscale_grads(&mut self, _optimizer: &mut dyn Optimizer) {
if !self.state.enabled {}
}
pub fn check_overflow(&mut self, gradients: &[Tensor<f32>]) -> bool {
if !self.state.enabled {
return false;
}
for grad in gradients {
if let Some(slice) = grad.as_slice() {
for &value in slice {
if !value.is_finite() || value.abs() > 65504.0 {
self.found_overflow = true;
return true;
}
}
}
}
false
}
pub fn step<O: Optimizer>(
&mut self,
optimizer: &mut O,
params: &[Tensor<f32>],
grads: &[Tensor<f32>],
) {
if !self.state.enabled {
for (param, grad) in params.iter().zip(grads.iter()) {
optimizer.step(param, grad);
}
return;
}
if self.check_overflow(grads) {
self.found_overflow = true;
return;
}
let unscaled_grads: Vec<Tensor<f32>> = grads.iter().map(|g| g / self.state.scale).collect();
for (param, grad) in params.iter().zip(unscaled_grads.iter()) {
optimizer.step(param, grad);
}
self.update_scale();
}
pub fn step_with_clipping<O: Optimizer>(
&mut self,
optimizer: &mut O,
params: &[Tensor<f32>],
grads: &mut [Tensor<f32>],
max_grad_norm: Option<f32>,
) -> StepResult {
if !self.state.enabled {
if let Some(max_norm) = max_grad_norm {
crate::amp::dtype_utils::utils::clip_grad_norm(grads, max_norm);
}
for (param, grad) in params.iter().zip(grads.iter()) {
optimizer.step(param, grad);
}
return StepResult::Success {
scale: 1.0,
grad_norm: None,
};
}
if self.check_overflow(grads) {
self.found_overflow = true;
return StepResult::Overflow {
scale: self.state.scale,
new_scale: self.state.scale * self.state.backoff_factor,
};
}
for grad in grads.iter_mut() {
*grad = grad.clone() / self.state.scale;
}
let grad_norm = max_grad_norm
.map(|max_norm| crate::amp::dtype_utils::utils::clip_grad_norm(grads, max_norm));
if self.check_overflow(grads) {
self.found_overflow = true;
return StepResult::InfNan {
scale: self.state.scale,
};
}
for (param, grad) in params.iter().zip(grads.iter()) {
optimizer.step(param, grad);
}
let old_scale = self.state.scale;
self.update_scale();
StepResult::Success {
scale: old_scale,
grad_norm,
}
}
pub fn update_scale(&mut self) {
if !self.state.enabled {
return;
}
self.state.growth_tracker += 1;
if self.found_overflow {
self.state.scale *= self.state.backoff_factor;
self.state.consecutive_non_overflow = 0;
self.found_overflow = false;
} else {
self.state.consecutive_non_overflow += 1;
if self.state.growth_tracker >= self.state.growth_interval {
self.state.scale *= self.state.growth_factor;
self.state.growth_tracker = 0;
}
}
self.state.scale = self.state.scale.clamp(1.0, 65536.0 * 65536.0);
}
pub fn get_scale(&self) -> f32 {
self.state.scale
}
pub fn set_scale(&mut self, scale: f32) {
self.state.scale = scale;
}
pub fn load_state_dict(&mut self, state: ScalerState) {
self.state = state;
}
pub fn state_dict(&self) -> ScalerState {
self.state.clone()
}
pub fn is_enabled(&self) -> bool {
self.state.enabled
}
pub fn set_enabled(&mut self, enabled: bool) {
self.state.enabled = enabled;
}
pub fn reset(&mut self) {
self.found_overflow = false;
self.scaled_grads.clear();
self.state.growth_tracker = 0;
self.state.consecutive_non_overflow = 0;
}
pub fn get_stats(&self) -> ScalerStats {
ScalerStats {
current_scale: self.state.scale,
growth_factor: self.state.growth_factor,
backoff_factor: self.state.backoff_factor,
growth_interval: self.state.growth_interval,
growth_tracker: self.state.growth_tracker,
consecutive_non_overflow: self.state.consecutive_non_overflow,
enabled: self.state.enabled,
has_overflow: self.found_overflow,
}
}
pub fn set_scale_bounds(&mut self, min_scale: f32, max_scale: f32) {
self.state.scale = self.state.scale.max(min_scale).min(max_scale);
}
pub fn adaptive_growth_interval(&mut self, overflow_rate: f32) {
if overflow_rate > 0.1 {
self.state.growth_interval = (self.state.growth_interval * 2).min(10000);
} else if overflow_rate < 0.01 {
self.state.growth_interval = (self.state.growth_interval / 2).max(100);
}
}
}
#[derive(Debug, Clone)]
pub struct ScalerStats {
pub current_scale: f32,
pub growth_factor: f32,
pub backoff_factor: f32,
pub growth_interval: usize,
pub growth_tracker: usize,
pub consecutive_non_overflow: usize,
pub enabled: bool,
pub has_overflow: bool,
}
pub mod utils {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grad_scaler_creation() {
let scaler = GradScaler::default();
assert_eq!(scaler.get_scale(), 65536.0);
assert!(scaler.is_enabled());
}
#[test]
fn test_grad_scaler_custom() {
let scaler = GradScaler::new(Some(1024.0), Some(3.0), Some(0.3), Some(1000), Some(true));
assert_eq!(scaler.get_scale(), 1024.0);
assert_eq!(scaler.state.growth_factor, 3.0);
assert_eq!(scaler.state.backoff_factor, 0.3);
}
#[test]
fn test_scale_tensor() {
let scaler = GradScaler::default();
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let scaled = scaler.scale_tensor(&tensor);
let expected = tensor * scaler.get_scale();
assert_eq!(scaled.as_slice(), expected.as_slice());
}
#[test]
fn test_overflow_detection() {
let mut scaler = GradScaler::default();
let normal_grads = vec![
Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]),
Tensor::from_vec(vec![4.0, 5.0, 6.0], vec![3]),
];
assert!(!scaler.check_overflow(&normal_grads));
let overflow_grads = vec![Tensor::from_vec(vec![1.0, 2.0, 100000.0], vec![3])];
assert!(scaler.check_overflow(&overflow_grads));
let nan_grads = vec![Tensor::from_vec(vec![1.0, f32::NAN, 3.0], vec![3])];
assert!(scaler.check_overflow(&nan_grads));
}
#[test]
fn test_scale_update() {
let mut scaler = GradScaler::new(Some(1024.0), Some(2.0), Some(0.5), Some(2), Some(true));
let initial_scale = scaler.get_scale();
scaler.found_overflow = true;
scaler.update_scale();
assert_eq!(scaler.get_scale(), initial_scale * 0.5);
scaler.found_overflow = false;
scaler.update_scale();
scaler.update_scale();
assert_eq!(scaler.get_scale(), initial_scale * 0.5 * 2.0);
}
}