burn_core/optim/simple/
base.rs

1use crate::{LearningRate, record::Record};
2use burn_tensor::{Tensor, backend::Backend};
3
4/// Simple optimizer is an opinionated trait to simplify the process of implementing an
5/// optimizer.
6///
7/// Implementations don't have to handle missing gradients, loading and exporting records, navigate the
8/// module parameter structure, handle tracked and untracked tensors, and the likes.
9pub trait SimpleOptimizer<B>: Send + Sync + Clone
10where
11    B: Backend,
12{
13    /// The state of the optimizer. It also implements [record](Record), so that it can be saved.
14    type State<const D: usize>: Record<B> + Clone + 'static;
15
16    /// The optimizer step is performed for one tensor at a time with its gradient and state.
17    ///
18    /// Note that the state is passed as parameter, so implementations don't have to handle
19    /// the saving and loading of recorded states.
20    fn step<const D: usize>(
21        &self,
22        lr: LearningRate,
23        tensor: Tensor<B, D>,
24        grad: Tensor<B, D>,
25        state: Option<Self::State<D>>,
26    ) -> (Tensor<B, D>, Option<Self::State<D>>);
27
28    /// Change the device of the state.
29    ///
30    /// This function will be called accordingly to have the state on the same device as the
31    /// gradient and the tensor when the [step](SimpleOptimizer::step) function is called.
32    fn to_device<const D: usize>(state: Self::State<D>, device: &B::Device) -> Self::State<D>;
33}