pub struct GradScaler { /* private fields */ }Expand description
GradScaler for mixed precision training.
Scales loss before backward to prevent gradient underflow in float16, then unscales gradients before optimizer step. Dynamically adjusts scale factor based on whether inf/nan gradients are detected.
ⓘ
let mut scaler = GradScaler::new();
let scaled_loss = scaler.scale(&loss)?;
scaled_loss.backward()?;
let stepped = scaler.step(¶ms, &mut || optim.step())?;
scaler.update();Implementations§
Source§impl GradScaler
impl GradScaler
Sourcepub fn new() -> Self
pub fn new() -> Self
Create a new GradScaler with default settings.
Initial scale: 2^16 = 65536, growth: 2.0, backoff: 0.5, interval: 2000.
Sourcepub fn scale(&self, loss: &Variable) -> Result<Variable>
pub fn scale(&self, loss: &Variable) -> Result<Variable>
Scale the loss before backward. Returns loss * scale.
Sourcepub fn scale_factor(&self) -> f64
pub fn scale_factor(&self) -> f64
Current scale factor.
Trait Implementations§
Source§impl Default for GradScaler
impl Default for GradScaler
Source§impl Stateful for GradScaler
impl Stateful for GradScaler
Source§fn save_state<W: Write>(&self, w: &mut W) -> Result<()>
fn save_state<W: Write>(&self, w: &mut W) -> Result<()>
Serialize optimizer state (lr, momentum buffers, etc.) to a writer.
Source§fn load_state<R: Read>(&mut self, r: &mut R) -> Result<()>
fn load_state<R: Read>(&mut self, r: &mut R) -> Result<()>
Restore optimizer state from a reader.
Auto Trait Implementations§
impl Freeze for GradScaler
impl RefUnwindSafe for GradScaler
impl Send for GradScaler
impl Sync for GradScaler
impl Unpin for GradScaler
impl UnsafeUnpin for GradScaler
impl UnwindSafe for GradScaler
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more