Skip to main content

axonml_optim/
rmsprop.rs

1//! `RMSprop` Optimizer
2//!
3//! Implements `RMSprop` (Root Mean Square Propagation) optimizer.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13// =============================================================================
14// RMSprop
15// =============================================================================
16
17/// `RMSprop` optimizer.
18///
19/// Maintains a moving average of squared gradients to normalize updates.
20///
21/// Update rule:
22/// ```text
23/// v_t = alpha * v_{t-1} + (1 - alpha) * grad^2
24/// param = param - lr * grad / (sqrt(v_t) + eps)
25/// ```
26///
27/// With momentum:
28/// ```text
29/// v_t = alpha * v_{t-1} + (1 - alpha) * grad^2
30/// buf_t = momentum * buf_{t-1} + grad / (sqrt(v_t) + eps)
31/// param = param - lr * buf_t
32/// ```
33pub struct RMSprop {
34    /// Parameters to optimize.
35    params: Vec<Parameter>,
36    /// Learning rate.
37    lr: f32,
38    /// Smoothing constant (decay rate for moving average).
39    alpha: f32,
40    /// Small constant for numerical stability.
41    eps: f32,
42    /// Weight decay (L2 regularization).
43    weight_decay: f32,
44    /// Momentum factor.
45    momentum: f32,
46    /// Whether to center the gradient (subtract mean).
47    centered: bool,
48    /// Per-parameter state.
49    state: Vec<RMSpropState>,
50}
51
52/// State for `RMSprop` optimizer.
53#[derive(Debug, Clone)]
54struct RMSpropState {
55    /// Square average of gradients.
56    square_avg: Vec<f32>,
57    /// Momentum buffer.
58    momentum_buffer: Option<Vec<f32>>,
59    /// Gradient average (for centered `RMSprop`).
60    grad_avg: Option<Vec<f32>>,
61}
62
63impl RMSpropState {
64    fn new(size: usize, momentum: bool, centered: bool) -> Self {
65        Self {
66            square_avg: vec![0.0; size],
67            momentum_buffer: if momentum {
68                Some(vec![0.0; size])
69            } else {
70                None
71            },
72            grad_avg: if centered {
73                Some(vec![0.0; size])
74            } else {
75                None
76            },
77        }
78    }
79}
80
81impl RMSprop {
82    /// Creates a new `RMSprop` optimizer with default settings.
83    #[must_use]
84    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
85        Self {
86            params,
87            lr,
88            alpha: 0.99,
89            eps: 1e-8,
90            weight_decay: 0.0,
91            momentum: 0.0,
92            centered: false,
93            state: Vec::new(),
94        }
95    }
96
97    /// Creates `RMSprop` with specified alpha (smoothing constant).
98    #[must_use]
99    pub fn with_alpha(params: Vec<Parameter>, lr: f32, alpha: f32) -> Self {
100        Self {
101            params,
102            lr,
103            alpha,
104            eps: 1e-8,
105            weight_decay: 0.0,
106            momentum: 0.0,
107            centered: false,
108            state: Vec::new(),
109        }
110    }
111
112    /// Creates `RMSprop` with all options.
113    #[must_use]
114    pub fn with_options(
115        params: Vec<Parameter>,
116        lr: f32,
117        alpha: f32,
118        eps: f32,
119        weight_decay: f32,
120        momentum: f32,
121        centered: bool,
122    ) -> Self {
123        Self {
124            params,
125            lr,
126            alpha,
127            eps,
128            weight_decay,
129            momentum,
130            centered,
131            state: Vec::new(),
132        }
133    }
134
135    /// Builder method to set alpha.
136    #[must_use]
137    pub fn alpha(mut self, alpha: f32) -> Self {
138        self.alpha = alpha;
139        self
140    }
141
142    /// Builder method to set epsilon.
143    #[must_use]
144    pub fn eps(mut self, eps: f32) -> Self {
145        self.eps = eps;
146        self
147    }
148
149    /// Builder method to set weight decay.
150    #[must_use]
151    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
152        self.weight_decay = weight_decay;
153        self
154    }
155
156    /// Builder method to set momentum.
157    #[must_use]
158    pub fn momentum(mut self, momentum: f32) -> Self {
159        self.momentum = momentum;
160        self
161    }
162
163    /// Builder method to enable centered `RMSprop`.
164    #[must_use]
165    pub fn centered(mut self, centered: bool) -> Self {
166        self.centered = centered;
167        self
168    }
169
170    fn ensure_state_initialized(&mut self) {
171        if self.state.is_empty() {
172            self.state = self
173                .params
174                .iter()
175                .map(|p| RMSpropState::new(p.numel(), self.momentum != 0.0, self.centered))
176                .collect();
177        }
178    }
179}
180
181impl Optimizer for RMSprop {
182    fn step(&mut self) {
183        self.ensure_state_initialized();
184
185        for (i, param) in self.params.iter().enumerate() {
186            if !param.requires_grad() {
187                continue;
188            }
189
190            let grad = match param.grad() {
191                Some(g) => g,
192                None => continue,
193            };
194
195            let mut grad_vec = grad.to_vec();
196            let state = &mut self.state[i];
197
198            let param_data = param.data();
199            let mut param_vec = param_data.to_vec();
200
201            // Apply weight decay
202            if self.weight_decay != 0.0 {
203                for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
204                    *g += self.weight_decay * p;
205                }
206            }
207
208            // Update square average
209            for (sq, g) in state.square_avg.iter_mut().zip(grad_vec.iter()) {
210                *sq = self.alpha * *sq + (1.0 - self.alpha) * g * g;
211            }
212
213            // Fused parameter update — no intermediate Vec allocation for denominator
214            let lr = self.lr;
215            let eps = self.eps;
216
217            if self.centered {
218                // Update gradient average for centered RMSprop
219                let grad_avg = state.grad_avg.as_mut().unwrap();
220                if self.momentum == 0.0 {
221                    for i in 0..param_vec.len() {
222                        grad_avg[i] = self.alpha * grad_avg[i] + (1.0 - self.alpha) * grad_vec[i];
223                        let avg = (state.square_avg[i] - grad_avg[i] * grad_avg[i]).sqrt() + eps;
224                        param_vec[i] -= lr * grad_vec[i] / avg;
225                    }
226                } else {
227                    let buf = state.momentum_buffer.as_mut().unwrap();
228                    for i in 0..param_vec.len() {
229                        grad_avg[i] = self.alpha * grad_avg[i] + (1.0 - self.alpha) * grad_vec[i];
230                        let avg = (state.square_avg[i] - grad_avg[i] * grad_avg[i]).sqrt() + eps;
231                        buf[i] = self.momentum * buf[i] + grad_vec[i] / avg;
232                        param_vec[i] -= lr * buf[i];
233                    }
234                }
235            } else if self.momentum == 0.0 {
236                // Without momentum, without centering
237                for i in 0..param_vec.len() {
238                    let avg = state.square_avg[i].sqrt() + eps;
239                    param_vec[i] -= lr * grad_vec[i] / avg;
240                }
241            } else {
242                // With momentum, without centering
243                let buf = state.momentum_buffer.as_mut().unwrap();
244                for i in 0..param_vec.len() {
245                    let avg = state.square_avg[i].sqrt() + eps;
246                    buf[i] = self.momentum * buf[i] + grad_vec[i] / avg;
247                    param_vec[i] -= lr * buf[i];
248                }
249            }
250
251            let mut update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
252            // Preserve device: from_vec creates CPU, move back to param device
253            let device = param_data.device();
254            if device.is_gpu() {
255                update = update.to_device(device).unwrap();
256            }
257            param.update_data(update);
258        }
259    }
260
261    fn zero_grad(&mut self) {
262        for param in &self.params {
263            param.zero_grad();
264        }
265    }
266
267    fn get_lr(&self) -> f32 {
268        self.lr
269    }
270
271    fn set_lr(&mut self, lr: f32) {
272        self.lr = lr;
273    }
274
275    fn parameters(&self) -> &[Parameter] {
276        &self.params
277    }
278}
279
280// =============================================================================
281// Tests
282// =============================================================================
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use axonml_autograd::Variable;
288
289    #[test]
290    fn test_rmsprop_creation() {
291        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
292        let param = Parameter::from_variable(var);
293        let optimizer = RMSprop::new(vec![param], 0.01);
294
295        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
296        assert!((optimizer.alpha - 0.99).abs() < 1e-6);
297    }
298
299    #[test]
300    fn test_rmsprop_step() {
301        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
302        let param = Parameter::from_variable(var);
303
304        // Set gradient
305        param
306            .variable()
307            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
308
309        let mut optimizer = RMSprop::new(vec![param.clone()], 0.01);
310        optimizer.step();
311
312        let new_data = param.data().to_vec();
313        // Parameters should have changed
314        assert!((new_data[0] - 1.0).abs() > 1e-6);
315    }
316
317    #[test]
318    fn test_rmsprop_with_momentum() {
319        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
320        let param = Parameter::from_variable(var);
321
322        let optimizer = RMSprop::new(vec![param], 0.01).momentum(0.9);
323
324        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
325    }
326
327    #[test]
328    fn test_rmsprop_centered() {
329        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
330        let param = Parameter::from_variable(var);
331
332        let optimizer = RMSprop::new(vec![param], 0.01).centered(true);
333
334        assert!(optimizer.centered);
335    }
336
337    #[test]
338    fn test_rmsprop_builder_pattern() {
339        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
340        let param = Parameter::from_variable(var);
341
342        let optimizer = RMSprop::new(vec![param], 0.01)
343            .alpha(0.95)
344            .eps(1e-6)
345            .weight_decay(0.0001)
346            .momentum(0.9)
347            .centered(true);
348
349        assert!((optimizer.alpha - 0.95).abs() < 1e-6);
350        assert!((optimizer.eps - 1e-6).abs() < 1e-9);
351        assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
352        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
353        assert!(optimizer.centered);
354    }
355}