burn_optim/optim/simple/base.rs
1use burn_core as burn;
2
3use crate::LearningRate;
4use burn::record::Record;
5use burn::tensor::{Tensor, backend::Backend};
6
7/// Simple optimizer is an opinionated trait to simplify the process of implementing an
8/// optimizer.
9///
10/// Implementations don't have to handle missing gradients, loading and exporting records, navigate the
11/// module parameter structure, handle tracked and untracked tensors, and the likes.
12pub trait SimpleOptimizer<B>: Send + Sync + Clone
13where
14 B: Backend,
15{
16 /// The state of the optimizer. It also implements [record](Record), so that it can be saved.
17 type State<const D: usize>: Record<B> + Clone + 'static;
18
19 /// The optimizer step is performed for one tensor at a time with its gradient and state.
20 ///
21 /// Note that the state is passed as parameter, so implementations don't have to handle
22 /// the saving and loading of recorded states.
23 fn step<const D: usize>(
24 &self,
25 lr: LearningRate,
26 tensor: Tensor<B, D>,
27 grad: Tensor<B, D>,
28 state: Option<Self::State<D>>,
29 ) -> (Tensor<B, D>, Option<Self::State<D>>);
30
31 /// Change the device of the state.
32 ///
33 /// This function will be called accordingly to have the state on the same device as the
34 /// gradient and the tensor when the [step](SimpleOptimizer::step) function is called.
35 fn to_device<const D: usize>(state: Self::State<D>, device: &B::Device) -> Self::State<D>;
36}