pub trait Optimizer<M, B>: Send + Syncwhere
    M: ADModule<B>,
    B: ADBackend,{
    type Record: Record;

    // Required methods
    fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;
    fn to_record(&self) -> Self::Record;
    fn load_record(self, record: Self::Record) -> Self;
}
Expand description

General trait to optimize module.

Required Associated Types§

source

type Record: Record

Optimizer associative type to be used when saving and loading the state.

Required Methods§

source

fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M

Perform the optimizer step using the given learning rate and gradients. The updated module is returned.

source

fn to_record(&self) -> Self::Record

Get the current state of the optimizer as a record.

source

fn load_record(self, record: Self::Record) -> Self

Load the state of the optimizer as a record.

Implementors§

source§

impl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>where B: ADBackend, M: ADModule<B>, O: SimpleOptimizer<B::InnerBackend>,