1use axonml_nn::Parameter;
23use axonml_tensor::Tensor;
24
25use crate::optimizer::Optimizer;
26
27use axonml_core;
29
30pub struct RMSprop {
51 params: Vec<Parameter>,
53 lr: f32,
55 alpha: f32,
57 eps: f32,
59 weight_decay: f32,
61 momentum: f32,
63 centered: bool,
65 state: Vec<RMSpropState>,
67}
68
69#[derive(Debug, Clone)]
74struct RMSpropState {
75 square_avg: Tensor<f32>,
77 momentum_buffer: Option<Tensor<f32>>,
79 grad_avg: Option<Tensor<f32>>,
81}
82
83impl RMSpropState {
84 fn new(shape: &[usize], device: axonml_core::Device, momentum: bool, centered: bool) -> Self {
85 let square_avg = {
86 let t = Tensor::zeros(shape);
87 if device.is_gpu() {
88 t.to_device(device).expect("device transfer failed")
89 } else {
90 t
91 }
92 };
93 let momentum_buffer = if momentum {
94 let t = Tensor::zeros(shape);
95 Some(if device.is_gpu() {
96 t.to_device(device).expect("device transfer failed")
97 } else {
98 t
99 })
100 } else {
101 None
102 };
103 let grad_avg = if centered {
104 let t = Tensor::zeros(shape);
105 Some(if device.is_gpu() {
106 t.to_device(device).expect("device transfer failed")
107 } else {
108 t
109 })
110 } else {
111 None
112 };
113 Self {
114 square_avg,
115 momentum_buffer,
116 grad_avg,
117 }
118 }
119}
120
121impl RMSprop {
122 #[must_use]
124 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
125 Self {
126 params,
127 lr,
128 alpha: 0.99,
129 eps: 1e-8,
130 weight_decay: 0.0,
131 momentum: 0.0,
132 centered: false,
133 state: Vec::new(),
134 }
135 }
136
137 #[must_use]
139 pub fn with_alpha(params: Vec<Parameter>, lr: f32, alpha: f32) -> Self {
140 Self {
141 params,
142 lr,
143 alpha,
144 eps: 1e-8,
145 weight_decay: 0.0,
146 momentum: 0.0,
147 centered: false,
148 state: Vec::new(),
149 }
150 }
151
152 #[must_use]
154 pub fn with_options(
155 params: Vec<Parameter>,
156 lr: f32,
157 alpha: f32,
158 eps: f32,
159 weight_decay: f32,
160 momentum: f32,
161 centered: bool,
162 ) -> Self {
163 Self {
164 params,
165 lr,
166 alpha,
167 eps,
168 weight_decay,
169 momentum,
170 centered,
171 state: Vec::new(),
172 }
173 }
174
175 #[must_use]
177 pub fn alpha(mut self, alpha: f32) -> Self {
178 self.alpha = alpha;
179 self
180 }
181
182 #[must_use]
184 pub fn eps(mut self, eps: f32) -> Self {
185 self.eps = eps;
186 self
187 }
188
189 #[must_use]
191 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
192 self.weight_decay = weight_decay;
193 self
194 }
195
196 #[must_use]
198 pub fn momentum(mut self, momentum: f32) -> Self {
199 self.momentum = momentum;
200 self
201 }
202
203 #[must_use]
205 pub fn centered(mut self, centered: bool) -> Self {
206 self.centered = centered;
207 self
208 }
209
210 fn ensure_state_initialized(&mut self) {
211 if self.state.is_empty() {
212 self.state = self
213 .params
214 .iter()
215 .map(|p| {
216 let data = p.data();
217 RMSpropState::new(
218 data.shape(),
219 data.device(),
220 self.momentum != 0.0,
221 self.centered,
222 )
223 })
224 .collect();
225 }
226 }
227}
228
229impl Optimizer for RMSprop {
230 fn step(&mut self) {
231 self.ensure_state_initialized();
232
233 for (i, param) in self.params.iter().enumerate() {
240 if !param.requires_grad() {
241 continue;
242 }
243
244 let grad = match param.grad() {
245 Some(g) => g,
246 None => continue,
247 };
248
249 let param_data = param.data();
250 let state = &mut self.state[i];
251
252 let d = if self.weight_decay == 0.0 {
254 grad.clone()
255 } else {
256 grad.add(¶m_data.mul_scalar(self.weight_decay)).unwrap()
257 };
258
259 let d_sq = d.mul(&d).unwrap();
261 state.square_avg = state
262 .square_avg
263 .mul_scalar(self.alpha)
264 .add(&d_sq.mul_scalar(1.0 - self.alpha))
265 .unwrap();
266
267 let denom = if self.centered {
269 let grad_avg = state.grad_avg.as_mut().unwrap();
271 *grad_avg = grad_avg
272 .mul_scalar(self.alpha)
273 .add(&d.mul_scalar(1.0 - self.alpha))
274 .unwrap();
275
276 let ga_sq = grad_avg.mul(grad_avg).unwrap();
278 state
279 .square_avg
280 .sub(&ga_sq)
281 .unwrap()
282 .sqrt()
283 .add_scalar(self.eps)
284 } else {
285 state.square_avg.sqrt().add_scalar(self.eps)
287 };
288
289 let update = if self.momentum == 0.0 {
291 d.div(&denom).unwrap()
293 } else {
294 let normalized = d.div(&denom).unwrap();
296 let buf = state.momentum_buffer.as_mut().unwrap();
297 *buf = buf.mul_scalar(self.momentum).add(&normalized).unwrap();
298 buf.clone()
299 };
300
301 let new_param = param_data.sub(&update.mul_scalar(self.lr)).unwrap();
303 param.update_data(new_param);
304 }
305 }
306
307 fn zero_grad(&mut self) {
308 for param in &self.params {
309 param.zero_grad();
310 }
311 }
312
313 fn get_lr(&self) -> f32 {
314 self.lr
315 }
316
317 fn set_lr(&mut self, lr: f32) {
318 self.lr = lr;
319 }
320
321 fn parameters(&self) -> &[Parameter] {
322 &self.params
323 }
324}
325
326#[cfg(test)]
331mod tests {
332 use super::*;
333 use axonml_autograd::Variable;
334
335 #[test]
336 fn test_rmsprop_creation() {
337 let var = Variable::new(
338 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
339 true,
340 );
341 let param = Parameter::from_variable(var);
342 let optimizer = RMSprop::new(vec![param], 0.01);
343
344 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
345 assert!((optimizer.alpha - 0.99).abs() < 1e-6);
346 }
347
348 #[test]
349 fn test_rmsprop_step() {
350 let var = Variable::new(
351 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
352 true,
353 );
354 let param = Parameter::from_variable(var);
355
356 param
358 .variable()
359 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
360
361 let mut optimizer = RMSprop::new(vec![param.clone()], 0.01);
362 optimizer.step();
363
364 let new_data = param.data().to_vec();
365 assert!((new_data[0] - 1.0).abs() > 1e-6);
367 }
368
369 #[test]
370 fn test_rmsprop_with_momentum() {
371 let var = Variable::new(
372 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
373 true,
374 );
375 let param = Parameter::from_variable(var);
376
377 let optimizer = RMSprop::new(vec![param], 0.01).momentum(0.9);
378
379 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
380 }
381
382 #[test]
383 fn test_rmsprop_centered() {
384 let var = Variable::new(
385 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
386 true,
387 );
388 let param = Parameter::from_variable(var);
389
390 let optimizer = RMSprop::new(vec![param], 0.01).centered(true);
391
392 assert!(optimizer.centered);
393 }
394
395 #[test]
396 fn test_rmsprop_builder_pattern() {
397 let var = Variable::new(
398 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
399 true,
400 );
401 let param = Parameter::from_variable(var);
402
403 let optimizer = RMSprop::new(vec![param], 0.01)
404 .alpha(0.95)
405 .eps(1e-6)
406 .weight_decay(0.0001)
407 .momentum(0.9)
408 .centered(true);
409
410 assert!((optimizer.alpha - 0.95).abs() < 1e-6);
411 assert!((optimizer.eps - 1e-6).abs() < 1e-9);
412 assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
413 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
414 assert!(optimizer.centered);
415 }
416}