Skip to main content

axonml_optim/
rmsprop.rs

1//! `RMSprop` Optimizer
2//!
3//! # File
4//! `crates/axonml-optim/src/rmsprop.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_nn::Parameter;
18use axonml_tensor::Tensor;
19
20use crate::optimizer::Optimizer;
21
22// Re-import Device for state initialization
23use axonml_core;
24
25// =============================================================================
26// RMSprop
27// =============================================================================
28
29/// `RMSprop` optimizer.
30///
31/// Maintains a moving average of squared gradients to normalize updates.
32///
33/// Update rule:
34/// ```text
35/// v_t = alpha * v_{t-1} + (1 - alpha) * grad^2
36/// param = param - lr * grad / (sqrt(v_t) + eps)
37/// ```
38///
39/// With momentum:
40/// ```text
41/// v_t = alpha * v_{t-1} + (1 - alpha) * grad^2
42/// buf_t = momentum * buf_{t-1} + grad / (sqrt(v_t) + eps)
43/// param = param - lr * buf_t
44/// ```
45pub struct RMSprop {
46    /// Parameters to optimize.
47    params: Vec<Parameter>,
48    /// Learning rate.
49    lr: f32,
50    /// Smoothing constant (decay rate for moving average).
51    alpha: f32,
52    /// Small constant for numerical stability.
53    eps: f32,
54    /// Weight decay (L2 regularization).
55    weight_decay: f32,
56    /// Momentum factor.
57    momentum: f32,
58    /// Whether to center the gradient (subtract mean).
59    centered: bool,
60    /// Per-parameter state.
61    state: Vec<RMSpropState>,
62}
63
64/// Tensor-based state for `RMSprop` optimizer.
65///
66/// All buffers are stored as `Tensor<f32>` so they stay GPU-resident when
67/// parameters are on GPU, avoiding round-trip copies through `to_vec()`.
68#[derive(Debug, Clone)]
69struct RMSpropState {
70    /// Square average of gradients.
71    square_avg: Tensor<f32>,
72    /// Momentum buffer.
73    momentum_buffer: Option<Tensor<f32>>,
74    /// Gradient average (for centered `RMSprop`).
75    grad_avg: Option<Tensor<f32>>,
76}
77
78impl RMSpropState {
79    fn new(shape: &[usize], device: axonml_core::Device, momentum: bool, centered: bool) -> Self {
80        let square_avg = {
81            let t = Tensor::zeros(shape);
82            if device.is_gpu() {
83                t.to_device(device).unwrap()
84            } else {
85                t
86            }
87        };
88        let momentum_buffer = if momentum {
89            let t = Tensor::zeros(shape);
90            Some(if device.is_gpu() {
91                t.to_device(device).unwrap()
92            } else {
93                t
94            })
95        } else {
96            None
97        };
98        let grad_avg = if centered {
99            let t = Tensor::zeros(shape);
100            Some(if device.is_gpu() {
101                t.to_device(device).unwrap()
102            } else {
103                t
104            })
105        } else {
106            None
107        };
108        Self {
109            square_avg,
110            momentum_buffer,
111            grad_avg,
112        }
113    }
114}
115
116impl RMSprop {
117    /// Creates a new `RMSprop` optimizer with default settings.
118    #[must_use]
119    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
120        Self {
121            params,
122            lr,
123            alpha: 0.99,
124            eps: 1e-8,
125            weight_decay: 0.0,
126            momentum: 0.0,
127            centered: false,
128            state: Vec::new(),
129        }
130    }
131
132    /// Creates `RMSprop` with specified alpha (smoothing constant).
133    #[must_use]
134    pub fn with_alpha(params: Vec<Parameter>, lr: f32, alpha: f32) -> Self {
135        Self {
136            params,
137            lr,
138            alpha,
139            eps: 1e-8,
140            weight_decay: 0.0,
141            momentum: 0.0,
142            centered: false,
143            state: Vec::new(),
144        }
145    }
146
147    /// Creates `RMSprop` with all options.
148    #[must_use]
149    pub fn with_options(
150        params: Vec<Parameter>,
151        lr: f32,
152        alpha: f32,
153        eps: f32,
154        weight_decay: f32,
155        momentum: f32,
156        centered: bool,
157    ) -> Self {
158        Self {
159            params,
160            lr,
161            alpha,
162            eps,
163            weight_decay,
164            momentum,
165            centered,
166            state: Vec::new(),
167        }
168    }
169
170    /// Builder method to set alpha.
171    #[must_use]
172    pub fn alpha(mut self, alpha: f32) -> Self {
173        self.alpha = alpha;
174        self
175    }
176
177    /// Builder method to set epsilon.
178    #[must_use]
179    pub fn eps(mut self, eps: f32) -> Self {
180        self.eps = eps;
181        self
182    }
183
184    /// Builder method to set weight decay.
185    #[must_use]
186    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
187        self.weight_decay = weight_decay;
188        self
189    }
190
191    /// Builder method to set momentum.
192    #[must_use]
193    pub fn momentum(mut self, momentum: f32) -> Self {
194        self.momentum = momentum;
195        self
196    }
197
198    /// Builder method to enable centered `RMSprop`.
199    #[must_use]
200    pub fn centered(mut self, centered: bool) -> Self {
201        self.centered = centered;
202        self
203    }
204
205    fn ensure_state_initialized(&mut self) {
206        if self.state.is_empty() {
207            self.state = self
208                .params
209                .iter()
210                .map(|p| {
211                    let data = p.data();
212                    RMSpropState::new(
213                        data.shape(),
214                        data.device(),
215                        self.momentum != 0.0,
216                        self.centered,
217                    )
218                })
219                .collect();
220        }
221    }
222}
223
224impl Optimizer for RMSprop {
225    fn step(&mut self) {
226        self.ensure_state_initialized();
227
228        // ============================================================
229        // Tensor-op path: works on both CPU and GPU without to_vec()
230        // All ops (add, mul, mul_scalar, div, sqrt, add_scalar, sub)
231        // dispatch to CUDA when the tensors are GPU-resident.
232        // ============================================================
233
234        for (i, param) in self.params.iter().enumerate() {
235            if !param.requires_grad() {
236                continue;
237            }
238
239            let grad = match param.grad() {
240                Some(g) => g,
241                None => continue,
242            };
243
244            let param_data = param.data();
245            let state = &mut self.state[i];
246
247            // Apply weight decay: d = grad + weight_decay * param
248            let d = if self.weight_decay == 0.0 {
249                grad.clone()
250            } else {
251                grad.add(&param_data.mul_scalar(self.weight_decay)).unwrap()
252            };
253
254            // Update square average: sq_avg = alpha * sq_avg + (1 - alpha) * d^2
255            let d_sq = d.mul(&d).unwrap();
256            state.square_avg = state
257                .square_avg
258                .mul_scalar(self.alpha)
259                .add(&d_sq.mul_scalar(1.0 - self.alpha))
260                .unwrap();
261
262            // Compute denominator
263            let denom = if self.centered {
264                // Update gradient average: grad_avg = alpha * grad_avg + (1 - alpha) * d
265                let grad_avg = state.grad_avg.as_mut().unwrap();
266                *grad_avg = grad_avg
267                    .mul_scalar(self.alpha)
268                    .add(&d.mul_scalar(1.0 - self.alpha))
269                    .unwrap();
270
271                // denom = sqrt(sq_avg - grad_avg^2) + eps
272                let ga_sq = grad_avg.mul(grad_avg).unwrap();
273                state
274                    .square_avg
275                    .sub(&ga_sq)
276                    .unwrap()
277                    .sqrt()
278                    .add_scalar(self.eps)
279            } else {
280                // denom = sqrt(sq_avg) + eps
281                state.square_avg.sqrt().add_scalar(self.eps)
282            };
283
284            // Apply update with or without momentum
285            let update = if self.momentum == 0.0 {
286                // update = d / denom
287                d.div(&denom).unwrap()
288            } else {
289                // buf = momentum * buf + d / denom
290                let normalized = d.div(&denom).unwrap();
291                let buf = state.momentum_buffer.as_mut().unwrap();
292                *buf = buf.mul_scalar(self.momentum).add(&normalized).unwrap();
293                buf.clone()
294            };
295
296            // param = param - lr * update
297            let new_param = param_data.sub(&update.mul_scalar(self.lr)).unwrap();
298            param.update_data(new_param);
299        }
300    }
301
302    fn zero_grad(&mut self) {
303        for param in &self.params {
304            param.zero_grad();
305        }
306    }
307
308    fn get_lr(&self) -> f32 {
309        self.lr
310    }
311
312    fn set_lr(&mut self, lr: f32) {
313        self.lr = lr;
314    }
315
316    fn parameters(&self) -> &[Parameter] {
317        &self.params
318    }
319}
320
321// =============================================================================
322// Tests
323// =============================================================================
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use axonml_autograd::Variable;
329
330    #[test]
331    fn test_rmsprop_creation() {
332        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
333        let param = Parameter::from_variable(var);
334        let optimizer = RMSprop::new(vec![param], 0.01);
335
336        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
337        assert!((optimizer.alpha - 0.99).abs() < 1e-6);
338    }
339
340    #[test]
341    fn test_rmsprop_step() {
342        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
343        let param = Parameter::from_variable(var);
344
345        // Set gradient
346        param
347            .variable()
348            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
349
350        let mut optimizer = RMSprop::new(vec![param.clone()], 0.01);
351        optimizer.step();
352
353        let new_data = param.data().to_vec();
354        // Parameters should have changed
355        assert!((new_data[0] - 1.0).abs() > 1e-6);
356    }
357
358    #[test]
359    fn test_rmsprop_with_momentum() {
360        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
361        let param = Parameter::from_variable(var);
362
363        let optimizer = RMSprop::new(vec![param], 0.01).momentum(0.9);
364
365        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
366    }
367
368    #[test]
369    fn test_rmsprop_centered() {
370        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
371        let param = Parameter::from_variable(var);
372
373        let optimizer = RMSprop::new(vec![param], 0.01).centered(true);
374
375        assert!(optimizer.centered);
376    }
377
378    #[test]
379    fn test_rmsprop_builder_pattern() {
380        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
381        let param = Parameter::from_variable(var);
382
383        let optimizer = RMSprop::new(vec![param], 0.01)
384            .alpha(0.95)
385            .eps(1e-6)
386            .weight_decay(0.0001)
387            .momentum(0.9)
388            .centered(true);
389
390        assert!((optimizer.alpha - 0.95).abs() < 1e-6);
391        assert!((optimizer.eps - 1e-6).abs() < 1e-9);
392        assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
393        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
394        assert!(optimizer.centered);
395    }
396}