pub struct GradScaler { /* private fields */ }Expand description
Gradient scaler for mixed precision training.
Scales the loss to prevent gradient underflow when using F16, then unscales gradients before the optimizer step.
The scale is automatically adjusted based on whether gradients overflow.
Implementations§
Source§impl GradScaler
impl GradScaler
Sourcepub fn new() -> Self
pub fn new() -> Self
Creates a new gradient scaler with default settings.
Default configuration:
- Initial scale: 65536.0 (2^16)
- Growth factor: 2.0
- Backoff factor: 0.5
- Growth interval: 2000 steps
Sourcepub fn with_scale(init_scale: f32) -> Self
pub fn with_scale(init_scale: f32) -> Self
Creates a gradient scaler with custom initial scale.
Sourcepub fn with_options(
init_scale: f32,
growth_factor: f32,
backoff_factor: f32,
growth_interval: usize,
) -> Self
pub fn with_options( init_scale: f32, growth_factor: f32, backoff_factor: f32, growth_interval: usize, ) -> Self
Creates a gradient scaler with all custom settings.
Sourcepub fn growth_factor(self, factor: f32) -> Self
pub fn growth_factor(self, factor: f32) -> Self
Builder: set growth factor
Sourcepub fn backoff_factor(self, factor: f32) -> Self
pub fn backoff_factor(self, factor: f32) -> Self
Builder: set backoff factor
Sourcepub fn growth_interval(self, interval: usize) -> Self
pub fn growth_interval(self, interval: usize) -> Self
Builder: set growth interval
Sourcepub fn is_enabled(&self) -> bool
pub fn is_enabled(&self) -> bool
Returns whether the scaler is enabled.
Sourcepub fn set_enabled(&mut self, enabled: bool)
pub fn set_enabled(&mut self, enabled: bool)
Enables or disables the scaler.
Sourcepub fn scale_loss(&self, loss: f32) -> f32
pub fn scale_loss(&self, loss: f32) -> f32
Scales a loss value for backward pass.
Multiply the loss by this before calling backward().
Sourcepub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool
pub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool
Unscales gradients in place and checks for inf/nan.
Returns true if all gradients are finite, false if any overflow.
Sourcepub fn check_grads(&self, grads: &[f32]) -> bool
pub fn check_grads(&self, grads: &[f32]) -> bool
Checks a slice of gradients for inf/nan without modifying them.
Sourcepub fn found_inf(&self) -> bool
pub fn found_inf(&self) -> bool
Returns whether inf/nan was found in the last unscale operation.
Sourcepub fn set_found_inf(&mut self, found: bool)
pub fn set_found_inf(&mut self, found: bool)
Marks that inf was found (for external gradient checking).
Sourcepub fn update(&mut self)
pub fn update(&mut self)
Updates the scale factor based on overflow history.
Call this after each optimizer step:
- If overflow was detected, scale is reduced by backoff_factor
- If no overflow for growth_interval steps, scale is increased by growth_factor
Sourcepub fn state_dict(&self) -> GradScalerState
pub fn state_dict(&self) -> GradScalerState
Returns the current state for checkpointing.
Sourcepub fn load_state_dict(&mut self, state: GradScalerState)
pub fn load_state_dict(&mut self, state: GradScalerState)
Loads state from a checkpoint.
Trait Implementations§
Source§impl Clone for GradScaler
impl Clone for GradScaler
Source§fn clone(&self) -> GradScaler
fn clone(&self) -> GradScaler
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreSource§impl Debug for GradScaler
impl Debug for GradScaler
Auto Trait Implementations§
impl Freeze for GradScaler
impl RefUnwindSafe for GradScaler
impl Send for GradScaler
impl Sync for GradScaler
impl Unpin for GradScaler
impl UnwindSafe for GradScaler
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more