1use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13pub struct RMSprop {
34 params: Vec<Parameter>,
36 lr: f32,
38 alpha: f32,
40 eps: f32,
42 weight_decay: f32,
44 momentum: f32,
46 centered: bool,
48 state: Vec<RMSpropState>,
50}
51
52#[derive(Debug, Clone)]
54struct RMSpropState {
55 square_avg: Vec<f32>,
57 momentum_buffer: Option<Vec<f32>>,
59 grad_avg: Option<Vec<f32>>,
61}
62
63impl RMSpropState {
64 fn new(size: usize, momentum: bool, centered: bool) -> Self {
65 Self {
66 square_avg: vec![0.0; size],
67 momentum_buffer: if momentum {
68 Some(vec![0.0; size])
69 } else {
70 None
71 },
72 grad_avg: if centered {
73 Some(vec![0.0; size])
74 } else {
75 None
76 },
77 }
78 }
79}
80
81impl RMSprop {
82 #[must_use]
84 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
85 Self {
86 params,
87 lr,
88 alpha: 0.99,
89 eps: 1e-8,
90 weight_decay: 0.0,
91 momentum: 0.0,
92 centered: false,
93 state: Vec::new(),
94 }
95 }
96
97 #[must_use]
99 pub fn with_alpha(params: Vec<Parameter>, lr: f32, alpha: f32) -> Self {
100 Self {
101 params,
102 lr,
103 alpha,
104 eps: 1e-8,
105 weight_decay: 0.0,
106 momentum: 0.0,
107 centered: false,
108 state: Vec::new(),
109 }
110 }
111
112 #[must_use]
114 pub fn with_options(
115 params: Vec<Parameter>,
116 lr: f32,
117 alpha: f32,
118 eps: f32,
119 weight_decay: f32,
120 momentum: f32,
121 centered: bool,
122 ) -> Self {
123 Self {
124 params,
125 lr,
126 alpha,
127 eps,
128 weight_decay,
129 momentum,
130 centered,
131 state: Vec::new(),
132 }
133 }
134
135 #[must_use]
137 pub fn alpha(mut self, alpha: f32) -> Self {
138 self.alpha = alpha;
139 self
140 }
141
142 #[must_use]
144 pub fn eps(mut self, eps: f32) -> Self {
145 self.eps = eps;
146 self
147 }
148
149 #[must_use]
151 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
152 self.weight_decay = weight_decay;
153 self
154 }
155
156 #[must_use]
158 pub fn momentum(mut self, momentum: f32) -> Self {
159 self.momentum = momentum;
160 self
161 }
162
163 #[must_use]
165 pub fn centered(mut self, centered: bool) -> Self {
166 self.centered = centered;
167 self
168 }
169
170 fn ensure_state_initialized(&mut self) {
171 if self.state.is_empty() {
172 self.state = self
173 .params
174 .iter()
175 .map(|p| RMSpropState::new(p.numel(), self.momentum != 0.0, self.centered))
176 .collect();
177 }
178 }
179}
180
181impl Optimizer for RMSprop {
182 fn step(&mut self) {
183 self.ensure_state_initialized();
184
185 for (i, param) in self.params.iter().enumerate() {
186 if !param.requires_grad() {
187 continue;
188 }
189
190 let grad = match param.grad() {
191 Some(g) => g,
192 None => continue,
193 };
194
195 let mut grad_vec = grad.to_vec();
196 let state = &mut self.state[i];
197
198 let param_data = param.data();
199 let mut param_vec = param_data.to_vec();
200
201 if self.weight_decay != 0.0 {
203 for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
204 *g += self.weight_decay * p;
205 }
206 }
207
208 for (sq, g) in state.square_avg.iter_mut().zip(grad_vec.iter()) {
210 *sq = self.alpha * *sq + (1.0 - self.alpha) * g * g;
211 }
212
213 let lr = self.lr;
215 let eps = self.eps;
216
217 if self.centered {
218 let grad_avg = state.grad_avg.as_mut().unwrap();
220 if self.momentum == 0.0 {
221 for i in 0..param_vec.len() {
222 grad_avg[i] = self.alpha * grad_avg[i] + (1.0 - self.alpha) * grad_vec[i];
223 let avg = (state.square_avg[i] - grad_avg[i] * grad_avg[i]).sqrt() + eps;
224 param_vec[i] -= lr * grad_vec[i] / avg;
225 }
226 } else {
227 let buf = state.momentum_buffer.as_mut().unwrap();
228 for i in 0..param_vec.len() {
229 grad_avg[i] = self.alpha * grad_avg[i] + (1.0 - self.alpha) * grad_vec[i];
230 let avg = (state.square_avg[i] - grad_avg[i] * grad_avg[i]).sqrt() + eps;
231 buf[i] = self.momentum * buf[i] + grad_vec[i] / avg;
232 param_vec[i] -= lr * buf[i];
233 }
234 }
235 } else if self.momentum == 0.0 {
236 for i in 0..param_vec.len() {
238 let avg = state.square_avg[i].sqrt() + eps;
239 param_vec[i] -= lr * grad_vec[i] / avg;
240 }
241 } else {
242 let buf = state.momentum_buffer.as_mut().unwrap();
244 for i in 0..param_vec.len() {
245 let avg = state.square_avg[i].sqrt() + eps;
246 buf[i] = self.momentum * buf[i] + grad_vec[i] / avg;
247 param_vec[i] -= lr * buf[i];
248 }
249 }
250
251 let mut update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
252 let device = param_data.device();
254 if device.is_gpu() {
255 update = update.to_device(device).unwrap();
256 }
257 param.update_data(update);
258 }
259 }
260
261 fn zero_grad(&mut self) {
262 for param in &self.params {
263 param.zero_grad();
264 }
265 }
266
267 fn get_lr(&self) -> f32 {
268 self.lr
269 }
270
271 fn set_lr(&mut self, lr: f32) {
272 self.lr = lr;
273 }
274
275 fn parameters(&self) -> &[Parameter] {
276 &self.params
277 }
278}
279
280#[cfg(test)]
285mod tests {
286 use super::*;
287 use axonml_autograd::Variable;
288
289 #[test]
290 fn test_rmsprop_creation() {
291 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
292 let param = Parameter::from_variable(var);
293 let optimizer = RMSprop::new(vec![param], 0.01);
294
295 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
296 assert!((optimizer.alpha - 0.99).abs() < 1e-6);
297 }
298
299 #[test]
300 fn test_rmsprop_step() {
301 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
302 let param = Parameter::from_variable(var);
303
304 param
306 .variable()
307 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
308
309 let mut optimizer = RMSprop::new(vec![param.clone()], 0.01);
310 optimizer.step();
311
312 let new_data = param.data().to_vec();
313 assert!((new_data[0] - 1.0).abs() > 1e-6);
315 }
316
317 #[test]
318 fn test_rmsprop_with_momentum() {
319 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
320 let param = Parameter::from_variable(var);
321
322 let optimizer = RMSprop::new(vec![param], 0.01).momentum(0.9);
323
324 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
325 }
326
327 #[test]
328 fn test_rmsprop_centered() {
329 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
330 let param = Parameter::from_variable(var);
331
332 let optimizer = RMSprop::new(vec![param], 0.01).centered(true);
333
334 assert!(optimizer.centered);
335 }
336
337 #[test]
338 fn test_rmsprop_builder_pattern() {
339 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
340 let param = Parameter::from_variable(var);
341
342 let optimizer = RMSprop::new(vec![param], 0.01)
343 .alpha(0.95)
344 .eps(1e-6)
345 .weight_decay(0.0001)
346 .momentum(0.9)
347 .centered(true);
348
349 assert!((optimizer.alpha - 0.95).abs() < 1e-6);
350 assert!((optimizer.eps - 1e-6).abs() < 1e-9);
351 assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
352 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
353 assert!(optimizer.centered);
354 }
355}