Skip to main content

burn_optim/optim/
rmsprop.rs

1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use super::{
6    SimpleOptimizer,
7    adaptor::OptimizerAdaptor,
8    decay::{WeightDecay, WeightDecayConfig},
9};
10use crate::{LearningRate, grad_clipping::GradientClippingConfig};
11
12use burn::config::Config;
13use burn::tensor::backend::Backend;
14use burn::tensor::{Tensor, backend::AutodiffBackend, ops::Device};
15
16/// Configuration to create the [RmsProp](RmsProp) optimizer.
17#[derive(Config, Debug)]
18pub struct RmsPropConfig {
19    /// Smoothing constant.
20    #[config(default = 0.99)]
21    alpha: f32,
22    /// momentum for RmsProp.
23    #[config(default = 0.9)]
24    momentum: f32,
25    /// A value required for numerical stability.
26    #[config(default = 1e-5)]
27    epsilon: f32,
28    /// if True, compute the centered RmsProp, the gradient is normalized by an estimation of its variance
29    #[config(default = false)]
30    centered: bool,
31    /// [Weight decay](WeightDecayConfig) config.
32    weight_decay: Option<WeightDecayConfig>,
33    /// [Gradient Clipping](GradientClippingConfig) config.
34    grad_clipping: Option<GradientClippingConfig>,
35}
36
37impl RmsPropConfig {
38    /// Build a [`RmsProp`] from the config.
39    pub fn build(&self) -> RmsProp {
40        let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
41        RmsProp {
42            alpha: self.alpha,
43            centered: self.centered,
44            weight_decay,
45            momentum: RmsPropMomentum {
46                momentum: self.momentum,
47                epsilon: self.epsilon,
48            },
49        }
50    }
51
52    /// Initialize RmsProp optimizer.
53    ///
54    /// # Returns
55    ///
56    /// Returns an optimizer that can be used to optimize a module.
57    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
58        &self,
59    ) -> OptimizerAdaptor<RmsProp, M, B> {
60        let mut optim = OptimizerAdaptor::from(self.build());
61        if let Some(config) = &self.grad_clipping {
62            optim = optim.with_grad_clipping(config.init());
63        }
64
65        optim
66    }
67}
68
69/// Optimizer that implements stochastic gradient descent with momentum.
70/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
71#[derive(Clone)]
72pub struct RmsProp {
73    alpha: f32,
74    // epsilon: f32,
75    centered: bool,
76    // momentum: Option<Momentum<B>>,
77    momentum: RmsPropMomentum,
78    weight_decay: Option<WeightDecay>,
79}
80
81impl<B: Backend> SimpleOptimizer<B> for RmsProp {
82    type State<const D: usize> = RmsPropState<B, D>;
83
84    fn step<const D: usize>(
85        &self,
86        lr: LearningRate,
87        tensor: Tensor<B, D>,
88        mut grad: Tensor<B, D>,
89        state: Option<Self::State<D>>,
90    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
91        // fetch state for params
92        let mut state_square_avg = None;
93        let mut state_centered = None;
94        let mut state_momentum = None;
95        if let Some(state) = state {
96            state_square_avg = Some(state.square_avg);
97            state_centered = Some(state.centered);
98            state_momentum = state.momentum;
99        }
100
101        // weight_decay transform
102        if let Some(weight_decay) = &self.weight_decay {
103            grad = weight_decay.transform(grad, tensor.clone());
104        }
105
106        // square_avg transform
107        let (grad, state_square_avg) =
108            SquareAvgState::transform(self.alpha, grad, state_square_avg);
109
110        // centered transform
111        let (grad, state_square_avg, state_centered) = CenteredState::transform(
112            self.alpha,
113            self.centered,
114            grad,
115            state_square_avg,
116            state_centered,
117        );
118
119        // momentum transform
120        let (grad, state_centered, state_momentum) =
121            self.momentum
122                .transform(grad, state_centered, state_momentum);
123
124        // transition state
125        let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);
126
127        // tensor param transform
128        let delta = grad.mul_scalar(lr);
129        (tensor - delta, Some(state))
130    }
131
132    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
133        state.square_avg = state.square_avg.to_device(device);
134        state.centered = state.centered.to_device(device);
135        state.momentum = state.momentum.map(|momentum| momentum.to_device(device));
136        state
137    }
138}
139
140/// State of [RmsProp](RmsProp)
141#[derive(Record, Clone, new)]
142pub struct RmsPropState<B: Backend, const D: usize> {
143    /// Current squared average state.
144    pub square_avg: SquareAvgState<B, D>,
145    /// Current centered state
146    pub centered: CenteredState<B, D>,
147    /// Current gradient momentum, if any.
148    pub momentum: Option<RmsPropMomentumState<B, D>>,
149}
150
151/// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params.
152#[derive(Record, Clone, new)]
153pub struct SquareAvgState<B: Backend, const D: usize> {
154    /// Current squared average.
155    pub square_avg: Tensor<B, D>,
156}
157
158impl<B: Backend, const D: usize> SquareAvgState<B, D> {
159    /// transform [SquareAvgState] to the next step
160    fn transform(alpha: f32, grad: Tensor<B, D>, state: Option<Self>) -> (Tensor<B, D>, Self) {
161        match state {
162            Some(state) => {
163                let square_avg = state
164                    .square_avg
165                    .mul_scalar(alpha)
166                    .add(grad.clone().square().mul_scalar(1. - alpha));
167                (grad, Self { square_avg })
168            }
169            _ => {
170                let square_avg = grad.clone().square().mul_scalar(1. - alpha);
171                (grad, Self { square_avg })
172            }
173        }
174    }
175
176    /// Moves the state to a device.
177    ///
178    /// # Arguments
179    ///
180    /// * `device` - Device to move the state to.
181    ///
182    /// # Returns
183    ///
184    /// * `self` - Moved state.
185    pub fn to_device(mut self, device: &B::Device) -> Self {
186        self.square_avg = self.square_avg.to_device(device);
187        self
188    }
189}
190
191/// [CenteredState](CenteredState) is to store and pass optimizer step params.
192#[derive(Record, Clone, new)]
193pub struct CenteredState<B: Backend, const D: usize> {
194    /// The averaged gradient to calculate the centered gradient, if available.
195    pub grad_avg: Option<Tensor<B, D>>,
196    /// The current average value.
197    pub avg: Tensor<B, D>,
198}
199
200impl<B: Backend, const D: usize> CenteredState<B, D> {
201    /// transform [CenteredState] to the next step
202    fn transform(
203        alpha: f32,
204        centered: bool,
205        grad: Tensor<B, D>,
206        square_avg_state: SquareAvgState<B, D>,
207        centered_state: Option<Self>,
208    ) -> (Tensor<B, D>, SquareAvgState<B, D>, Self) {
209        if centered {
210            let grad_avg_constant = grad.clone().mul_scalar(1. - alpha);
211            let grad_avg = match centered_state {
212                Some(state) => state
213                    .grad_avg
214                    .map_or(grad_avg_constant.clone(), move |grad_avg| {
215                        grad_avg.mul_scalar(alpha).add(grad_avg_constant)
216                    }),
217                _ => grad_avg_constant,
218            };
219            let avg = square_avg_state
220                .square_avg
221                .clone()
222                .sub(grad_avg.clone().square());
223
224            (
225                grad,
226                square_avg_state,
227                Self {
228                    grad_avg: Some(grad_avg),
229                    avg,
230                },
231            )
232        } else {
233            (
234                grad,
235                square_avg_state.clone(),
236                Self {
237                    grad_avg: None,
238                    avg: square_avg_state.square_avg,
239                },
240            )
241        }
242    }
243
244    /// Moves the state to a device.
245    ///
246    /// # Arguments
247    ///
248    /// * `device` - Device to move the state to.
249    ///
250    /// # Returns
251    ///
252    /// * `self` - Moved state.
253    pub fn to_device(mut self, device: &B::Device) -> Self {
254        self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device));
255        self.avg = self.avg.to_device(device);
256        self
257    }
258}
259
260/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
261/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
262#[derive(Clone)]
263pub struct RmsPropMomentum {
264    momentum: f32,
265    epsilon: f32,
266}
267
268impl RmsPropMomentum {
269    /// transform [grad](Tensor) and [RmsPropMomentumState] to the next step
270    fn transform<B: Backend, const D: usize>(
271        &self,
272        grad: Tensor<B, D>,
273        centered_state: CenteredState<B, D>,
274        momentum_state: Option<RmsPropMomentumState<B, D>>,
275    ) -> (
276        Tensor<B, D>,
277        CenteredState<B, D>,
278        Option<RmsPropMomentumState<B, D>>,
279    ) {
280        let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
281
282        if self.momentum > 0. {
283            let buf = match momentum_state {
284                Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
285                _ => grad,
286            };
287            (
288                buf.clone(),
289                centered_state,
290                Some(RmsPropMomentumState { buf }),
291            )
292        } else {
293            (grad, centered_state, None)
294        }
295    }
296}
297
298/// [RmsPropMomentumState](RmsPropMomentumState) is to store and pass optimizer step params.
299#[derive(Record, Clone, new)]
300pub struct RmsPropMomentumState<B: Backend, const D: usize> {
301    buf: Tensor<B, D>,
302}
303
304impl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {
305    /// Moves the state to a device.
306    ///
307    /// # Arguments
308    ///
309    /// * `device` - Device to move the state to.
310    ///
311    /// # Returns
312    ///
313    /// * `self` - Moved state.
314    pub fn to_device(mut self, device: &B::Device) -> Self {
315        self.buf = self.buf.to_device(device);
316        self
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use burn::tensor::ops::FloatElem;
323    use burn::tensor::{Shape, Tolerance};
324
325    use super::*;
326    use crate::TestAutodiffBackend;
327    use crate::optim::{GradientsParams, Optimizer};
328    use burn::module::{Module, Param};
329    use burn::tensor::{Distribution, Tensor, TensorData};
330    use burn_nn::{Linear, LinearConfig, LinearRecord};
331
332    type FT = FloatElem<TestAutodiffBackend>;
333
334    const LEARNING_RATE: LearningRate = 0.01;
335
336    #[test]
337    fn test_rmsprop_optimizer_save_load_state() {
338        let device = Default::default();
339        let linear = LinearConfig::new(6, 6).init(&device);
340        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
341        let mut optimizer = create_rmsprop();
342        let grads = linear.forward(x).backward();
343        let grads = GradientsParams::from_grads(grads, &linear);
344        let _linear = optimizer.step(LEARNING_RATE, linear, grads);
345
346        #[cfg(feature = "std")]
347        {
348            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
349
350            BinFileRecorder::<FullPrecisionSettings>::default()
351                .record(
352                    optimizer.to_record(),
353                    std::env::temp_dir().as_path().join("test_optim_rmsprop"),
354                )
355                .unwrap();
356        }
357        #[cfg(not(feature = "std"))]
358        {
359            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
360
361            let result = BinBytesRecorder::<FullPrecisionSettings>::default()
362                .record(optimizer.to_record(), ())
363                .unwrap();
364            assert!(!result.is_empty());
365        }
366
367        let state_optim_before = optimizer.to_record();
368        let state_optim_before_copy = optimizer.to_record();
369        let optimizer = create_rmsprop();
370        let optimizer = optimizer.load_record(state_optim_before_copy);
371        let state_optim_after = optimizer.to_record();
372
373        assert_eq!(state_optim_before.len(), state_optim_after.len());
374    }
375
376    /// used for test differences and debug
377    #[test]
378    fn test_rmsprop_optimizer_with_numbers_basic() {
379        let linear = given_linear_layer(
380            TensorData::from([
381                [1., 1., 1., 1., 1., 1.],
382                [1., 1., 1., 1., 1., 1.],
383                [1., 1., 1., 1., 1., 1.],
384                [1., 1., 1., 1., 1., 1.],
385                [1., 1., 1., 1., 1., 1.],
386                [1., 1., 1., 1., 1., 1.],
387            ]),
388            TensorData::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
389        );
390        let device = Default::default();
391        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
392            [
393                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
394                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
395            ],
396            &device,
397        )
398        .require_grad();
399        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
400            [
401                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
402                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
403            ],
404            &device,
405        )
406        .require_grad();
407
408        let mut optimizer = RmsPropConfig::new()
409            .with_alpha(0.99)
410            .with_epsilon(1e-8)
411            .with_weight_decay(WeightDecayConfig::new(0.05).into())
412            .with_momentum(0.9)
413            .with_centered(false)
414            .init();
415
416        // println!("linear is {:?}", linear);
417        let grads = linear.forward(x_1).backward();
418        let grads = GradientsParams::from_grads(grads, &linear);
419        let linear = optimizer.step(LEARNING_RATE, linear, grads);
420
421        // println!("linear is {:?}", linear);
422        let grads = linear.forward(x_2).backward();
423        let grads = GradientsParams::from_grads(grads, &linear);
424        let linear = optimizer.step(LEARNING_RATE, linear, grads);
425
426        // println!("linear is {:?}", linear);
427        let state_updated = linear.into_record();
428
429        let (weight_updated, bias_updated) = (
430            state_updated.weight.to_data(),
431            state_updated.bias.unwrap().to_data(),
432        );
433
434        // println!("\nweight_updated\n{:?}", weight_updated);
435        // println!("\nbias_updated\n{:?}", bias_updated);
436
437        let weights_expected = TensorData::from([
438            [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937],
439            [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809],
440            [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881],
441            [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366],
442            [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005],
443            [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710],
444        ]);
445        let bias_expected =
446            TensorData::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]);
447
448        let tolerance = Tolerance::absolute(1e-6);
449        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
450        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
451    }
452
453    #[test]
454    fn test_rmsprop_optimizer_with_numbers() {
455        let linear = given_linear_layer(
456            TensorData::from([
457                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
458                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
459                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
460                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
461                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
462                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
463            ]),
464            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
465        );
466        let device = Default::default();
467        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
468            [
469                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
470                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
471            ],
472            &device,
473        )
474        .require_grad();
475        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
476            [
477                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
478                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
479            ],
480            &device,
481        )
482        .require_grad();
483
484        let mut optimizer = RmsPropConfig::new()
485            .with_alpha(0.99)
486            .with_epsilon(1e-8)
487            .with_weight_decay(WeightDecayConfig::new(0.05).into())
488            .with_momentum(0.9)
489            .with_centered(false)
490            .init();
491
492        let grads = linear.forward(x_1).backward();
493        let grads = GradientsParams::from_grads(grads, &linear);
494        let linear = optimizer.step(LEARNING_RATE, linear, grads);
495
496        let grads = linear.forward(x_2).backward();
497        let grads = GradientsParams::from_grads(grads, &linear);
498        let linear = optimizer.step(LEARNING_RATE, linear, grads);
499
500        let state_updated = linear.into_record();
501        let weights_expected = TensorData::from([
502            [
503                -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779,
504            ],
505            [
506                -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207,
507            ],
508            [
509                -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967,
510            ],
511            [
512                -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997,
513            ],
514            [
515                0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912,
516            ],
517            [
518                -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126,
519            ],
520        ]);
521        let bias_expected = TensorData::from([
522            -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800,
523        ]);
524
525        let (weight_updated, bias_updated) = (
526            state_updated.weight.to_data(),
527            state_updated.bias.unwrap().to_data(),
528        );
529
530        // println!("\nweight_updated\n{:?}", weight_updated);
531        // println!("\nbias_updated\n{:?}", bias_updated);
532
533        let tolerance = Tolerance::absolute(1e-6);
534        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
535        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
536    }
537
538    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
539        let device = Default::default();
540        let record = LinearRecord {
541            weight: Param::from_data(weight, &device),
542            bias: Some(Param::from_data(bias, &device)),
543        };
544
545        LinearConfig::new(6, 6).init(&device).load_record(record)
546    }
547
548    #[allow(dead_code)]
549    fn create_random_tensor() -> Tensor<TestAutodiffBackend, 2> {
550        Tensor::<TestAutodiffBackend, 2>::random(
551            Shape::new([2, 20]),
552            Distribution::Default,
553            &Default::default(),
554        )
555    }
556
557    fn create_rmsprop()
558    -> OptimizerAdaptor<RmsProp, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
559        RmsPropConfig {
560            alpha: 0.99,
561            epsilon: 1e-9,
562            centered: false,
563            weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
564            momentum: 0.9,
565            grad_clipping: None,
566        }
567        .init()
568    }
569}