1use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13pub struct RMSprop {
34 params: Vec<Parameter>,
36 lr: f32,
38 alpha: f32,
40 eps: f32,
42 weight_decay: f32,
44 momentum: f32,
46 centered: bool,
48 state: Vec<RMSpropState>,
50}
51
52#[derive(Debug, Clone)]
54struct RMSpropState {
55 square_avg: Vec<f32>,
57 momentum_buffer: Option<Vec<f32>>,
59 grad_avg: Option<Vec<f32>>,
61}
62
63impl RMSpropState {
64 fn new(size: usize, momentum: bool, centered: bool) -> Self {
65 Self {
66 square_avg: vec![0.0; size],
67 momentum_buffer: if momentum {
68 Some(vec![0.0; size])
69 } else {
70 None
71 },
72 grad_avg: if centered {
73 Some(vec![0.0; size])
74 } else {
75 None
76 },
77 }
78 }
79}
80
81impl RMSprop {
82 #[must_use]
84 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
85 Self {
86 params,
87 lr,
88 alpha: 0.99,
89 eps: 1e-8,
90 weight_decay: 0.0,
91 momentum: 0.0,
92 centered: false,
93 state: Vec::new(),
94 }
95 }
96
97 #[must_use]
99 pub fn with_alpha(params: Vec<Parameter>, lr: f32, alpha: f32) -> Self {
100 Self {
101 params,
102 lr,
103 alpha,
104 eps: 1e-8,
105 weight_decay: 0.0,
106 momentum: 0.0,
107 centered: false,
108 state: Vec::new(),
109 }
110 }
111
112 #[must_use]
114 pub fn with_options(
115 params: Vec<Parameter>,
116 lr: f32,
117 alpha: f32,
118 eps: f32,
119 weight_decay: f32,
120 momentum: f32,
121 centered: bool,
122 ) -> Self {
123 Self {
124 params,
125 lr,
126 alpha,
127 eps,
128 weight_decay,
129 momentum,
130 centered,
131 state: Vec::new(),
132 }
133 }
134
135 #[must_use]
137 pub fn alpha(mut self, alpha: f32) -> Self {
138 self.alpha = alpha;
139 self
140 }
141
142 #[must_use]
144 pub fn eps(mut self, eps: f32) -> Self {
145 self.eps = eps;
146 self
147 }
148
149 #[must_use]
151 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
152 self.weight_decay = weight_decay;
153 self
154 }
155
156 #[must_use]
158 pub fn momentum(mut self, momentum: f32) -> Self {
159 self.momentum = momentum;
160 self
161 }
162
163 #[must_use]
165 pub fn centered(mut self, centered: bool) -> Self {
166 self.centered = centered;
167 self
168 }
169
170 fn ensure_state_initialized(&mut self) {
171 if self.state.is_empty() {
172 self.state = self
173 .params
174 .iter()
175 .map(|p| RMSpropState::new(p.numel(), self.momentum != 0.0, self.centered))
176 .collect();
177 }
178 }
179}
180
181impl Optimizer for RMSprop {
182 fn step(&mut self) {
183 self.ensure_state_initialized();
184
185 for (i, param) in self.params.iter().enumerate() {
186 if !param.requires_grad() {
187 continue;
188 }
189
190 let grad = match param.grad() {
191 Some(g) => g,
192 None => continue,
193 };
194
195 let mut grad_vec = grad.to_vec();
196 let state = &mut self.state[i];
197
198 let param_data = param.data();
199 let mut param_vec = param_data.to_vec();
200
201 if self.weight_decay != 0.0 {
203 for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
204 *g += self.weight_decay * p;
205 }
206 }
207
208 for (sq, g) in state.square_avg.iter_mut().zip(grad_vec.iter()) {
210 *sq = self.alpha * *sq + (1.0 - self.alpha) * g * g;
211 }
212
213 let avg: Vec<f32> = if self.centered {
215 let grad_avg = state.grad_avg.as_mut().unwrap();
217 for (ga, g) in grad_avg.iter_mut().zip(grad_vec.iter()) {
218 *ga = self.alpha * *ga + (1.0 - self.alpha) * g;
219 }
220 state
222 .square_avg
223 .iter()
224 .zip(grad_avg.iter())
225 .map(|(sq, ga)| (sq - ga * ga).sqrt() + self.eps)
226 .collect()
227 } else {
228 state
229 .square_avg
230 .iter()
231 .map(|sq| sq.sqrt() + self.eps)
232 .collect()
233 };
234
235 if self.momentum == 0.0 {
237 for ((p, g), a) in param_vec.iter_mut().zip(grad_vec.iter()).zip(avg.iter()) {
239 *p -= self.lr * g / a;
240 }
241 } else {
242 let buf = state.momentum_buffer.as_mut().unwrap();
244 for ((b, g), a) in buf.iter_mut().zip(grad_vec.iter()).zip(avg.iter()) {
245 *b = self.momentum * *b + g / a;
246 }
247 for (p, b) in param_vec.iter_mut().zip(buf.iter()) {
248 *p -= self.lr * b;
249 }
250 }
251
252 let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
253 param.update_data(update);
254 }
255 }
256
257 fn zero_grad(&mut self) {
258 for param in &self.params {
259 param.zero_grad();
260 }
261 }
262
263 fn get_lr(&self) -> f32 {
264 self.lr
265 }
266
267 fn set_lr(&mut self, lr: f32) {
268 self.lr = lr;
269 }
270
271 fn parameters(&self) -> &[Parameter] {
272 &self.params
273 }
274}
275
276#[cfg(test)]
281mod tests {
282 use super::*;
283 use axonml_autograd::Variable;
284
285 #[test]
286 fn test_rmsprop_creation() {
287 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
288 let param = Parameter::from_variable(var);
289 let optimizer = RMSprop::new(vec![param], 0.01);
290
291 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
292 assert!((optimizer.alpha - 0.99).abs() < 1e-6);
293 }
294
295 #[test]
296 fn test_rmsprop_step() {
297 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
298 let param = Parameter::from_variable(var);
299
300 param
302 .variable()
303 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
304
305 let mut optimizer = RMSprop::new(vec![param.clone()], 0.01);
306 optimizer.step();
307
308 let new_data = param.data().to_vec();
309 assert!((new_data[0] - 1.0).abs() > 1e-6);
311 }
312
313 #[test]
314 fn test_rmsprop_with_momentum() {
315 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
316 let param = Parameter::from_variable(var);
317
318 let optimizer = RMSprop::new(vec![param], 0.01).momentum(0.9);
319
320 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
321 }
322
323 #[test]
324 fn test_rmsprop_centered() {
325 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
326 let param = Parameter::from_variable(var);
327
328 let optimizer = RMSprop::new(vec![param], 0.01).centered(true);
329
330 assert!(optimizer.centered);
331 }
332
333 #[test]
334 fn test_rmsprop_builder_pattern() {
335 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
336 let param = Parameter::from_variable(var);
337
338 let optimizer = RMSprop::new(vec![param], 0.01)
339 .alpha(0.95)
340 .eps(1e-6)
341 .weight_decay(0.0001)
342 .momentum(0.9)
343 .centered(true);
344
345 assert!((optimizer.alpha - 0.95).abs() < 1e-6);
346 assert!((optimizer.eps - 1e-6).abs() < 1e-9);
347 assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
348 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
349 assert!(optimizer.centered);
350 }
351}