burn_optim/optim/
base.rs

1use burn_core as burn;
2
3use super::GradientsParams;
4use crate::LearningRate;
5use burn::module::AutodiffModule;
6use burn::record::Record;
7use burn::tensor::backend::AutodiffBackend;
8
9/// General trait to optimize [module](AutodiffModule).
10pub trait Optimizer<M, B>: Send + Clone
11where
12    M: AutodiffModule<B>,
13    B: AutodiffBackend,
14{
15    /// Optimizer associative type to be used when saving and loading the state.
16    type Record: Record<B>;
17
18    /// Perform the optimizer step using the given learning rate and gradients.
19    /// The updated module is returned.
20    fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;
21
22    /// Get the current state of the optimizer as a [record](Record).
23    fn to_record(&self) -> Self::Record;
24
25    /// Load the state of the optimizer as a [record](Record).
26    fn load_record(self, record: Self::Record) -> Self;
27}