burn_core/optim/
rmsprop.rs

1use crate::{
2    self as burn, LearningRate, grad_clipping::GradientClippingConfig, module::AutodiffModule,
3    record::Record,
4};
5
6use super::{
7    SimpleOptimizer,
8    decay::{WeightDecay, WeightDecayConfig},
9};
10use crate::config::Config;
11use crate::optim::adaptor::OptimizerAdaptor;
12use crate::tensor::{Tensor, backend::AutodiffBackend, ops::Device};
13use burn_tensor::backend::Backend;
14
15/// Configuration to create the [RmsProp](RmsProp) optimizer.
16#[derive(Config)]
17pub struct RmsPropConfig {
18    /// Smoothing constant.
19    #[config(default = 0.99)]
20    alpha: f32,
21    /// momentum for RmsProp.
22    #[config(default = 0.9)]
23    momentum: f32,
24    /// A value required for numerical stability.
25    #[config(default = 1e-5)]
26    epsilon: f32,
27    /// if True, compute the centered RmsProp, the gradient is normalized by an estimation of its variance
28    #[config(default = false)]
29    centered: bool,
30    /// [Weight decay](WeightDecayConfig) config.
31    weight_decay: Option<WeightDecayConfig>,
32    /// [Gradient Clipping](GradientClippingConfig) config.
33    grad_clipping: Option<GradientClippingConfig>,
34}
35
36impl RmsPropConfig {
37    /// Initialize RmsProp optimizer.
38    ///
39    /// # Returns
40    ///
41    /// Returns an optimizer that can be used to optimize a module.
42    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
43        &self,
44    ) -> OptimizerAdaptor<RmsProp, M, B> {
45        let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
46
47        let mut optim = OptimizerAdaptor::from(RmsProp {
48            alpha: self.alpha,
49            centered: self.centered,
50            weight_decay,
51            momentum: RmsPropMomentum {
52                momentum: self.momentum,
53                epsilon: self.epsilon,
54            },
55        });
56
57        if let Some(config) = &self.grad_clipping {
58            optim = optim.with_grad_clipping(config.init());
59        }
60
61        optim
62    }
63}
64
65/// Optimizer that implements stochastic gradient descent with momentum.
66/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
67#[derive(Clone)]
68pub struct RmsProp {
69    alpha: f32,
70    // epsilon: f32,
71    centered: bool,
72    // momentum: Option<Momentum<B>>,
73    momentum: RmsPropMomentum,
74    weight_decay: Option<WeightDecay>,
75}
76
77impl<B: Backend> SimpleOptimizer<B> for RmsProp {
78    type State<const D: usize> = RmsPropState<B, D>;
79
80    fn step<const D: usize>(
81        &self,
82        lr: LearningRate,
83        tensor: Tensor<B, D>,
84        mut grad: Tensor<B, D>,
85        state: Option<Self::State<D>>,
86    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
87        // fetch state for params
88        let mut state_square_avg = None;
89        let mut state_centered = None;
90        let mut state_momentum = None;
91        if let Some(state) = state {
92            state_square_avg = Some(state.square_avg);
93            state_centered = Some(state.centered);
94            state_momentum = state.momentum;
95        }
96
97        // weight_decay transform
98        if let Some(weight_decay) = &self.weight_decay {
99            grad = weight_decay.transform(grad, tensor.clone());
100        }
101
102        // square_avg transform
103        let (grad, state_square_avg) =
104            SquareAvgState::transform(self.alpha, grad, state_square_avg);
105
106        // centered transform
107        let (grad, state_square_avg, state_centered) = CenteredState::transform(
108            self.alpha,
109            self.centered,
110            grad,
111            state_square_avg,
112            state_centered,
113        );
114
115        // momentum transform
116        let (grad, state_centered, state_momentum) =
117            self.momentum
118                .transform(grad, state_centered, state_momentum);
119
120        // transition state
121        let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);
122
123        // tensor param transform
124        let delta = grad.mul_scalar(lr);
125        (tensor - delta, Some(state))
126    }
127
128    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
129        state.square_avg = state.square_avg.to_device(device);
130        state.centered = state.centered.to_device(device);
131        state.momentum = state.momentum.map(|momentum| momentum.to_device(device));
132        state
133    }
134}
135
136/// State of [RmsProp](RmsProp)
137#[derive(Record, Clone, new)]
138pub struct RmsPropState<B: Backend, const D: usize> {
139    /// Current squared average state.
140    pub square_avg: SquareAvgState<B, D>,
141    /// Current centered state
142    pub centered: CenteredState<B, D>,
143    /// Current gradient momentum, if any.
144    pub momentum: Option<RmsPropMomentumState<B, D>>,
145}
146
147/// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params.
148#[derive(Record, Clone, new)]
149pub struct SquareAvgState<B: Backend, const D: usize> {
150    /// Current squared average.
151    pub square_avg: Tensor<B, D>,
152}
153
154impl<B: Backend, const D: usize> SquareAvgState<B, D> {
155    /// transform [SquareAvgState] to the next step
156    fn transform(alpha: f32, grad: Tensor<B, D>, state: Option<Self>) -> (Tensor<B, D>, Self) {
157        match state {
158            Some(state) => {
159                let square_avg = state
160                    .square_avg
161                    .mul_scalar(alpha)
162                    .add(grad.clone().powi_scalar(2).mul_scalar(1. - alpha));
163                (grad, Self { square_avg })
164            }
165            _ => {
166                let square_avg = grad.clone().powi_scalar(2).mul_scalar(1. - alpha);
167                (grad, Self { square_avg })
168            }
169        }
170    }
171
172    /// Moves the state to a device.
173    ///
174    /// # Arguments
175    ///
176    /// * `device` - Device to move the state to.
177    ///
178    /// # Returns
179    ///
180    /// * `self` - Moved state.
181    pub fn to_device(mut self, device: &B::Device) -> Self {
182        self.square_avg = self.square_avg.to_device(device);
183        self
184    }
185}
186
187/// [CenteredState](CenteredState) is to store and pass optimizer step params.
188#[derive(Record, Clone, new)]
189pub struct CenteredState<B: Backend, const D: usize> {
190    /// The averaged gradient to calculate the centered gradient, if available.
191    pub grad_avg: Option<Tensor<B, D>>,
192    /// The current average value.
193    pub avg: Tensor<B, D>,
194}
195
196impl<B: Backend, const D: usize> CenteredState<B, D> {
197    /// transform [CenteredState] to the next step
198    fn transform(
199        alpha: f32,
200        centered: bool,
201        grad: Tensor<B, D>,
202        square_avg_state: SquareAvgState<B, D>,
203        centered_state: Option<Self>,
204    ) -> (Tensor<B, D>, SquareAvgState<B, D>, Self) {
205        if centered {
206            let grad_avg_constant = grad.clone().mul_scalar(1. - alpha);
207            let grad_avg = match centered_state {
208                Some(state) => state
209                    .grad_avg
210                    .map_or(grad_avg_constant.clone(), move |grad_avg| {
211                        grad_avg.mul_scalar(alpha).add(grad_avg_constant)
212                    }),
213                _ => grad_avg_constant,
214            };
215            let avg = square_avg_state
216                .square_avg
217                .clone()
218                .sub(grad_avg.clone().powi_scalar(2));
219
220            (
221                grad,
222                square_avg_state,
223                Self {
224                    grad_avg: Some(grad_avg),
225                    avg,
226                },
227            )
228        } else {
229            (
230                grad,
231                square_avg_state.clone(),
232                Self {
233                    grad_avg: None,
234                    avg: square_avg_state.square_avg,
235                },
236            )
237        }
238    }
239
240    /// Moves the state to a device.
241    ///
242    /// # Arguments
243    ///
244    /// * `device` - Device to move the state to.
245    ///
246    /// # Returns
247    ///
248    /// * `self` - Moved state.
249    pub fn to_device(mut self, device: &B::Device) -> Self {
250        self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device));
251        self.avg = self.avg.to_device(device);
252        self
253    }
254}
255
256/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
257/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
258#[derive(Clone)]
259pub struct RmsPropMomentum {
260    momentum: f32,
261    epsilon: f32,
262}
263
264impl RmsPropMomentum {
265    /// transform [grad](Tensor) and [RmsPropMomentumState] to the next step
266    fn transform<B: Backend, const D: usize>(
267        &self,
268        grad: Tensor<B, D>,
269        centered_state: CenteredState<B, D>,
270        momentum_state: Option<RmsPropMomentumState<B, D>>,
271    ) -> (
272        Tensor<B, D>,
273        CenteredState<B, D>,
274        Option<RmsPropMomentumState<B, D>>,
275    ) {
276        let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
277
278        if self.momentum > 0. {
279            let buf = match momentum_state {
280                Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
281                _ => grad,
282            };
283            (
284                buf.clone(),
285                centered_state,
286                Some(RmsPropMomentumState { buf }),
287            )
288        } else {
289            (grad, centered_state, None)
290        }
291    }
292}
293
294/// [RmsPropMomentumState](RmsPropMomentumState) is to store and pass optimizer step params.
295#[derive(Record, Clone, new)]
296pub struct RmsPropMomentumState<B: Backend, const D: usize> {
297    buf: Tensor<B, D>,
298}
299
300impl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {
301    /// Moves the state to a device.
302    ///
303    /// # Arguments
304    ///
305    /// * `device` - Device to move the state to.
306    ///
307    /// # Returns
308    ///
309    /// * `self` - Moved state.
310    pub fn to_device(mut self, device: &B::Device) -> Self {
311        self.buf = self.buf.to_device(device);
312        self
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use burn_tensor::ops::FloatElem;
319    use burn_tensor::{Shape, Tolerance};
320
321    use super::*;
322    use crate::module::{Module, Param};
323    use crate::optim::{GradientsParams, Optimizer};
324    use crate::tensor::{Distribution, Tensor, TensorData};
325    use crate::{TestAutodiffBackend, nn};
326
327    type FT = FloatElem<TestAutodiffBackend>;
328
329    const LEARNING_RATE: LearningRate = 0.01;
330
331    #[test]
332    fn test_rmsprop_optimizer_save_load_state() {
333        let device = Default::default();
334        let linear = nn::LinearConfig::new(6, 6).init(&device);
335        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
336        let mut optimizer = create_rmsprop();
337        let grads = linear.forward(x).backward();
338        let grads = GradientsParams::from_grads(grads, &linear);
339        let _linear = optimizer.step(LEARNING_RATE, linear, grads);
340
341        #[cfg(feature = "std")]
342        {
343            use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
344
345            BinFileRecorder::<FullPrecisionSettings>::default()
346                .record(
347                    optimizer.to_record(),
348                    std::env::temp_dir().as_path().join("test_optim_rmsprop"),
349                )
350                .unwrap();
351        }
352        #[cfg(not(feature = "std"))]
353        {
354            use crate::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
355
356            let result = BinBytesRecorder::<FullPrecisionSettings>::default()
357                .record(optimizer.to_record(), ())
358                .unwrap();
359            assert!(!result.is_empty());
360        }
361
362        let state_optim_before = optimizer.to_record();
363        let state_optim_before_copy = optimizer.to_record();
364        let optimizer = create_rmsprop();
365        let optimizer = optimizer.load_record(state_optim_before_copy);
366        let state_optim_after = optimizer.to_record();
367
368        assert_eq!(state_optim_before.len(), state_optim_after.len());
369    }
370
371    /// used for test differences and debug
372    #[test]
373    fn test_rmsprop_optimizer_with_numbers_basic() {
374        let linear = given_linear_layer(
375            TensorData::from([
376                [1., 1., 1., 1., 1., 1.],
377                [1., 1., 1., 1., 1., 1.],
378                [1., 1., 1., 1., 1., 1.],
379                [1., 1., 1., 1., 1., 1.],
380                [1., 1., 1., 1., 1., 1.],
381                [1., 1., 1., 1., 1., 1.],
382            ]),
383            TensorData::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
384        );
385        let device = Default::default();
386        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
387            [
388                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
389                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
390            ],
391            &device,
392        )
393        .require_grad();
394        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
395            [
396                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
397                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
398            ],
399            &device,
400        )
401        .require_grad();
402
403        let mut optimizer = RmsPropConfig::new()
404            .with_alpha(0.99)
405            .with_epsilon(1e-8)
406            .with_weight_decay(WeightDecayConfig::new(0.05).into())
407            .with_momentum(0.9)
408            .with_centered(false)
409            .init();
410
411        // println!("linear is {:?}", linear);
412        let grads = linear.forward(x_1).backward();
413        let grads = GradientsParams::from_grads(grads, &linear);
414        let linear = optimizer.step(LEARNING_RATE, linear, grads);
415
416        // println!("linear is {:?}", linear);
417        let grads = linear.forward(x_2).backward();
418        let grads = GradientsParams::from_grads(grads, &linear);
419        let linear = optimizer.step(LEARNING_RATE, linear, grads);
420
421        // println!("linear is {:?}", linear);
422        let state_updated = linear.into_record();
423
424        let (weight_updated, bias_updated) = (
425            state_updated.weight.to_data(),
426            state_updated.bias.unwrap().to_data(),
427        );
428
429        // println!("\nweight_updated\n{:?}", weight_updated);
430        // println!("\nbias_updated\n{:?}", bias_updated);
431
432        let weights_expected = TensorData::from([
433            [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937],
434            [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809],
435            [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881],
436            [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366],
437            [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005],
438            [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710],
439        ]);
440        let bias_expected =
441            TensorData::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]);
442
443        let tolerance = Tolerance::absolute(1e-6);
444        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
445        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
446    }
447
448    #[test]
449    fn test_rmsprop_optimizer_with_numbers() {
450        let linear = given_linear_layer(
451            TensorData::from([
452                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
453                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
454                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
455                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
456                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
457                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
458            ]),
459            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
460        );
461        let device = Default::default();
462        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
463            [
464                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
465                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
466            ],
467            &device,
468        )
469        .require_grad();
470        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
471            [
472                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
473                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
474            ],
475            &device,
476        )
477        .require_grad();
478
479        let mut optimizer = RmsPropConfig::new()
480            .with_alpha(0.99)
481            .with_epsilon(1e-8)
482            .with_weight_decay(WeightDecayConfig::new(0.05).into())
483            .with_momentum(0.9)
484            .with_centered(false)
485            .init();
486
487        let grads = linear.forward(x_1).backward();
488        let grads = GradientsParams::from_grads(grads, &linear);
489        let linear = optimizer.step(LEARNING_RATE, linear, grads);
490
491        let grads = linear.forward(x_2).backward();
492        let grads = GradientsParams::from_grads(grads, &linear);
493        let linear = optimizer.step(LEARNING_RATE, linear, grads);
494
495        let state_updated = linear.into_record();
496        let weights_expected = TensorData::from([
497            [
498                -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779,
499            ],
500            [
501                -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207,
502            ],
503            [
504                -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967,
505            ],
506            [
507                -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997,
508            ],
509            [
510                0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912,
511            ],
512            [
513                -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126,
514            ],
515        ]);
516        let bias_expected = TensorData::from([
517            -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800,
518        ]);
519
520        let (weight_updated, bias_updated) = (
521            state_updated.weight.to_data(),
522            state_updated.bias.unwrap().to_data(),
523        );
524
525        // println!("\nweight_updated\n{:?}", weight_updated);
526        // println!("\nbias_updated\n{:?}", bias_updated);
527
528        let tolerance = Tolerance::absolute(1e-6);
529        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
530        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
531    }
532
533    fn given_linear_layer(weight: TensorData, bias: TensorData) -> nn::Linear<TestAutodiffBackend> {
534        let device = Default::default();
535        let record = nn::LinearRecord {
536            weight: Param::from_data(weight, &device),
537            bias: Some(Param::from_data(bias, &device)),
538        };
539
540        nn::LinearConfig::new(6, 6)
541            .init(&device)
542            .load_record(record)
543    }
544
545    #[allow(dead_code)]
546    fn create_random_tensor() -> Tensor<TestAutodiffBackend, 2> {
547        Tensor::<TestAutodiffBackend, 2>::random(
548            Shape::new([2, 20]),
549            Distribution::Default,
550            &Default::default(),
551        )
552    }
553
554    fn create_rmsprop()
555    -> OptimizerAdaptor<RmsProp, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend> {
556        RmsPropConfig {
557            alpha: 0.99,
558            epsilon: 1e-9,
559            centered: false,
560            weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
561            momentum: 0.9,
562            grad_clipping: None,
563        }
564        .init()
565    }
566}