burn_core/optim/
momentum.rs1use crate as burn;
2
3use crate::config::Config;
4use crate::record::Record;
5use crate::tensor::{ElementConversion, Tensor};
6use burn_tensor::backend::Backend;
7
8#[derive(Config)]
10pub struct MomentumConfig {
11 #[config(default = 0.9)]
13 pub momentum: f64,
14 #[config(default = 0.1)]
16 pub dampening: f64,
17 #[config(default = false)]
20 pub nesterov: bool,
21}
22
23#[derive(Record, Clone, new)]
25pub struct MomentumState<B: Backend, const D: usize> {
26 velocity: Tensor<B, D>,
27}
28
29#[derive(Clone)]
31pub struct Momentum<B: Backend> {
32 momentum: B::FloatElem,
33 dampening: f64,
34 nesterov: bool,
35}
36
37impl<B: Backend> Momentum<B> {
38 pub fn new(config: &MomentumConfig) -> Self {
40 Self {
41 momentum: config.momentum.elem(),
42 dampening: config.dampening,
43 nesterov: config.nesterov,
44 }
45 }
46
47 pub fn transform<const D: usize>(
59 &self,
60 grad: Tensor<B, D>,
61 state: Option<MomentumState<B, D>>,
62 ) -> (Tensor<B, D>, MomentumState<B, D>) {
63 let velocity = if let Some(state) = state {
64 grad.clone()
65 .mul_scalar(1.0 - self.dampening)
66 .add(state.velocity.mul_scalar(self.momentum))
67 } else {
68 grad.clone()
69 };
70
71 let grad = match self.nesterov {
72 true => velocity.clone().mul_scalar(self.momentum).add(grad),
73 false => velocity.clone(),
74 };
75
76 (grad, MomentumState::new(velocity))
77 }
78}
79
80impl<B: Backend, const D: usize> MomentumState<B, D> {
81 pub fn to_device(mut self, device: &B::Device) -> Self {
91 self.velocity = self.velocity.to_device(device);
92 self
93 }
94}