Skip to main content

GradScaler

Struct GradScaler 

Source
pub struct GradScaler { /* private fields */ }
Expand description

GradScaler for mixed precision training.

Scales loss before backward to prevent gradient underflow in float16, then unscales gradients before optimizer step. Dynamically adjusts scale factor based on whether inf/nan gradients are detected.

let mut scaler = GradScaler::new();
let scaled_loss = scaler.scale(&loss)?;
scaled_loss.backward()?;
let stepped = scaler.step(&params, &mut || optim.step())?;
scaler.update();

Implementations§

Source§

impl GradScaler

Source

pub fn new() -> Self

Create a new GradScaler with default settings.

Initial scale: 2^16 = 65536, growth: 2.0, backoff: 0.5, interval: 2000.

Source

pub fn scale(&self, loss: &Variable) -> Result<Variable>

Scale the loss before backward. Returns loss * scale.

Source

pub fn scale_factor(&self) -> f64

Current scale factor.

Source

pub fn step( &mut self, params: &[Parameter], step_fn: &mut dyn FnMut() -> Result<()>, ) -> Result<bool>

Unscale gradients, check for inf/nan, and step the optimizer.

Returns true if the step was taken (all gradients finite). Returns false if inf/nan detected (optimizer step skipped).

Source

pub fn update(&mut self)

Update the scale factor after each step.

Call this after every step() call, regardless of whether it succeeded.

Trait Implementations§

Source§

impl Default for GradScaler

Source§

fn default() -> Self

Returns the “default value” for a type. Read more
Source§

impl Stateful for GradScaler

Source§

fn save_state<W: Write>(&self, w: &mut W) -> Result<()>

Serialize optimizer state (lr, momentum buffers, etc.) to a writer.
Source§

fn load_state<R: Read>(&mut self, r: &mut R) -> Result<()>

Restore optimizer state from a reader.
Source§

fn save_state_file(&self, path: &str) -> Result<()>

Save state to a file. Uses gzip compression if path ends with .gz.
Source§

fn load_state_file(&mut self, path: &str) -> Result<()>

Load state from a file. Detects gzip from .gz extension.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.