Skip to main content

Optimizer

Trait Optimizer 

Source
pub trait Optimizer<R: Runtime<DType = DType>> {
    // Required methods
    fn step<C>(
        &mut self,
        client: &C,
        params: &mut HashMap<TensorId, Tensor<R>>,
        grads: &GradStore<R>,
    ) -> Result<()>
       where C: RuntimeClient<R> + BinaryOps<R> + UnaryOps<R> + ScalarOps<R> + ReduceOps<R> + FusedOptimizerOps<R>;
    fn set_lr(&mut self, lr: f64);
    fn lr(&self) -> f64;
    fn reset(&mut self);
}
Expand description

Trait for parameter optimizers.

All optimizers (AdamW, SGD, etc.) implement this trait so trainers can work with any optimizer without hardcoding a specific one.

Required Methods§

Source

fn step<C>( &mut self, client: &C, params: &mut HashMap<TensorId, Tensor<R>>, grads: &GradStore<R>, ) -> Result<()>

Perform one optimization step.

Updates all parameters in params using gradients from grads. Parameters without gradients are skipped.

Source

fn set_lr(&mut self, lr: f64)

Set the learning rate.

Source

fn lr(&self) -> f64

Get the current learning rate.

Source

fn reset(&mut self)

Reset all optimizer state (moments, velocities, timestep).

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

Source§

impl<R: Runtime<DType = DType>> Optimizer<R> for AdaGrad<R>

Source§

impl<R: Runtime<DType = DType>> Optimizer<R> for AdamW<R>

Source§

impl<R: Runtime<DType = DType>> Optimizer<R> for Lamb<R>

Source§

impl<R: Runtime<DType = DType>> Optimizer<R> for Sgd<R>