burn_optim/optim/
rmsprop.rs

1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use super::{
6    SimpleOptimizer,
7    adaptor::OptimizerAdaptor,
8    decay::{WeightDecay, WeightDecayConfig},
9};
10use crate::{LearningRate, grad_clipping::GradientClippingConfig};
11
12use burn::config::Config;
13use burn::tensor::backend::Backend;
14use burn::tensor::{Tensor, backend::AutodiffBackend, ops::Device};
15
16/// Configuration to create the [RmsProp](RmsProp) optimizer.
17#[derive(Config, Debug)]
18pub struct RmsPropConfig {
19    /// Smoothing constant.
20    #[config(default = 0.99)]
21    alpha: f32,
22    /// momentum for RmsProp.
23    #[config(default = 0.9)]
24    momentum: f32,
25    /// A value required for numerical stability.
26    #[config(default = 1e-5)]
27    epsilon: f32,
28    /// if True, compute the centered RmsProp, the gradient is normalized by an estimation of its variance
29    #[config(default = false)]
30    centered: bool,
31    /// [Weight decay](WeightDecayConfig) config.
32    weight_decay: Option<WeightDecayConfig>,
33    /// [Gradient Clipping](GradientClippingConfig) config.
34    grad_clipping: Option<GradientClippingConfig>,
35}
36
37impl RmsPropConfig {
38    /// Initialize RmsProp optimizer.
39    ///
40    /// # Returns
41    ///
42    /// Returns an optimizer that can be used to optimize a module.
43    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
44        &self,
45    ) -> OptimizerAdaptor<RmsProp, M, B> {
46        let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
47
48        let mut optim = OptimizerAdaptor::from(RmsProp {
49            alpha: self.alpha,
50            centered: self.centered,
51            weight_decay,
52            momentum: RmsPropMomentum {
53                momentum: self.momentum,
54                epsilon: self.epsilon,
55            },
56        });
57
58        if let Some(config) = &self.grad_clipping {
59            optim = optim.with_grad_clipping(config.init());
60        }
61
62        optim
63    }
64}
65
66/// Optimizer that implements stochastic gradient descent with momentum.
67/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
68#[derive(Clone)]
69pub struct RmsProp {
70    alpha: f32,
71    // epsilon: f32,
72    centered: bool,
73    // momentum: Option<Momentum<B>>,
74    momentum: RmsPropMomentum,
75    weight_decay: Option<WeightDecay>,
76}
77
78impl<B: Backend> SimpleOptimizer<B> for RmsProp {
79    type State<const D: usize> = RmsPropState<B, D>;
80
81    fn step<const D: usize>(
82        &self,
83        lr: LearningRate,
84        tensor: Tensor<B, D>,
85        mut grad: Tensor<B, D>,
86        state: Option<Self::State<D>>,
87    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
88        // fetch state for params
89        let mut state_square_avg = None;
90        let mut state_centered = None;
91        let mut state_momentum = None;
92        if let Some(state) = state {
93            state_square_avg = Some(state.square_avg);
94            state_centered = Some(state.centered);
95            state_momentum = state.momentum;
96        }
97
98        // weight_decay transform
99        if let Some(weight_decay) = &self.weight_decay {
100            grad = weight_decay.transform(grad, tensor.clone());
101        }
102
103        // square_avg transform
104        let (grad, state_square_avg) =
105            SquareAvgState::transform(self.alpha, grad, state_square_avg);
106
107        // centered transform
108        let (grad, state_square_avg, state_centered) = CenteredState::transform(
109            self.alpha,
110            self.centered,
111            grad,
112            state_square_avg,
113            state_centered,
114        );
115
116        // momentum transform
117        let (grad, state_centered, state_momentum) =
118            self.momentum
119                .transform(grad, state_centered, state_momentum);
120
121        // transition state
122        let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);
123
124        // tensor param transform
125        let delta = grad.mul_scalar(lr);
126        (tensor - delta, Some(state))
127    }
128
129    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
130        state.square_avg = state.square_avg.to_device(device);
131        state.centered = state.centered.to_device(device);
132        state.momentum = state.momentum.map(|momentum| momentum.to_device(device));
133        state
134    }
135}
136
137/// State of [RmsProp](RmsProp)
138#[derive(Record, Clone, new)]
139pub struct RmsPropState<B: Backend, const D: usize> {
140    /// Current squared average state.
141    pub square_avg: SquareAvgState<B, D>,
142    /// Current centered state
143    pub centered: CenteredState<B, D>,
144    /// Current gradient momentum, if any.
145    pub momentum: Option<RmsPropMomentumState<B, D>>,
146}
147
148/// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params.
149#[derive(Record, Clone, new)]
150pub struct SquareAvgState<B: Backend, const D: usize> {
151    /// Current squared average.
152    pub square_avg: Tensor<B, D>,
153}
154
155impl<B: Backend, const D: usize> SquareAvgState<B, D> {
156    /// transform [SquareAvgState] to the next step
157    fn transform(alpha: f32, grad: Tensor<B, D>, state: Option<Self>) -> (Tensor<B, D>, Self) {
158        match state {
159            Some(state) => {
160                let square_avg = state
161                    .square_avg
162                    .mul_scalar(alpha)
163                    .add(grad.clone().powi_scalar(2).mul_scalar(1. - alpha));
164                (grad, Self { square_avg })
165            }
166            _ => {
167                let square_avg = grad.clone().powi_scalar(2).mul_scalar(1. - alpha);
168                (grad, Self { square_avg })
169            }
170        }
171    }
172
173    /// Moves the state to a device.
174    ///
175    /// # Arguments
176    ///
177    /// * `device` - Device to move the state to.
178    ///
179    /// # Returns
180    ///
181    /// * `self` - Moved state.
182    pub fn to_device(mut self, device: &B::Device) -> Self {
183        self.square_avg = self.square_avg.to_device(device);
184        self
185    }
186}
187
188/// [CenteredState](CenteredState) is to store and pass optimizer step params.
189#[derive(Record, Clone, new)]
190pub struct CenteredState<B: Backend, const D: usize> {
191    /// The averaged gradient to calculate the centered gradient, if available.
192    pub grad_avg: Option<Tensor<B, D>>,
193    /// The current average value.
194    pub avg: Tensor<B, D>,
195}
196
197impl<B: Backend, const D: usize> CenteredState<B, D> {
198    /// transform [CenteredState] to the next step
199    fn transform(
200        alpha: f32,
201        centered: bool,
202        grad: Tensor<B, D>,
203        square_avg_state: SquareAvgState<B, D>,
204        centered_state: Option<Self>,
205    ) -> (Tensor<B, D>, SquareAvgState<B, D>, Self) {
206        if centered {
207            let grad_avg_constant = grad.clone().mul_scalar(1. - alpha);
208            let grad_avg = match centered_state {
209                Some(state) => state
210                    .grad_avg
211                    .map_or(grad_avg_constant.clone(), move |grad_avg| {
212                        grad_avg.mul_scalar(alpha).add(grad_avg_constant)
213                    }),
214                _ => grad_avg_constant,
215            };
216            let avg = square_avg_state
217                .square_avg
218                .clone()
219                .sub(grad_avg.clone().powi_scalar(2));
220
221            (
222                grad,
223                square_avg_state,
224                Self {
225                    grad_avg: Some(grad_avg),
226                    avg,
227                },
228            )
229        } else {
230            (
231                grad,
232                square_avg_state.clone(),
233                Self {
234                    grad_avg: None,
235                    avg: square_avg_state.square_avg,
236                },
237            )
238        }
239    }
240
241    /// Moves the state to a device.
242    ///
243    /// # Arguments
244    ///
245    /// * `device` - Device to move the state to.
246    ///
247    /// # Returns
248    ///
249    /// * `self` - Moved state.
250    pub fn to_device(mut self, device: &B::Device) -> Self {
251        self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device));
252        self.avg = self.avg.to_device(device);
253        self
254    }
255}
256
257/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
258/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
259#[derive(Clone)]
260pub struct RmsPropMomentum {
261    momentum: f32,
262    epsilon: f32,
263}
264
265impl RmsPropMomentum {
266    /// transform [grad](Tensor) and [RmsPropMomentumState] to the next step
267    fn transform<B: Backend, const D: usize>(
268        &self,
269        grad: Tensor<B, D>,
270        centered_state: CenteredState<B, D>,
271        momentum_state: Option<RmsPropMomentumState<B, D>>,
272    ) -> (
273        Tensor<B, D>,
274        CenteredState<B, D>,
275        Option<RmsPropMomentumState<B, D>>,
276    ) {
277        let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
278
279        if self.momentum > 0. {
280            let buf = match momentum_state {
281                Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
282                _ => grad,
283            };
284            (
285                buf.clone(),
286                centered_state,
287                Some(RmsPropMomentumState { buf }),
288            )
289        } else {
290            (grad, centered_state, None)
291        }
292    }
293}
294
295/// [RmsPropMomentumState](RmsPropMomentumState) is to store and pass optimizer step params.
296#[derive(Record, Clone, new)]
297pub struct RmsPropMomentumState<B: Backend, const D: usize> {
298    buf: Tensor<B, D>,
299}
300
301impl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {
302    /// Moves the state to a device.
303    ///
304    /// # Arguments
305    ///
306    /// * `device` - Device to move the state to.
307    ///
308    /// # Returns
309    ///
310    /// * `self` - Moved state.
311    pub fn to_device(mut self, device: &B::Device) -> Self {
312        self.buf = self.buf.to_device(device);
313        self
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use burn::tensor::ops::FloatElem;
320    use burn::tensor::{Shape, Tolerance};
321
322    use super::*;
323    use crate::TestAutodiffBackend;
324    use crate::optim::{GradientsParams, Optimizer};
325    use burn::module::{Module, Param};
326    use burn::tensor::{Distribution, Tensor, TensorData};
327    use burn_nn::{Linear, LinearConfig, LinearRecord};
328
329    type FT = FloatElem<TestAutodiffBackend>;
330
331    const LEARNING_RATE: LearningRate = 0.01;
332
333    #[test]
334    fn test_rmsprop_optimizer_save_load_state() {
335        let device = Default::default();
336        let linear = LinearConfig::new(6, 6).init(&device);
337        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
338        let mut optimizer = create_rmsprop();
339        let grads = linear.forward(x).backward();
340        let grads = GradientsParams::from_grads(grads, &linear);
341        let _linear = optimizer.step(LEARNING_RATE, linear, grads);
342
343        #[cfg(feature = "std")]
344        {
345            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
346
347            BinFileRecorder::<FullPrecisionSettings>::default()
348                .record(
349                    optimizer.to_record(),
350                    std::env::temp_dir().as_path().join("test_optim_rmsprop"),
351                )
352                .unwrap();
353        }
354        #[cfg(not(feature = "std"))]
355        {
356            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
357
358            let result = BinBytesRecorder::<FullPrecisionSettings>::default()
359                .record(optimizer.to_record(), ())
360                .unwrap();
361            assert!(!result.is_empty());
362        }
363
364        let state_optim_before = optimizer.to_record();
365        let state_optim_before_copy = optimizer.to_record();
366        let optimizer = create_rmsprop();
367        let optimizer = optimizer.load_record(state_optim_before_copy);
368        let state_optim_after = optimizer.to_record();
369
370        assert_eq!(state_optim_before.len(), state_optim_after.len());
371    }
372
373    /// used for test differences and debug
374    #[test]
375    fn test_rmsprop_optimizer_with_numbers_basic() {
376        let linear = given_linear_layer(
377            TensorData::from([
378                [1., 1., 1., 1., 1., 1.],
379                [1., 1., 1., 1., 1., 1.],
380                [1., 1., 1., 1., 1., 1.],
381                [1., 1., 1., 1., 1., 1.],
382                [1., 1., 1., 1., 1., 1.],
383                [1., 1., 1., 1., 1., 1.],
384            ]),
385            TensorData::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
386        );
387        let device = Default::default();
388        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
389            [
390                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
391                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
392            ],
393            &device,
394        )
395        .require_grad();
396        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
397            [
398                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
399                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
400            ],
401            &device,
402        )
403        .require_grad();
404
405        let mut optimizer = RmsPropConfig::new()
406            .with_alpha(0.99)
407            .with_epsilon(1e-8)
408            .with_weight_decay(WeightDecayConfig::new(0.05).into())
409            .with_momentum(0.9)
410            .with_centered(false)
411            .init();
412
413        // println!("linear is {:?}", linear);
414        let grads = linear.forward(x_1).backward();
415        let grads = GradientsParams::from_grads(grads, &linear);
416        let linear = optimizer.step(LEARNING_RATE, linear, grads);
417
418        // println!("linear is {:?}", linear);
419        let grads = linear.forward(x_2).backward();
420        let grads = GradientsParams::from_grads(grads, &linear);
421        let linear = optimizer.step(LEARNING_RATE, linear, grads);
422
423        // println!("linear is {:?}", linear);
424        let state_updated = linear.into_record();
425
426        let (weight_updated, bias_updated) = (
427            state_updated.weight.to_data(),
428            state_updated.bias.unwrap().to_data(),
429        );
430
431        // println!("\nweight_updated\n{:?}", weight_updated);
432        // println!("\nbias_updated\n{:?}", bias_updated);
433
434        let weights_expected = TensorData::from([
435            [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937],
436            [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809],
437            [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881],
438            [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366],
439            [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005],
440            [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710],
441        ]);
442        let bias_expected =
443            TensorData::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]);
444
445        let tolerance = Tolerance::absolute(1e-6);
446        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
447        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
448    }
449
450    #[test]
451    fn test_rmsprop_optimizer_with_numbers() {
452        let linear = given_linear_layer(
453            TensorData::from([
454                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
455                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
456                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
457                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
458                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
459                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
460            ]),
461            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
462        );
463        let device = Default::default();
464        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
465            [
466                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
467                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
468            ],
469            &device,
470        )
471        .require_grad();
472        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
473            [
474                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
475                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
476            ],
477            &device,
478        )
479        .require_grad();
480
481        let mut optimizer = RmsPropConfig::new()
482            .with_alpha(0.99)
483            .with_epsilon(1e-8)
484            .with_weight_decay(WeightDecayConfig::new(0.05).into())
485            .with_momentum(0.9)
486            .with_centered(false)
487            .init();
488
489        let grads = linear.forward(x_1).backward();
490        let grads = GradientsParams::from_grads(grads, &linear);
491        let linear = optimizer.step(LEARNING_RATE, linear, grads);
492
493        let grads = linear.forward(x_2).backward();
494        let grads = GradientsParams::from_grads(grads, &linear);
495        let linear = optimizer.step(LEARNING_RATE, linear, grads);
496
497        let state_updated = linear.into_record();
498        let weights_expected = TensorData::from([
499            [
500                -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779,
501            ],
502            [
503                -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207,
504            ],
505            [
506                -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967,
507            ],
508            [
509                -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997,
510            ],
511            [
512                0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912,
513            ],
514            [
515                -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126,
516            ],
517        ]);
518        let bias_expected = TensorData::from([
519            -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800,
520        ]);
521
522        let (weight_updated, bias_updated) = (
523            state_updated.weight.to_data(),
524            state_updated.bias.unwrap().to_data(),
525        );
526
527        // println!("\nweight_updated\n{:?}", weight_updated);
528        // println!("\nbias_updated\n{:?}", bias_updated);
529
530        let tolerance = Tolerance::absolute(1e-6);
531        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
532        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
533    }
534
535    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
536        let device = Default::default();
537        let record = LinearRecord {
538            weight: Param::from_data(weight, &device),
539            bias: Some(Param::from_data(bias, &device)),
540        };
541
542        LinearConfig::new(6, 6).init(&device).load_record(record)
543    }
544
545    #[allow(dead_code)]
546    fn create_random_tensor() -> Tensor<TestAutodiffBackend, 2> {
547        Tensor::<TestAutodiffBackend, 2>::random(
548            Shape::new([2, 20]),
549            Distribution::Default,
550            &Default::default(),
551        )
552    }
553
554    fn create_rmsprop()
555    -> OptimizerAdaptor<RmsProp, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
556        RmsPropConfig {
557            alpha: 0.99,
558            epsilon: 1e-9,
559            centered: false,
560            weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
561            momentum: 0.9,
562            grad_clipping: None,
563        }
564        .init()
565    }
566}