burn_core/optim/base.rs
1use super::GradientsParams;
2use crate::LearningRate;
3use crate::module::AutodiffModule;
4use crate::record::Record;
5use crate::tensor::backend::AutodiffBackend;
6
7/// General trait to optimize [module](AutodiffModule).
8pub trait Optimizer<M, B>: Send
9where
10 M: AutodiffModule<B>,
11 B: AutodiffBackend,
12{
13 /// Optimizer associative type to be used when saving and loading the state.
14 type Record: Record<B>;
15
16 /// Perform the optimizer step using the given learning rate and gradients.
17 /// The updated module is returned.
18 fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;
19
20 /// Get the current state of the optimizer as a [record](Record).
21 fn to_record(&self) -> Self::Record;
22
23 /// Load the state of the optimizer as a [record](Record).
24 fn load_record(self, record: Self::Record) -> Self;
25}