1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use super::{
6 SimpleOptimizer,
7 adaptor::OptimizerAdaptor,
8 decay::{WeightDecay, WeightDecayConfig},
9};
10use crate::{LearningRate, grad_clipping::GradientClippingConfig};
11
12use burn::config::Config;
13use burn::tensor::backend::Backend;
14use burn::tensor::{Tensor, backend::AutodiffBackend, ops::Device};
15
16#[derive(Config, Debug)]
18pub struct RmsPropConfig {
19 #[config(default = 0.99)]
21 alpha: f32,
22 #[config(default = 0.9)]
24 momentum: f32,
25 #[config(default = 1e-5)]
27 epsilon: f32,
28 #[config(default = false)]
30 centered: bool,
31 weight_decay: Option<WeightDecayConfig>,
33 grad_clipping: Option<GradientClippingConfig>,
35}
36
37impl RmsPropConfig {
38 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
44 &self,
45 ) -> OptimizerAdaptor<RmsProp, M, B> {
46 let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
47
48 let mut optim = OptimizerAdaptor::from(RmsProp {
49 alpha: self.alpha,
50 centered: self.centered,
51 weight_decay,
52 momentum: RmsPropMomentum {
53 momentum: self.momentum,
54 epsilon: self.epsilon,
55 },
56 });
57
58 if let Some(config) = &self.grad_clipping {
59 optim = optim.with_grad_clipping(config.init());
60 }
61
62 optim
63 }
64}
65
66#[derive(Clone)]
69pub struct RmsProp {
70 alpha: f32,
71 centered: bool,
73 momentum: RmsPropMomentum,
75 weight_decay: Option<WeightDecay>,
76}
77
78impl<B: Backend> SimpleOptimizer<B> for RmsProp {
79 type State<const D: usize> = RmsPropState<B, D>;
80
81 fn step<const D: usize>(
82 &self,
83 lr: LearningRate,
84 tensor: Tensor<B, D>,
85 mut grad: Tensor<B, D>,
86 state: Option<Self::State<D>>,
87 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
88 let mut state_square_avg = None;
90 let mut state_centered = None;
91 let mut state_momentum = None;
92 if let Some(state) = state {
93 state_square_avg = Some(state.square_avg);
94 state_centered = Some(state.centered);
95 state_momentum = state.momentum;
96 }
97
98 if let Some(weight_decay) = &self.weight_decay {
100 grad = weight_decay.transform(grad, tensor.clone());
101 }
102
103 let (grad, state_square_avg) =
105 SquareAvgState::transform(self.alpha, grad, state_square_avg);
106
107 let (grad, state_square_avg, state_centered) = CenteredState::transform(
109 self.alpha,
110 self.centered,
111 grad,
112 state_square_avg,
113 state_centered,
114 );
115
116 let (grad, state_centered, state_momentum) =
118 self.momentum
119 .transform(grad, state_centered, state_momentum);
120
121 let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);
123
124 let delta = grad.mul_scalar(lr);
126 (tensor - delta, Some(state))
127 }
128
129 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
130 state.square_avg = state.square_avg.to_device(device);
131 state.centered = state.centered.to_device(device);
132 state.momentum = state.momentum.map(|momentum| momentum.to_device(device));
133 state
134 }
135}
136
137#[derive(Record, Clone, new)]
139pub struct RmsPropState<B: Backend, const D: usize> {
140 pub square_avg: SquareAvgState<B, D>,
142 pub centered: CenteredState<B, D>,
144 pub momentum: Option<RmsPropMomentumState<B, D>>,
146}
147
148#[derive(Record, Clone, new)]
150pub struct SquareAvgState<B: Backend, const D: usize> {
151 pub square_avg: Tensor<B, D>,
153}
154
155impl<B: Backend, const D: usize> SquareAvgState<B, D> {
156 fn transform(alpha: f32, grad: Tensor<B, D>, state: Option<Self>) -> (Tensor<B, D>, Self) {
158 match state {
159 Some(state) => {
160 let square_avg = state
161 .square_avg
162 .mul_scalar(alpha)
163 .add(grad.clone().powi_scalar(2).mul_scalar(1. - alpha));
164 (grad, Self { square_avg })
165 }
166 _ => {
167 let square_avg = grad.clone().powi_scalar(2).mul_scalar(1. - alpha);
168 (grad, Self { square_avg })
169 }
170 }
171 }
172
173 pub fn to_device(mut self, device: &B::Device) -> Self {
183 self.square_avg = self.square_avg.to_device(device);
184 self
185 }
186}
187
188#[derive(Record, Clone, new)]
190pub struct CenteredState<B: Backend, const D: usize> {
191 pub grad_avg: Option<Tensor<B, D>>,
193 pub avg: Tensor<B, D>,
195}
196
197impl<B: Backend, const D: usize> CenteredState<B, D> {
198 fn transform(
200 alpha: f32,
201 centered: bool,
202 grad: Tensor<B, D>,
203 square_avg_state: SquareAvgState<B, D>,
204 centered_state: Option<Self>,
205 ) -> (Tensor<B, D>, SquareAvgState<B, D>, Self) {
206 if centered {
207 let grad_avg_constant = grad.clone().mul_scalar(1. - alpha);
208 let grad_avg = match centered_state {
209 Some(state) => state
210 .grad_avg
211 .map_or(grad_avg_constant.clone(), move |grad_avg| {
212 grad_avg.mul_scalar(alpha).add(grad_avg_constant)
213 }),
214 _ => grad_avg_constant,
215 };
216 let avg = square_avg_state
217 .square_avg
218 .clone()
219 .sub(grad_avg.clone().powi_scalar(2));
220
221 (
222 grad,
223 square_avg_state,
224 Self {
225 grad_avg: Some(grad_avg),
226 avg,
227 },
228 )
229 } else {
230 (
231 grad,
232 square_avg_state.clone(),
233 Self {
234 grad_avg: None,
235 avg: square_avg_state.square_avg,
236 },
237 )
238 }
239 }
240
241 pub fn to_device(mut self, device: &B::Device) -> Self {
251 self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device));
252 self.avg = self.avg.to_device(device);
253 self
254 }
255}
256
257#[derive(Clone)]
260pub struct RmsPropMomentum {
261 momentum: f32,
262 epsilon: f32,
263}
264
265impl RmsPropMomentum {
266 fn transform<B: Backend, const D: usize>(
268 &self,
269 grad: Tensor<B, D>,
270 centered_state: CenteredState<B, D>,
271 momentum_state: Option<RmsPropMomentumState<B, D>>,
272 ) -> (
273 Tensor<B, D>,
274 CenteredState<B, D>,
275 Option<RmsPropMomentumState<B, D>>,
276 ) {
277 let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
278
279 if self.momentum > 0. {
280 let buf = match momentum_state {
281 Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
282 _ => grad,
283 };
284 (
285 buf.clone(),
286 centered_state,
287 Some(RmsPropMomentumState { buf }),
288 )
289 } else {
290 (grad, centered_state, None)
291 }
292 }
293}
294
295#[derive(Record, Clone, new)]
297pub struct RmsPropMomentumState<B: Backend, const D: usize> {
298 buf: Tensor<B, D>,
299}
300
301impl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {
302 pub fn to_device(mut self, device: &B::Device) -> Self {
312 self.buf = self.buf.to_device(device);
313 self
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use burn::tensor::ops::FloatElem;
320 use burn::tensor::{Shape, Tolerance};
321
322 use super::*;
323 use crate::TestAutodiffBackend;
324 use crate::optim::{GradientsParams, Optimizer};
325 use burn::module::{Module, Param};
326 use burn::tensor::{Distribution, Tensor, TensorData};
327 use burn_nn::{Linear, LinearConfig, LinearRecord};
328
329 type FT = FloatElem<TestAutodiffBackend>;
330
331 const LEARNING_RATE: LearningRate = 0.01;
332
333 #[test]
334 fn test_rmsprop_optimizer_save_load_state() {
335 let device = Default::default();
336 let linear = LinearConfig::new(6, 6).init(&device);
337 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
338 let mut optimizer = create_rmsprop();
339 let grads = linear.forward(x).backward();
340 let grads = GradientsParams::from_grads(grads, &linear);
341 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
342
343 #[cfg(feature = "std")]
344 {
345 use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
346
347 BinFileRecorder::<FullPrecisionSettings>::default()
348 .record(
349 optimizer.to_record(),
350 std::env::temp_dir().as_path().join("test_optim_rmsprop"),
351 )
352 .unwrap();
353 }
354 #[cfg(not(feature = "std"))]
355 {
356 use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
357
358 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
359 .record(optimizer.to_record(), ())
360 .unwrap();
361 assert!(!result.is_empty());
362 }
363
364 let state_optim_before = optimizer.to_record();
365 let state_optim_before_copy = optimizer.to_record();
366 let optimizer = create_rmsprop();
367 let optimizer = optimizer.load_record(state_optim_before_copy);
368 let state_optim_after = optimizer.to_record();
369
370 assert_eq!(state_optim_before.len(), state_optim_after.len());
371 }
372
373 #[test]
375 fn test_rmsprop_optimizer_with_numbers_basic() {
376 let linear = given_linear_layer(
377 TensorData::from([
378 [1., 1., 1., 1., 1., 1.],
379 [1., 1., 1., 1., 1., 1.],
380 [1., 1., 1., 1., 1., 1.],
381 [1., 1., 1., 1., 1., 1.],
382 [1., 1., 1., 1., 1., 1.],
383 [1., 1., 1., 1., 1., 1.],
384 ]),
385 TensorData::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
386 );
387 let device = Default::default();
388 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
389 [
390 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
391 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
392 ],
393 &device,
394 )
395 .require_grad();
396 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
397 [
398 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
399 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
400 ],
401 &device,
402 )
403 .require_grad();
404
405 let mut optimizer = RmsPropConfig::new()
406 .with_alpha(0.99)
407 .with_epsilon(1e-8)
408 .with_weight_decay(WeightDecayConfig::new(0.05).into())
409 .with_momentum(0.9)
410 .with_centered(false)
411 .init();
412
413 let grads = linear.forward(x_1).backward();
415 let grads = GradientsParams::from_grads(grads, &linear);
416 let linear = optimizer.step(LEARNING_RATE, linear, grads);
417
418 let grads = linear.forward(x_2).backward();
420 let grads = GradientsParams::from_grads(grads, &linear);
421 let linear = optimizer.step(LEARNING_RATE, linear, grads);
422
423 let state_updated = linear.into_record();
425
426 let (weight_updated, bias_updated) = (
427 state_updated.weight.to_data(),
428 state_updated.bias.unwrap().to_data(),
429 );
430
431 let weights_expected = TensorData::from([
435 [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937],
436 [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809],
437 [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881],
438 [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366],
439 [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005],
440 [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710],
441 ]);
442 let bias_expected =
443 TensorData::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]);
444
445 let tolerance = Tolerance::absolute(1e-6);
446 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
447 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
448 }
449
450 #[test]
451 fn test_rmsprop_optimizer_with_numbers() {
452 let linear = given_linear_layer(
453 TensorData::from([
454 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
455 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
456 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
457 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
458 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
459 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
460 ]),
461 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
462 );
463 let device = Default::default();
464 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
465 [
466 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
467 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
468 ],
469 &device,
470 )
471 .require_grad();
472 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
473 [
474 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
475 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
476 ],
477 &device,
478 )
479 .require_grad();
480
481 let mut optimizer = RmsPropConfig::new()
482 .with_alpha(0.99)
483 .with_epsilon(1e-8)
484 .with_weight_decay(WeightDecayConfig::new(0.05).into())
485 .with_momentum(0.9)
486 .with_centered(false)
487 .init();
488
489 let grads = linear.forward(x_1).backward();
490 let grads = GradientsParams::from_grads(grads, &linear);
491 let linear = optimizer.step(LEARNING_RATE, linear, grads);
492
493 let grads = linear.forward(x_2).backward();
494 let grads = GradientsParams::from_grads(grads, &linear);
495 let linear = optimizer.step(LEARNING_RATE, linear, grads);
496
497 let state_updated = linear.into_record();
498 let weights_expected = TensorData::from([
499 [
500 -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779,
501 ],
502 [
503 -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207,
504 ],
505 [
506 -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967,
507 ],
508 [
509 -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997,
510 ],
511 [
512 0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912,
513 ],
514 [
515 -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126,
516 ],
517 ]);
518 let bias_expected = TensorData::from([
519 -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800,
520 ]);
521
522 let (weight_updated, bias_updated) = (
523 state_updated.weight.to_data(),
524 state_updated.bias.unwrap().to_data(),
525 );
526
527 let tolerance = Tolerance::absolute(1e-6);
531 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
532 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
533 }
534
535 fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
536 let device = Default::default();
537 let record = LinearRecord {
538 weight: Param::from_data(weight, &device),
539 bias: Some(Param::from_data(bias, &device)),
540 };
541
542 LinearConfig::new(6, 6).init(&device).load_record(record)
543 }
544
545 #[allow(dead_code)]
546 fn create_random_tensor() -> Tensor<TestAutodiffBackend, 2> {
547 Tensor::<TestAutodiffBackend, 2>::random(
548 Shape::new([2, 20]),
549 Distribution::Default,
550 &Default::default(),
551 )
552 }
553
554 fn create_rmsprop()
555 -> OptimizerAdaptor<RmsProp, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
556 RmsPropConfig {
557 alpha: 0.99,
558 epsilon: 1e-9,
559 centered: false,
560 weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
561 momentum: 0.9,
562 grad_clipping: None,
563 }
564 .init()
565 }
566}