Skip to main content

GradScaler

Struct GradScaler 

Source
pub struct GradScaler { /* private fields */ }
Expand description

Gradient scaler for mixed precision training.

Scales the loss to prevent gradient underflow when using F16, then unscales gradients before the optimizer step.

The scale is automatically adjusted based on whether gradients overflow.

Implementations§

Source§

impl GradScaler

Source

pub fn new() -> Self

Creates a new gradient scaler with default settings.

Default configuration:

  • Initial scale: 65536.0 (2^16)
  • Growth factor: 2.0
  • Backoff factor: 0.5
  • Growth interval: 2000 steps
Source

pub fn with_scale(init_scale: f32) -> Self

Creates a gradient scaler with custom initial scale.

Source

pub fn with_options( init_scale: f32, growth_factor: f32, backoff_factor: f32, growth_interval: usize, ) -> Self

Creates a gradient scaler with all custom settings.

Source

pub fn growth_factor(self, factor: f32) -> Self

Builder: set growth factor

Source

pub fn backoff_factor(self, factor: f32) -> Self

Builder: set backoff factor

Source

pub fn growth_interval(self, interval: usize) -> Self

Builder: set growth interval

Source

pub fn enabled(self, enabled: bool) -> Self

Builder: set enabled state

Source

pub fn get_scale(&self) -> f32

Returns the current scale factor.

Source

pub fn set_scale(&mut self, scale: f32)

Sets the scale factor.

Source

pub fn is_enabled(&self) -> bool

Returns whether the scaler is enabled.

Source

pub fn set_enabled(&mut self, enabled: bool)

Enables or disables the scaler.

Source

pub fn scale_loss(&self, loss: f32) -> f32

Scales a loss value for backward pass.

Multiply the loss by this before calling backward().

Source

pub fn unscale_grads(&mut self, grads: &mut [f32]) -> bool

Unscales gradients in place and checks for inf/nan.

Returns true if all gradients are finite, false if any overflow.

Source

pub fn check_grads(&self, grads: &[f32]) -> bool

Checks a slice of gradients for inf/nan without modifying them.

Source

pub fn found_inf(&self) -> bool

Returns whether inf/nan was found in the last unscale operation.

Source

pub fn set_found_inf(&mut self, found: bool)

Marks that inf was found (for external gradient checking).

Source

pub fn update(&mut self)

Updates the scale factor based on overflow history.

Call this after each optimizer step:

  • If overflow was detected, scale is reduced by backoff_factor
  • If no overflow for growth_interval steps, scale is increased by growth_factor
Source

pub fn state_dict(&self) -> GradScalerState

Returns the current state for checkpointing.

Source

pub fn load_state_dict(&mut self, state: GradScalerState)

Loads state from a checkpoint.

Trait Implementations§

Source§

impl Clone for GradScaler

Source§

fn clone(&self) -> GradScaler

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for GradScaler

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Default for GradScaler

Source§

fn default() -> Self

Returns the “default value” for a type. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V