burn_core/optim/
base.rs

1use super::GradientsParams;
2use crate::LearningRate;
3use crate::module::AutodiffModule;
4use crate::record::Record;
5use crate::tensor::backend::AutodiffBackend;
6
7/// General trait to optimize [module](AutodiffModule).
8pub trait Optimizer<M, B>: Send
9where
10    M: AutodiffModule<B>,
11    B: AutodiffBackend,
12{
13    /// Optimizer associative type to be used when saving and loading the state.
14    type Record: Record<B>;
15
16    /// Perform the optimizer step using the given learning rate and gradients.
17    /// The updated module is returned.
18    fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;
19
20    /// Get the current state of the optimizer as a [record](Record).
21    fn to_record(&self) -> Self::Record;
22
23    /// Load the state of the optimizer as a [record](Record).
24    fn load_record(self, record: Self::Record) -> Self;
25}