pub trait Optimizer<R: Runtime<DType = DType>> {
// Required methods
fn step<C>(
&mut self,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
grads: &GradStore<R>,
) -> Result<()>
where C: RuntimeClient<R> + BinaryOps<R> + UnaryOps<R> + ScalarOps<R> + ReduceOps<R> + FusedOptimizerOps<R>;
fn set_lr(&mut self, lr: f64);
fn lr(&self) -> f64;
fn reset(&mut self);
}Expand description
Trait for parameter optimizers.
All optimizers (AdamW, SGD, etc.) implement this trait so trainers can work with any optimizer without hardcoding a specific one.
Required Methods§
Sourcefn step<C>(
&mut self,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
grads: &GradStore<R>,
) -> Result<()>where
C: RuntimeClient<R> + BinaryOps<R> + UnaryOps<R> + ScalarOps<R> + ReduceOps<R> + FusedOptimizerOps<R>,
fn step<C>(
&mut self,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
grads: &GradStore<R>,
) -> Result<()>where
C: RuntimeClient<R> + BinaryOps<R> + UnaryOps<R> + ScalarOps<R> + ReduceOps<R> + FusedOptimizerOps<R>,
Perform one optimization step.
Updates all parameters in params using gradients from grads.
Parameters without gradients are skipped.
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.