Skip to main content

axonml_optim/
rmsprop.rs

1//! `RMSprop` — root mean square propagation optimizer.
2//!
3//! Exponential moving average of squared gradients for per-parameter
4//! adaptive learning rates. Config: alpha (decay), epsilon, optional
5//! momentum, optional centered mode (subtract mean of squared grads).
6//!
7//! # File
8//! `crates/axonml-optim/src/rmsprop.rs`
9//!
10//! # Author
11//! Andrew Jewell Sr. — AutomataNexus LLC
12//! ORCID: 0009-0005-2158-7060
13//!
14//! # Updated
15//! April 14, 2026 11:15 PM EST
16//!
17//! # Disclaimer
18//! Use at own risk. This software is provided "as is", without warranty of any
19//! kind, express or implied. The author and AutomataNexus shall not be held
20//! liable for any damages arising from the use of this software.
21
22use axonml_nn::Parameter;
23use axonml_tensor::Tensor;
24
25use crate::optimizer::Optimizer;
26
27// Re-import Device for state initialization
28use axonml_core;
29
30// =============================================================================
31// RMSprop
32// =============================================================================
33
34/// `RMSprop` optimizer.
35///
36/// Maintains a moving average of squared gradients to normalize updates.
37///
38/// Update rule:
39/// ```text
40/// v_t = alpha * v_{t-1} + (1 - alpha) * grad^2
41/// param = param - lr * grad / (sqrt(v_t) + eps)
42/// ```
43///
44/// With momentum:
45/// ```text
46/// v_t = alpha * v_{t-1} + (1 - alpha) * grad^2
47/// buf_t = momentum * buf_{t-1} + grad / (sqrt(v_t) + eps)
48/// param = param - lr * buf_t
49/// ```
50pub struct RMSprop {
51    /// Parameters to optimize.
52    params: Vec<Parameter>,
53    /// Learning rate.
54    lr: f32,
55    /// Smoothing constant (decay rate for moving average).
56    alpha: f32,
57    /// Small constant for numerical stability.
58    eps: f32,
59    /// Weight decay (L2 regularization).
60    weight_decay: f32,
61    /// Momentum factor.
62    momentum: f32,
63    /// Whether to center the gradient (subtract mean).
64    centered: bool,
65    /// Per-parameter state.
66    state: Vec<RMSpropState>,
67}
68
69/// Tensor-based state for `RMSprop` optimizer.
70///
71/// All buffers are stored as `Tensor<f32>` so they stay GPU-resident when
72/// parameters are on GPU, avoiding round-trip copies through `to_vec()`.
73#[derive(Debug, Clone)]
74struct RMSpropState {
75    /// Square average of gradients.
76    square_avg: Tensor<f32>,
77    /// Momentum buffer.
78    momentum_buffer: Option<Tensor<f32>>,
79    /// Gradient average (for centered `RMSprop`).
80    grad_avg: Option<Tensor<f32>>,
81}
82
83impl RMSpropState {
84    fn new(shape: &[usize], device: axonml_core::Device, momentum: bool, centered: bool) -> Self {
85        let square_avg = {
86            let t = Tensor::zeros(shape);
87            if device.is_gpu() {
88                t.to_device(device).expect("device transfer failed")
89            } else {
90                t
91            }
92        };
93        let momentum_buffer = if momentum {
94            let t = Tensor::zeros(shape);
95            Some(if device.is_gpu() {
96                t.to_device(device).expect("device transfer failed")
97            } else {
98                t
99            })
100        } else {
101            None
102        };
103        let grad_avg = if centered {
104            let t = Tensor::zeros(shape);
105            Some(if device.is_gpu() {
106                t.to_device(device).expect("device transfer failed")
107            } else {
108                t
109            })
110        } else {
111            None
112        };
113        Self {
114            square_avg,
115            momentum_buffer,
116            grad_avg,
117        }
118    }
119}
120
121impl RMSprop {
122    /// Creates a new `RMSprop` optimizer with default settings.
123    #[must_use]
124    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
125        Self {
126            params,
127            lr,
128            alpha: 0.99,
129            eps: 1e-8,
130            weight_decay: 0.0,
131            momentum: 0.0,
132            centered: false,
133            state: Vec::new(),
134        }
135    }
136
137    /// Creates `RMSprop` with specified alpha (smoothing constant).
138    #[must_use]
139    pub fn with_alpha(params: Vec<Parameter>, lr: f32, alpha: f32) -> Self {
140        Self {
141            params,
142            lr,
143            alpha,
144            eps: 1e-8,
145            weight_decay: 0.0,
146            momentum: 0.0,
147            centered: false,
148            state: Vec::new(),
149        }
150    }
151
152    /// Creates `RMSprop` with all options.
153    #[must_use]
154    pub fn with_options(
155        params: Vec<Parameter>,
156        lr: f32,
157        alpha: f32,
158        eps: f32,
159        weight_decay: f32,
160        momentum: f32,
161        centered: bool,
162    ) -> Self {
163        Self {
164            params,
165            lr,
166            alpha,
167            eps,
168            weight_decay,
169            momentum,
170            centered,
171            state: Vec::new(),
172        }
173    }
174
175    /// Builder method to set alpha.
176    #[must_use]
177    pub fn alpha(mut self, alpha: f32) -> Self {
178        self.alpha = alpha;
179        self
180    }
181
182    /// Builder method to set epsilon.
183    #[must_use]
184    pub fn eps(mut self, eps: f32) -> Self {
185        self.eps = eps;
186        self
187    }
188
189    /// Builder method to set weight decay.
190    #[must_use]
191    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
192        self.weight_decay = weight_decay;
193        self
194    }
195
196    /// Builder method to set momentum.
197    #[must_use]
198    pub fn momentum(mut self, momentum: f32) -> Self {
199        self.momentum = momentum;
200        self
201    }
202
203    /// Builder method to enable centered `RMSprop`.
204    #[must_use]
205    pub fn centered(mut self, centered: bool) -> Self {
206        self.centered = centered;
207        self
208    }
209
210    fn ensure_state_initialized(&mut self) {
211        if self.state.is_empty() {
212            self.state = self
213                .params
214                .iter()
215                .map(|p| {
216                    let data = p.data();
217                    RMSpropState::new(
218                        data.shape(),
219                        data.device(),
220                        self.momentum != 0.0,
221                        self.centered,
222                    )
223                })
224                .collect();
225        }
226    }
227}
228
229impl Optimizer for RMSprop {
230    fn step(&mut self) {
231        self.ensure_state_initialized();
232
233        // ============================================================
234        // Tensor-op path: works on both CPU and GPU without to_vec()
235        // All ops (add, mul, mul_scalar, div, sqrt, add_scalar, sub)
236        // dispatch to CUDA when the tensors are GPU-resident.
237        // ============================================================
238
239        for (i, param) in self.params.iter().enumerate() {
240            if !param.requires_grad() {
241                continue;
242            }
243
244            let grad = match param.grad() {
245                Some(g) => g,
246                None => continue,
247            };
248
249            let param_data = param.data();
250            let state = &mut self.state[i];
251
252            // Apply weight decay: d = grad + weight_decay * param
253            let d = if self.weight_decay == 0.0 {
254                grad.clone()
255            } else {
256                grad.add(&param_data.mul_scalar(self.weight_decay)).unwrap()
257            };
258
259            // Update square average: sq_avg = alpha * sq_avg + (1 - alpha) * d^2
260            let d_sq = d.mul(&d).unwrap();
261            state.square_avg = state
262                .square_avg
263                .mul_scalar(self.alpha)
264                .add(&d_sq.mul_scalar(1.0 - self.alpha))
265                .unwrap();
266
267            // Compute denominator
268            let denom = if self.centered {
269                // Update gradient average: grad_avg = alpha * grad_avg + (1 - alpha) * d
270                let grad_avg = state.grad_avg.as_mut().unwrap();
271                *grad_avg = grad_avg
272                    .mul_scalar(self.alpha)
273                    .add(&d.mul_scalar(1.0 - self.alpha))
274                    .unwrap();
275
276                // denom = sqrt(sq_avg - grad_avg^2) + eps
277                let ga_sq = grad_avg.mul(grad_avg).unwrap();
278                state
279                    .square_avg
280                    .sub(&ga_sq)
281                    .unwrap()
282                    .sqrt()
283                    .add_scalar(self.eps)
284            } else {
285                // denom = sqrt(sq_avg) + eps
286                state.square_avg.sqrt().add_scalar(self.eps)
287            };
288
289            // Apply update with or without momentum
290            let update = if self.momentum == 0.0 {
291                // update = d / denom
292                d.div(&denom).unwrap()
293            } else {
294                // buf = momentum * buf + d / denom
295                let normalized = d.div(&denom).unwrap();
296                let buf = state.momentum_buffer.as_mut().unwrap();
297                *buf = buf.mul_scalar(self.momentum).add(&normalized).unwrap();
298                buf.clone()
299            };
300
301            // param = param - lr * update
302            let new_param = param_data.sub(&update.mul_scalar(self.lr)).unwrap();
303            param.update_data(new_param);
304        }
305    }
306
307    fn zero_grad(&mut self) {
308        for param in &self.params {
309            param.zero_grad();
310        }
311    }
312
313    fn get_lr(&self) -> f32 {
314        self.lr
315    }
316
317    fn set_lr(&mut self, lr: f32) {
318        self.lr = lr;
319    }
320
321    fn parameters(&self) -> &[Parameter] {
322        &self.params
323    }
324}
325
326// =============================================================================
327// Tests
328// =============================================================================
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use axonml_autograd::Variable;
334
335    #[test]
336    fn test_rmsprop_creation() {
337        let var = Variable::new(
338            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
339            true,
340        );
341        let param = Parameter::from_variable(var);
342        let optimizer = RMSprop::new(vec![param], 0.01);
343
344        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
345        assert!((optimizer.alpha - 0.99).abs() < 1e-6);
346    }
347
348    #[test]
349    fn test_rmsprop_step() {
350        let var = Variable::new(
351            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
352            true,
353        );
354        let param = Parameter::from_variable(var);
355
356        // Set gradient
357        param
358            .variable()
359            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
360
361        let mut optimizer = RMSprop::new(vec![param.clone()], 0.01);
362        optimizer.step();
363
364        let new_data = param.data().to_vec();
365        // Parameters should have changed
366        assert!((new_data[0] - 1.0).abs() > 1e-6);
367    }
368
369    #[test]
370    fn test_rmsprop_with_momentum() {
371        let var = Variable::new(
372            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
373            true,
374        );
375        let param = Parameter::from_variable(var);
376
377        let optimizer = RMSprop::new(vec![param], 0.01).momentum(0.9);
378
379        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
380    }
381
382    #[test]
383    fn test_rmsprop_centered() {
384        let var = Variable::new(
385            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
386            true,
387        );
388        let param = Parameter::from_variable(var);
389
390        let optimizer = RMSprop::new(vec![param], 0.01).centered(true);
391
392        assert!(optimizer.centered);
393    }
394
395    #[test]
396    fn test_rmsprop_builder_pattern() {
397        let var = Variable::new(
398            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
399            true,
400        );
401        let param = Parameter::from_variable(var);
402
403        let optimizer = RMSprop::new(vec![param], 0.01)
404            .alpha(0.95)
405            .eps(1e-6)
406            .weight_decay(0.0001)
407            .momentum(0.9)
408            .centered(true);
409
410        assert!((optimizer.alpha - 0.95).abs() < 1e-6);
411        assert!((optimizer.eps - 1e-6).abs() < 1e-9);
412        assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
413        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
414        assert!(optimizer.centered);
415    }
416}