pub trait Optimizer<M, B>: Send + Syncwhere
M: ADModule<B>,
B: ADBackend,{
type Record: Record;
// Required methods
fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;
fn to_record(&self) -> Self::Record;
fn load_record(self, record: Self::Record) -> Self;
}Expand description
General trait to optimize module.
Required Associated Types§
Required Methods§
sourcefn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M
fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M
Perform the optimizer step using the given learning rate and gradients. The updated module is returned.
sourcefn load_record(self, record: Self::Record) -> Self
fn load_record(self, record: Self::Record) -> Self
Load the state of the optimizer as a record.