Skip to main content

axonml_optim/
rmsprop.rs

1//! `RMSprop` Optimizer
2//!
3//! Implements `RMSprop` (Root Mean Square Propagation) optimizer.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13// =============================================================================
14// RMSprop
15// =============================================================================
16
17/// `RMSprop` optimizer.
18///
19/// Maintains a moving average of squared gradients to normalize updates.
20///
21/// Update rule:
22/// ```text
23/// v_t = alpha * v_{t-1} + (1 - alpha) * grad^2
24/// param = param - lr * grad / (sqrt(v_t) + eps)
25/// ```
26///
27/// With momentum:
28/// ```text
29/// v_t = alpha * v_{t-1} + (1 - alpha) * grad^2
30/// buf_t = momentum * buf_{t-1} + grad / (sqrt(v_t) + eps)
31/// param = param - lr * buf_t
32/// ```
33pub struct RMSprop {
34    /// Parameters to optimize.
35    params: Vec<Parameter>,
36    /// Learning rate.
37    lr: f32,
38    /// Smoothing constant (decay rate for moving average).
39    alpha: f32,
40    /// Small constant for numerical stability.
41    eps: f32,
42    /// Weight decay (L2 regularization).
43    weight_decay: f32,
44    /// Momentum factor.
45    momentum: f32,
46    /// Whether to center the gradient (subtract mean).
47    centered: bool,
48    /// Per-parameter state.
49    state: Vec<RMSpropState>,
50}
51
52/// State for `RMSprop` optimizer.
53#[derive(Debug, Clone)]
54struct RMSpropState {
55    /// Square average of gradients.
56    square_avg: Vec<f32>,
57    /// Momentum buffer.
58    momentum_buffer: Option<Vec<f32>>,
59    /// Gradient average (for centered `RMSprop`).
60    grad_avg: Option<Vec<f32>>,
61}
62
63impl RMSpropState {
64    fn new(size: usize, momentum: bool, centered: bool) -> Self {
65        Self {
66            square_avg: vec![0.0; size],
67            momentum_buffer: if momentum {
68                Some(vec![0.0; size])
69            } else {
70                None
71            },
72            grad_avg: if centered {
73                Some(vec![0.0; size])
74            } else {
75                None
76            },
77        }
78    }
79}
80
81impl RMSprop {
82    /// Creates a new `RMSprop` optimizer with default settings.
83    #[must_use]
84    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
85        Self {
86            params,
87            lr,
88            alpha: 0.99,
89            eps: 1e-8,
90            weight_decay: 0.0,
91            momentum: 0.0,
92            centered: false,
93            state: Vec::new(),
94        }
95    }
96
97    /// Creates `RMSprop` with specified alpha (smoothing constant).
98    #[must_use]
99    pub fn with_alpha(params: Vec<Parameter>, lr: f32, alpha: f32) -> Self {
100        Self {
101            params,
102            lr,
103            alpha,
104            eps: 1e-8,
105            weight_decay: 0.0,
106            momentum: 0.0,
107            centered: false,
108            state: Vec::new(),
109        }
110    }
111
112    /// Creates `RMSprop` with all options.
113    #[must_use]
114    pub fn with_options(
115        params: Vec<Parameter>,
116        lr: f32,
117        alpha: f32,
118        eps: f32,
119        weight_decay: f32,
120        momentum: f32,
121        centered: bool,
122    ) -> Self {
123        Self {
124            params,
125            lr,
126            alpha,
127            eps,
128            weight_decay,
129            momentum,
130            centered,
131            state: Vec::new(),
132        }
133    }
134
135    /// Builder method to set alpha.
136    #[must_use]
137    pub fn alpha(mut self, alpha: f32) -> Self {
138        self.alpha = alpha;
139        self
140    }
141
142    /// Builder method to set epsilon.
143    #[must_use]
144    pub fn eps(mut self, eps: f32) -> Self {
145        self.eps = eps;
146        self
147    }
148
149    /// Builder method to set weight decay.
150    #[must_use]
151    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
152        self.weight_decay = weight_decay;
153        self
154    }
155
156    /// Builder method to set momentum.
157    #[must_use]
158    pub fn momentum(mut self, momentum: f32) -> Self {
159        self.momentum = momentum;
160        self
161    }
162
163    /// Builder method to enable centered `RMSprop`.
164    #[must_use]
165    pub fn centered(mut self, centered: bool) -> Self {
166        self.centered = centered;
167        self
168    }
169
170    fn ensure_state_initialized(&mut self) {
171        if self.state.is_empty() {
172            self.state = self
173                .params
174                .iter()
175                .map(|p| RMSpropState::new(p.numel(), self.momentum != 0.0, self.centered))
176                .collect();
177        }
178    }
179}
180
181impl Optimizer for RMSprop {
182    fn step(&mut self) {
183        self.ensure_state_initialized();
184
185        for (i, param) in self.params.iter().enumerate() {
186            if !param.requires_grad() {
187                continue;
188            }
189
190            let grad = match param.grad() {
191                Some(g) => g,
192                None => continue,
193            };
194
195            let mut grad_vec = grad.to_vec();
196            let state = &mut self.state[i];
197
198            let param_data = param.data();
199            let mut param_vec = param_data.to_vec();
200
201            // Apply weight decay
202            if self.weight_decay != 0.0 {
203                for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
204                    *g += self.weight_decay * p;
205                }
206            }
207
208            // Update square average
209            for (sq, g) in state.square_avg.iter_mut().zip(grad_vec.iter()) {
210                *sq = self.alpha * *sq + (1.0 - self.alpha) * g * g;
211            }
212
213            // Compute denominator
214            let avg: Vec<f32> = if self.centered {
215                // Update gradient average for centered RMSprop
216                let grad_avg = state.grad_avg.as_mut().unwrap();
217                for (ga, g) in grad_avg.iter_mut().zip(grad_vec.iter()) {
218                    *ga = self.alpha * *ga + (1.0 - self.alpha) * g;
219                }
220                // avg = sqrt(square_avg - grad_avg^2)
221                state
222                    .square_avg
223                    .iter()
224                    .zip(grad_avg.iter())
225                    .map(|(sq, ga)| (sq - ga * ga).sqrt() + self.eps)
226                    .collect()
227            } else {
228                state
229                    .square_avg
230                    .iter()
231                    .map(|sq| sq.sqrt() + self.eps)
232                    .collect()
233            };
234
235            // Update parameters
236            if self.momentum == 0.0 {
237                // Without momentum
238                for ((p, g), a) in param_vec.iter_mut().zip(grad_vec.iter()).zip(avg.iter()) {
239                    *p -= self.lr * g / a;
240                }
241            } else {
242                // With momentum
243                let buf = state.momentum_buffer.as_mut().unwrap();
244                for ((b, g), a) in buf.iter_mut().zip(grad_vec.iter()).zip(avg.iter()) {
245                    *b = self.momentum * *b + g / a;
246                }
247                for (p, b) in param_vec.iter_mut().zip(buf.iter()) {
248                    *p -= self.lr * b;
249                }
250            }
251
252            let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
253            param.update_data(update);
254        }
255    }
256
257    fn zero_grad(&mut self) {
258        for param in &self.params {
259            param.zero_grad();
260        }
261    }
262
263    fn get_lr(&self) -> f32 {
264        self.lr
265    }
266
267    fn set_lr(&mut self, lr: f32) {
268        self.lr = lr;
269    }
270
271    fn parameters(&self) -> &[Parameter] {
272        &self.params
273    }
274}
275
276// =============================================================================
277// Tests
278// =============================================================================
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use axonml_autograd::Variable;
284
285    #[test]
286    fn test_rmsprop_creation() {
287        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
288        let param = Parameter::from_variable(var);
289        let optimizer = RMSprop::new(vec![param], 0.01);
290
291        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
292        assert!((optimizer.alpha - 0.99).abs() < 1e-6);
293    }
294
295    #[test]
296    fn test_rmsprop_step() {
297        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
298        let param = Parameter::from_variable(var);
299
300        // Set gradient
301        param
302            .variable()
303            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
304
305        let mut optimizer = RMSprop::new(vec![param.clone()], 0.01);
306        optimizer.step();
307
308        let new_data = param.data().to_vec();
309        // Parameters should have changed
310        assert!((new_data[0] - 1.0).abs() > 1e-6);
311    }
312
313    #[test]
314    fn test_rmsprop_with_momentum() {
315        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
316        let param = Parameter::from_variable(var);
317
318        let optimizer = RMSprop::new(vec![param], 0.01).momentum(0.9);
319
320        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
321    }
322
323    #[test]
324    fn test_rmsprop_centered() {
325        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
326        let param = Parameter::from_variable(var);
327
328        let optimizer = RMSprop::new(vec![param], 0.01).centered(true);
329
330        assert!(optimizer.centered);
331    }
332
333    #[test]
334    fn test_rmsprop_builder_pattern() {
335        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
336        let param = Parameter::from_variable(var);
337
338        let optimizer = RMSprop::new(vec![param], 0.01)
339            .alpha(0.95)
340            .eps(1e-6)
341            .weight_decay(0.0001)
342            .momentum(0.9)
343            .centered(true);
344
345        assert!((optimizer.alpha - 0.95).abs() < 1e-6);
346        assert!((optimizer.eps - 1e-6).abs() < 1e-9);
347        assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
348        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
349        assert!(optimizer.centered);
350    }
351}