burn_core/optim/
momentum.rs

1use crate as burn;
2
3use crate::config::Config;
4use crate::record::Record;
5use crate::tensor::{ElementConversion, Tensor};
6use burn_tensor::backend::Backend;
7
8/// Configuration to create [momentum](Momentum).
9#[derive(Config)]
10pub struct MomentumConfig {
11    /// Momemtum factor
12    #[config(default = 0.9)]
13    pub momentum: f64,
14    /// Dampening factor.
15    #[config(default = 0.1)]
16    pub dampening: f64,
17    /// Enables Nesterov momentum, see [On the importance of initialization and
18    /// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf).
19    #[config(default = false)]
20    pub nesterov: bool,
21}
22
23/// State of [momentum](Momentum).
24#[derive(Record, Clone, new)]
25pub struct MomentumState<B: Backend, const D: usize> {
26    velocity: Tensor<B, D>,
27}
28
29/// Momemtum implementation that transforms gradients.
30#[derive(Clone)]
31pub struct Momentum<B: Backend> {
32    momentum: B::FloatElem,
33    dampening: f64,
34    nesterov: bool,
35}
36
37impl<B: Backend> Momentum<B> {
38    /// Creates a new [momentum](Momentum) from a [config](MomentumConfig).
39    pub fn new(config: &MomentumConfig) -> Self {
40        Self {
41            momentum: config.momentum.elem(),
42            dampening: config.dampening,
43            nesterov: config.nesterov,
44        }
45    }
46
47    /// Transforms a gradient.
48    ///
49    /// # Arguments
50    ///
51    /// * `grad` - Gradient to transform.
52    /// * `state` - State of the optimizer.
53    ///
54    /// # Returns
55    ///
56    /// * `grad` - Transformed gradient.
57    /// * `state` - State of the optimizer.
58    pub fn transform<const D: usize>(
59        &self,
60        grad: Tensor<B, D>,
61        state: Option<MomentumState<B, D>>,
62    ) -> (Tensor<B, D>, MomentumState<B, D>) {
63        let velocity = if let Some(state) = state {
64            grad.clone()
65                .mul_scalar(1.0 - self.dampening)
66                .add(state.velocity.mul_scalar(self.momentum))
67        } else {
68            grad.clone()
69        };
70
71        let grad = match self.nesterov {
72            true => velocity.clone().mul_scalar(self.momentum).add(grad),
73            false => velocity.clone(),
74        };
75
76        (grad, MomentumState::new(velocity))
77    }
78}
79
80impl<B: Backend, const D: usize> MomentumState<B, D> {
81    /// Moves the state to a device.
82    ///
83    /// # Arguments
84    ///
85    /// * `device` - Device to move the state to.
86    ///
87    /// # Returns
88    ///
89    /// * `self` - Moved state.
90    pub fn to_device(mut self, device: &B::Device) -> Self {
91        self.velocity = self.velocity.to_device(device);
92        self
93    }
94}