1use axonml_nn::Parameter;
18use axonml_tensor::Tensor;
19
20use crate::optimizer::Optimizer;
21
22use axonml_core;
24
25pub struct RMSprop {
46 params: Vec<Parameter>,
48 lr: f32,
50 alpha: f32,
52 eps: f32,
54 weight_decay: f32,
56 momentum: f32,
58 centered: bool,
60 state: Vec<RMSpropState>,
62}
63
64#[derive(Debug, Clone)]
69struct RMSpropState {
70 square_avg: Tensor<f32>,
72 momentum_buffer: Option<Tensor<f32>>,
74 grad_avg: Option<Tensor<f32>>,
76}
77
78impl RMSpropState {
79 fn new(shape: &[usize], device: axonml_core::Device, momentum: bool, centered: bool) -> Self {
80 let square_avg = {
81 let t = Tensor::zeros(shape);
82 if device.is_gpu() {
83 t.to_device(device).unwrap()
84 } else {
85 t
86 }
87 };
88 let momentum_buffer = if momentum {
89 let t = Tensor::zeros(shape);
90 Some(if device.is_gpu() {
91 t.to_device(device).unwrap()
92 } else {
93 t
94 })
95 } else {
96 None
97 };
98 let grad_avg = if centered {
99 let t = Tensor::zeros(shape);
100 Some(if device.is_gpu() {
101 t.to_device(device).unwrap()
102 } else {
103 t
104 })
105 } else {
106 None
107 };
108 Self {
109 square_avg,
110 momentum_buffer,
111 grad_avg,
112 }
113 }
114}
115
116impl RMSprop {
117 #[must_use]
119 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
120 Self {
121 params,
122 lr,
123 alpha: 0.99,
124 eps: 1e-8,
125 weight_decay: 0.0,
126 momentum: 0.0,
127 centered: false,
128 state: Vec::new(),
129 }
130 }
131
132 #[must_use]
134 pub fn with_alpha(params: Vec<Parameter>, lr: f32, alpha: f32) -> Self {
135 Self {
136 params,
137 lr,
138 alpha,
139 eps: 1e-8,
140 weight_decay: 0.0,
141 momentum: 0.0,
142 centered: false,
143 state: Vec::new(),
144 }
145 }
146
147 #[must_use]
149 pub fn with_options(
150 params: Vec<Parameter>,
151 lr: f32,
152 alpha: f32,
153 eps: f32,
154 weight_decay: f32,
155 momentum: f32,
156 centered: bool,
157 ) -> Self {
158 Self {
159 params,
160 lr,
161 alpha,
162 eps,
163 weight_decay,
164 momentum,
165 centered,
166 state: Vec::new(),
167 }
168 }
169
170 #[must_use]
172 pub fn alpha(mut self, alpha: f32) -> Self {
173 self.alpha = alpha;
174 self
175 }
176
177 #[must_use]
179 pub fn eps(mut self, eps: f32) -> Self {
180 self.eps = eps;
181 self
182 }
183
184 #[must_use]
186 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
187 self.weight_decay = weight_decay;
188 self
189 }
190
191 #[must_use]
193 pub fn momentum(mut self, momentum: f32) -> Self {
194 self.momentum = momentum;
195 self
196 }
197
198 #[must_use]
200 pub fn centered(mut self, centered: bool) -> Self {
201 self.centered = centered;
202 self
203 }
204
205 fn ensure_state_initialized(&mut self) {
206 if self.state.is_empty() {
207 self.state = self
208 .params
209 .iter()
210 .map(|p| {
211 let data = p.data();
212 RMSpropState::new(
213 data.shape(),
214 data.device(),
215 self.momentum != 0.0,
216 self.centered,
217 )
218 })
219 .collect();
220 }
221 }
222}
223
224impl Optimizer for RMSprop {
225 fn step(&mut self) {
226 self.ensure_state_initialized();
227
228 for (i, param) in self.params.iter().enumerate() {
235 if !param.requires_grad() {
236 continue;
237 }
238
239 let grad = match param.grad() {
240 Some(g) => g,
241 None => continue,
242 };
243
244 let param_data = param.data();
245 let state = &mut self.state[i];
246
247 let d = if self.weight_decay == 0.0 {
249 grad.clone()
250 } else {
251 grad.add(¶m_data.mul_scalar(self.weight_decay)).unwrap()
252 };
253
254 let d_sq = d.mul(&d).unwrap();
256 state.square_avg = state
257 .square_avg
258 .mul_scalar(self.alpha)
259 .add(&d_sq.mul_scalar(1.0 - self.alpha))
260 .unwrap();
261
262 let denom = if self.centered {
264 let grad_avg = state.grad_avg.as_mut().unwrap();
266 *grad_avg = grad_avg
267 .mul_scalar(self.alpha)
268 .add(&d.mul_scalar(1.0 - self.alpha))
269 .unwrap();
270
271 let ga_sq = grad_avg.mul(grad_avg).unwrap();
273 state
274 .square_avg
275 .sub(&ga_sq)
276 .unwrap()
277 .sqrt()
278 .add_scalar(self.eps)
279 } else {
280 state.square_avg.sqrt().add_scalar(self.eps)
282 };
283
284 let update = if self.momentum == 0.0 {
286 d.div(&denom).unwrap()
288 } else {
289 let normalized = d.div(&denom).unwrap();
291 let buf = state.momentum_buffer.as_mut().unwrap();
292 *buf = buf.mul_scalar(self.momentum).add(&normalized).unwrap();
293 buf.clone()
294 };
295
296 let new_param = param_data.sub(&update.mul_scalar(self.lr)).unwrap();
298 param.update_data(new_param);
299 }
300 }
301
302 fn zero_grad(&mut self) {
303 for param in &self.params {
304 param.zero_grad();
305 }
306 }
307
308 fn get_lr(&self) -> f32 {
309 self.lr
310 }
311
312 fn set_lr(&mut self, lr: f32) {
313 self.lr = lr;
314 }
315
316 fn parameters(&self) -> &[Parameter] {
317 &self.params
318 }
319}
320
321#[cfg(test)]
326mod tests {
327 use super::*;
328 use axonml_autograd::Variable;
329
330 #[test]
331 fn test_rmsprop_creation() {
332 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
333 let param = Parameter::from_variable(var);
334 let optimizer = RMSprop::new(vec![param], 0.01);
335
336 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
337 assert!((optimizer.alpha - 0.99).abs() < 1e-6);
338 }
339
340 #[test]
341 fn test_rmsprop_step() {
342 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
343 let param = Parameter::from_variable(var);
344
345 param
347 .variable()
348 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
349
350 let mut optimizer = RMSprop::new(vec![param.clone()], 0.01);
351 optimizer.step();
352
353 let new_data = param.data().to_vec();
354 assert!((new_data[0] - 1.0).abs() > 1e-6);
356 }
357
358 #[test]
359 fn test_rmsprop_with_momentum() {
360 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
361 let param = Parameter::from_variable(var);
362
363 let optimizer = RMSprop::new(vec![param], 0.01).momentum(0.9);
364
365 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
366 }
367
368 #[test]
369 fn test_rmsprop_centered() {
370 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
371 let param = Parameter::from_variable(var);
372
373 let optimizer = RMSprop::new(vec![param], 0.01).centered(true);
374
375 assert!(optimizer.centered);
376 }
377
378 #[test]
379 fn test_rmsprop_builder_pattern() {
380 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
381 let param = Parameter::from_variable(var);
382
383 let optimizer = RMSprop::new(vec![param], 0.01)
384 .alpha(0.95)
385 .eps(1e-6)
386 .weight_decay(0.0001)
387 .momentum(0.9)
388 .centered(true);
389
390 assert!((optimizer.alpha - 0.95).abs() < 1e-6);
391 assert!((optimizer.eps - 1e-6).abs() < 1e-9);
392 assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
393 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
394 assert!(optimizer.centered);
395 }
396}