1use zyx::Tensor;
6use zyx_derive::Module;
7
8#[derive(Module)]
10#[cfg_attr(feature = "py", pyo3::pyclass)]
11pub struct RMSprop {
12 pub learning_rate: f32,
14 pub alpha: f32,
16 pub eps: f32,
18 pub momentum: f32,
20 pub centered: bool,
22 pub weight_decay: f32,
24 pub t: usize,
26 buffer: Vec<Tensor>,
28 momentum_buf: Vec<Tensor>,
30 grad_avg: Vec<Tensor>,
32}
33
34impl Default for RMSprop {
35 fn default() -> Self {
36 Self {
37 learning_rate: 0.01,
38 alpha: 0.99,
39 eps: 1e-8,
40 momentum: 0.0,
41 centered: false,
42 weight_decay: 0.0,
43 t: 0,
44 buffer: Vec::new(),
45 momentum_buf: Vec::new(),
46 grad_avg: Vec::new(),
47 }
48 }
49}
50
51impl RMSprop {
52 pub fn update<'a>(
54 &mut self,
55 parameters: impl IntoIterator<Item = &'a mut Tensor>,
56 gradients: impl IntoIterator<Item = Option<Tensor>>,
57 ) {
58 for (i, (param, grad)) in parameters.into_iter().zip(gradients).enumerate() {
59 let Some(grad) = grad else {
60 if self.buffer.len() <= i {
62 self.buffer.push(Tensor::zeros_like(&*param));
63 self.momentum_buf.push(Tensor::zeros_like(&*param));
64 if self.centered {
65 self.grad_avg.push(Tensor::zeros_like(&*param));
66 }
67 }
68 continue;
69 };
70
71 if self.buffer.len() <= i {
73 self.buffer.push(&grad * &grad * (1.0 - self.alpha));
74 self.momentum_buf.push(Tensor::zeros_like(&*param));
75 if self.centered {
76 self.grad_avg.push(&grad * (1.0 - self.alpha));
77 }
78 }
79
80 self.buffer[i] = &self.buffer[i] * self.alpha + &grad * &grad * (1.0 - self.alpha);
82
83 let denom = if self.centered {
84 self.grad_avg[i] = &self.grad_avg[i] * self.alpha + &grad * (1.0 - self.alpha);
86 let avg = &self.grad_avg[i];
87 (&self.buffer[i] - avg * avg).relu().sqrt() + self.eps
88 } else {
89 self.buffer[i].sqrt() + self.eps
90 };
91
92 let update = &grad / denom * self.learning_rate;
93
94 if self.momentum > 0.0 {
95 self.momentum_buf[i] = &self.momentum_buf[i] * self.momentum + &update;
96 *param = &*param - &self.momentum_buf[i];
97 } else {
98 *param = &*param - update;
99 }
100
101 if self.weight_decay > 0.0 {
102 *param = &*param * (1.0 - self.learning_rate * self.weight_decay);
103 }
104 }
105
106 self.t += 1;
107 }
108}