axonml_optim/
rmsprop.rs

1//! `RMSprop` Optimizer
2//!
3//! Implements `RMSprop` (Root Mean Square Propagation) optimizer.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13// =============================================================================
14// RMSprop
15// =============================================================================
16
17/// `RMSprop` optimizer.
18///
19/// Maintains a moving average of squared gradients to normalize updates.
20///
21/// Update rule:
22/// ```text
23/// v_t = alpha * v_{t-1} + (1 - alpha) * grad^2
24/// param = param - lr * grad / (sqrt(v_t) + eps)
25/// ```
26///
27/// With momentum:
28/// ```text
29/// v_t = alpha * v_{t-1} + (1 - alpha) * grad^2
30/// buf_t = momentum * buf_{t-1} + grad / (sqrt(v_t) + eps)
31/// param = param - lr * buf_t
32/// ```
33pub struct RMSprop {
34    /// Parameters to optimize.
35    params: Vec<Parameter>,
36    /// Learning rate.
37    lr: f32,
38    /// Smoothing constant (decay rate for moving average).
39    alpha: f32,
40    /// Small constant for numerical stability.
41    eps: f32,
42    /// Weight decay (L2 regularization).
43    weight_decay: f32,
44    /// Momentum factor.
45    momentum: f32,
46    /// Whether to center the gradient (subtract mean).
47    centered: bool,
48    /// Per-parameter state.
49    state: Vec<RMSpropState>,
50}
51
52/// State for `RMSprop` optimizer.
53#[derive(Debug, Clone)]
54struct RMSpropState {
55    /// Square average of gradients.
56    square_avg: Vec<f32>,
57    /// Momentum buffer.
58    momentum_buffer: Option<Vec<f32>>,
59    /// Gradient average (for centered `RMSprop`).
60    grad_avg: Option<Vec<f32>>,
61}
62
63impl RMSpropState {
64    fn new(size: usize, momentum: bool, centered: bool) -> Self {
65        Self {
66            square_avg: vec![0.0; size],
67            momentum_buffer: if momentum {
68                Some(vec![0.0; size])
69            } else {
70                None
71            },
72            grad_avg: if centered {
73                Some(vec![0.0; size])
74            } else {
75                None
76            },
77        }
78    }
79}
80
81impl RMSprop {
82    /// Creates a new `RMSprop` optimizer with default settings.
83    #[must_use] pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
84        Self {
85            params,
86            lr,
87            alpha: 0.99,
88            eps: 1e-8,
89            weight_decay: 0.0,
90            momentum: 0.0,
91            centered: false,
92            state: Vec::new(),
93        }
94    }
95
96    /// Creates `RMSprop` with specified alpha (smoothing constant).
97    #[must_use] pub fn with_alpha(params: Vec<Parameter>, lr: f32, alpha: f32) -> Self {
98        Self {
99            params,
100            lr,
101            alpha,
102            eps: 1e-8,
103            weight_decay: 0.0,
104            momentum: 0.0,
105            centered: false,
106            state: Vec::new(),
107        }
108    }
109
110    /// Creates `RMSprop` with all options.
111    #[must_use] pub fn with_options(
112        params: Vec<Parameter>,
113        lr: f32,
114        alpha: f32,
115        eps: f32,
116        weight_decay: f32,
117        momentum: f32,
118        centered: bool,
119    ) -> Self {
120        Self {
121            params,
122            lr,
123            alpha,
124            eps,
125            weight_decay,
126            momentum,
127            centered,
128            state: Vec::new(),
129        }
130    }
131
132    /// Builder method to set alpha.
133    #[must_use] pub fn alpha(mut self, alpha: f32) -> Self {
134        self.alpha = alpha;
135        self
136    }
137
138    /// Builder method to set epsilon.
139    #[must_use] pub fn eps(mut self, eps: f32) -> Self {
140        self.eps = eps;
141        self
142    }
143
144    /// Builder method to set weight decay.
145    #[must_use] pub fn weight_decay(mut self, weight_decay: f32) -> Self {
146        self.weight_decay = weight_decay;
147        self
148    }
149
150    /// Builder method to set momentum.
151    #[must_use] pub fn momentum(mut self, momentum: f32) -> Self {
152        self.momentum = momentum;
153        self
154    }
155
156    /// Builder method to enable centered `RMSprop`.
157    #[must_use] pub fn centered(mut self, centered: bool) -> Self {
158        self.centered = centered;
159        self
160    }
161
162    fn ensure_state_initialized(&mut self) {
163        if self.state.is_empty() {
164            self.state = self
165                .params
166                .iter()
167                .map(|p| RMSpropState::new(p.numel(), self.momentum != 0.0, self.centered))
168                .collect();
169        }
170    }
171}
172
173impl Optimizer for RMSprop {
174    fn step(&mut self) {
175        self.ensure_state_initialized();
176
177        for (i, param) in self.params.iter().enumerate() {
178            if !param.requires_grad() {
179                continue;
180            }
181
182            let grad = match param.grad() {
183                Some(g) => g,
184                None => continue,
185            };
186
187            let mut grad_vec = grad.to_vec();
188            let state = &mut self.state[i];
189
190            let param_data = param.data();
191            let mut param_vec = param_data.to_vec();
192
193            // Apply weight decay
194            if self.weight_decay != 0.0 {
195                for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
196                    *g += self.weight_decay * p;
197                }
198            }
199
200            // Update square average
201            for (sq, g) in state.square_avg.iter_mut().zip(grad_vec.iter()) {
202                *sq = self.alpha * *sq + (1.0 - self.alpha) * g * g;
203            }
204
205            // Compute denominator
206            let avg: Vec<f32> = if self.centered {
207                // Update gradient average for centered RMSprop
208                let grad_avg = state.grad_avg.as_mut().unwrap();
209                for (ga, g) in grad_avg.iter_mut().zip(grad_vec.iter()) {
210                    *ga = self.alpha * *ga + (1.0 - self.alpha) * g;
211                }
212                // avg = sqrt(square_avg - grad_avg^2)
213                state
214                    .square_avg
215                    .iter()
216                    .zip(grad_avg.iter())
217                    .map(|(sq, ga)| (sq - ga * ga).sqrt() + self.eps)
218                    .collect()
219            } else {
220                state
221                    .square_avg
222                    .iter()
223                    .map(|sq| sq.sqrt() + self.eps)
224                    .collect()
225            };
226
227            // Update parameters
228            if self.momentum == 0.0 {
229                // Without momentum
230                for ((p, g), a) in param_vec.iter_mut().zip(grad_vec.iter()).zip(avg.iter()) {
231                    *p -= self.lr * g / a;
232                }
233            } else {
234                // With momentum
235                let buf = state.momentum_buffer.as_mut().unwrap();
236                for ((b, g), a) in buf.iter_mut().zip(grad_vec.iter()).zip(avg.iter()) {
237                    *b = self.momentum * *b + g / a;
238                }
239                for (p, b) in param_vec.iter_mut().zip(buf.iter()) {
240                    *p -= self.lr * b;
241                }
242            }
243
244            let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
245            param.update_data(update);
246        }
247    }
248
249    fn zero_grad(&mut self) {
250        for param in &self.params {
251            param.zero_grad();
252        }
253    }
254
255    fn get_lr(&self) -> f32 {
256        self.lr
257    }
258
259    fn set_lr(&mut self, lr: f32) {
260        self.lr = lr;
261    }
262
263    fn parameters(&self) -> &[Parameter] {
264        &self.params
265    }
266}
267
268// =============================================================================
269// Tests
270// =============================================================================
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use axonml_autograd::Variable;
276
277    #[test]
278    fn test_rmsprop_creation() {
279        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
280        let param = Parameter::from_variable(var);
281        let optimizer = RMSprop::new(vec![param], 0.01);
282
283        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
284        assert!((optimizer.alpha - 0.99).abs() < 1e-6);
285    }
286
287    #[test]
288    fn test_rmsprop_step() {
289        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
290        let param = Parameter::from_variable(var);
291
292        // Set gradient
293        param
294            .variable()
295            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
296
297        let mut optimizer = RMSprop::new(vec![param.clone()], 0.01);
298        optimizer.step();
299
300        let new_data = param.data().to_vec();
301        // Parameters should have changed
302        assert!((new_data[0] - 1.0).abs() > 1e-6);
303    }
304
305    #[test]
306    fn test_rmsprop_with_momentum() {
307        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
308        let param = Parameter::from_variable(var);
309
310        let optimizer = RMSprop::new(vec![param], 0.01).momentum(0.9);
311
312        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
313    }
314
315    #[test]
316    fn test_rmsprop_centered() {
317        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
318        let param = Parameter::from_variable(var);
319
320        let optimizer = RMSprop::new(vec![param], 0.01).centered(true);
321
322        assert!(optimizer.centered);
323    }
324
325    #[test]
326    fn test_rmsprop_builder_pattern() {
327        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
328        let param = Parameter::from_variable(var);
329
330        let optimizer = RMSprop::new(vec![param], 0.01)
331            .alpha(0.95)
332            .eps(1e-6)
333            .weight_decay(0.0001)
334            .momentum(0.9)
335            .centered(true);
336
337        assert!((optimizer.alpha - 0.95).abs() < 1e-6);
338        assert!((optimizer.eps - 1e-6).abs() < 1e-9);
339        assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
340        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
341        assert!(optimizer.centered);
342    }
343}