1use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13pub struct RMSprop {
34 params: Vec<Parameter>,
36 lr: f32,
38 alpha: f32,
40 eps: f32,
42 weight_decay: f32,
44 momentum: f32,
46 centered: bool,
48 state: Vec<RMSpropState>,
50}
51
52#[derive(Debug, Clone)]
54struct RMSpropState {
55 square_avg: Vec<f32>,
57 momentum_buffer: Option<Vec<f32>>,
59 grad_avg: Option<Vec<f32>>,
61}
62
63impl RMSpropState {
64 fn new(size: usize, momentum: bool, centered: bool) -> Self {
65 Self {
66 square_avg: vec![0.0; size],
67 momentum_buffer: if momentum {
68 Some(vec![0.0; size])
69 } else {
70 None
71 },
72 grad_avg: if centered {
73 Some(vec![0.0; size])
74 } else {
75 None
76 },
77 }
78 }
79}
80
81impl RMSprop {
82 #[must_use] pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
84 Self {
85 params,
86 lr,
87 alpha: 0.99,
88 eps: 1e-8,
89 weight_decay: 0.0,
90 momentum: 0.0,
91 centered: false,
92 state: Vec::new(),
93 }
94 }
95
96 #[must_use] pub fn with_alpha(params: Vec<Parameter>, lr: f32, alpha: f32) -> Self {
98 Self {
99 params,
100 lr,
101 alpha,
102 eps: 1e-8,
103 weight_decay: 0.0,
104 momentum: 0.0,
105 centered: false,
106 state: Vec::new(),
107 }
108 }
109
110 #[must_use] pub fn with_options(
112 params: Vec<Parameter>,
113 lr: f32,
114 alpha: f32,
115 eps: f32,
116 weight_decay: f32,
117 momentum: f32,
118 centered: bool,
119 ) -> Self {
120 Self {
121 params,
122 lr,
123 alpha,
124 eps,
125 weight_decay,
126 momentum,
127 centered,
128 state: Vec::new(),
129 }
130 }
131
132 #[must_use] pub fn alpha(mut self, alpha: f32) -> Self {
134 self.alpha = alpha;
135 self
136 }
137
138 #[must_use] pub fn eps(mut self, eps: f32) -> Self {
140 self.eps = eps;
141 self
142 }
143
144 #[must_use] pub fn weight_decay(mut self, weight_decay: f32) -> Self {
146 self.weight_decay = weight_decay;
147 self
148 }
149
150 #[must_use] pub fn momentum(mut self, momentum: f32) -> Self {
152 self.momentum = momentum;
153 self
154 }
155
156 #[must_use] pub fn centered(mut self, centered: bool) -> Self {
158 self.centered = centered;
159 self
160 }
161
162 fn ensure_state_initialized(&mut self) {
163 if self.state.is_empty() {
164 self.state = self
165 .params
166 .iter()
167 .map(|p| RMSpropState::new(p.numel(), self.momentum != 0.0, self.centered))
168 .collect();
169 }
170 }
171}
172
173impl Optimizer for RMSprop {
174 fn step(&mut self) {
175 self.ensure_state_initialized();
176
177 for (i, param) in self.params.iter().enumerate() {
178 if !param.requires_grad() {
179 continue;
180 }
181
182 let grad = match param.grad() {
183 Some(g) => g,
184 None => continue,
185 };
186
187 let mut grad_vec = grad.to_vec();
188 let state = &mut self.state[i];
189
190 let param_data = param.data();
191 let mut param_vec = param_data.to_vec();
192
193 if self.weight_decay != 0.0 {
195 for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
196 *g += self.weight_decay * p;
197 }
198 }
199
200 for (sq, g) in state.square_avg.iter_mut().zip(grad_vec.iter()) {
202 *sq = self.alpha * *sq + (1.0 - self.alpha) * g * g;
203 }
204
205 let avg: Vec<f32> = if self.centered {
207 let grad_avg = state.grad_avg.as_mut().unwrap();
209 for (ga, g) in grad_avg.iter_mut().zip(grad_vec.iter()) {
210 *ga = self.alpha * *ga + (1.0 - self.alpha) * g;
211 }
212 state
214 .square_avg
215 .iter()
216 .zip(grad_avg.iter())
217 .map(|(sq, ga)| (sq - ga * ga).sqrt() + self.eps)
218 .collect()
219 } else {
220 state
221 .square_avg
222 .iter()
223 .map(|sq| sq.sqrt() + self.eps)
224 .collect()
225 };
226
227 if self.momentum == 0.0 {
229 for ((p, g), a) in param_vec.iter_mut().zip(grad_vec.iter()).zip(avg.iter()) {
231 *p -= self.lr * g / a;
232 }
233 } else {
234 let buf = state.momentum_buffer.as_mut().unwrap();
236 for ((b, g), a) in buf.iter_mut().zip(grad_vec.iter()).zip(avg.iter()) {
237 *b = self.momentum * *b + g / a;
238 }
239 for (p, b) in param_vec.iter_mut().zip(buf.iter()) {
240 *p -= self.lr * b;
241 }
242 }
243
244 let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
245 param.update_data(update);
246 }
247 }
248
249 fn zero_grad(&mut self) {
250 for param in &self.params {
251 param.zero_grad();
252 }
253 }
254
255 fn get_lr(&self) -> f32 {
256 self.lr
257 }
258
259 fn set_lr(&mut self, lr: f32) {
260 self.lr = lr;
261 }
262
263 fn parameters(&self) -> &[Parameter] {
264 &self.params
265 }
266}
267
268#[cfg(test)]
273mod tests {
274 use super::*;
275 use axonml_autograd::Variable;
276
277 #[test]
278 fn test_rmsprop_creation() {
279 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
280 let param = Parameter::from_variable(var);
281 let optimizer = RMSprop::new(vec![param], 0.01);
282
283 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
284 assert!((optimizer.alpha - 0.99).abs() < 1e-6);
285 }
286
287 #[test]
288 fn test_rmsprop_step() {
289 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
290 let param = Parameter::from_variable(var);
291
292 param
294 .variable()
295 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
296
297 let mut optimizer = RMSprop::new(vec![param.clone()], 0.01);
298 optimizer.step();
299
300 let new_data = param.data().to_vec();
301 assert!((new_data[0] - 1.0).abs() > 1e-6);
303 }
304
305 #[test]
306 fn test_rmsprop_with_momentum() {
307 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
308 let param = Parameter::from_variable(var);
309
310 let optimizer = RMSprop::new(vec![param], 0.01).momentum(0.9);
311
312 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
313 }
314
315 #[test]
316 fn test_rmsprop_centered() {
317 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
318 let param = Parameter::from_variable(var);
319
320 let optimizer = RMSprop::new(vec![param], 0.01).centered(true);
321
322 assert!(optimizer.centered);
323 }
324
325 #[test]
326 fn test_rmsprop_builder_pattern() {
327 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
328 let param = Parameter::from_variable(var);
329
330 let optimizer = RMSprop::new(vec![param], 0.01)
331 .alpha(0.95)
332 .eps(1e-6)
333 .weight_decay(0.0001)
334 .momentum(0.9)
335 .centered(true);
336
337 assert!((optimizer.alpha - 0.95).abs() < 1e-6);
338 assert!((optimizer.eps - 1e-6).abs() < 1e-9);
339 assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
340 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
341 assert!(optimizer.centered);
342 }
343}