1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use super::{
6 SimpleOptimizer,
7 adaptor::OptimizerAdaptor,
8 decay::{WeightDecay, WeightDecayConfig},
9};
10use crate::{LearningRate, grad_clipping::GradientClippingConfig};
11
12use burn::config::Config;
13use burn::tensor::backend::Backend;
14use burn::tensor::{Tensor, backend::AutodiffBackend, ops::Device};
15
16#[derive(Config, Debug)]
18pub struct RmsPropConfig {
19 #[config(default = 0.99)]
21 alpha: f32,
22 #[config(default = 0.9)]
24 momentum: f32,
25 #[config(default = 1e-5)]
27 epsilon: f32,
28 #[config(default = false)]
30 centered: bool,
31 weight_decay: Option<WeightDecayConfig>,
33 grad_clipping: Option<GradientClippingConfig>,
35}
36
37impl RmsPropConfig {
38 pub fn build(&self) -> RmsProp {
40 let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
41 RmsProp {
42 alpha: self.alpha,
43 centered: self.centered,
44 weight_decay,
45 momentum: RmsPropMomentum {
46 momentum: self.momentum,
47 epsilon: self.epsilon,
48 },
49 }
50 }
51
52 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
58 &self,
59 ) -> OptimizerAdaptor<RmsProp, M, B> {
60 let mut optim = OptimizerAdaptor::from(self.build());
61 if let Some(config) = &self.grad_clipping {
62 optim = optim.with_grad_clipping(config.init());
63 }
64
65 optim
66 }
67}
68
69#[derive(Clone)]
72pub struct RmsProp {
73 alpha: f32,
74 centered: bool,
76 momentum: RmsPropMomentum,
78 weight_decay: Option<WeightDecay>,
79}
80
81impl<B: Backend> SimpleOptimizer<B> for RmsProp {
82 type State<const D: usize> = RmsPropState<B, D>;
83
84 fn step<const D: usize>(
85 &self,
86 lr: LearningRate,
87 tensor: Tensor<B, D>,
88 mut grad: Tensor<B, D>,
89 state: Option<Self::State<D>>,
90 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
91 let mut state_square_avg = None;
93 let mut state_centered = None;
94 let mut state_momentum = None;
95 if let Some(state) = state {
96 state_square_avg = Some(state.square_avg);
97 state_centered = Some(state.centered);
98 state_momentum = state.momentum;
99 }
100
101 if let Some(weight_decay) = &self.weight_decay {
103 grad = weight_decay.transform(grad, tensor.clone());
104 }
105
106 let (grad, state_square_avg) =
108 SquareAvgState::transform(self.alpha, grad, state_square_avg);
109
110 let (grad, state_square_avg, state_centered) = CenteredState::transform(
112 self.alpha,
113 self.centered,
114 grad,
115 state_square_avg,
116 state_centered,
117 );
118
119 let (grad, state_centered, state_momentum) =
121 self.momentum
122 .transform(grad, state_centered, state_momentum);
123
124 let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);
126
127 let delta = grad.mul_scalar(lr);
129 (tensor - delta, Some(state))
130 }
131
132 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
133 state.square_avg = state.square_avg.to_device(device);
134 state.centered = state.centered.to_device(device);
135 state.momentum = state.momentum.map(|momentum| momentum.to_device(device));
136 state
137 }
138}
139
140#[derive(Record, Clone, new)]
142pub struct RmsPropState<B: Backend, const D: usize> {
143 pub square_avg: SquareAvgState<B, D>,
145 pub centered: CenteredState<B, D>,
147 pub momentum: Option<RmsPropMomentumState<B, D>>,
149}
150
151#[derive(Record, Clone, new)]
153pub struct SquareAvgState<B: Backend, const D: usize> {
154 pub square_avg: Tensor<B, D>,
156}
157
158impl<B: Backend, const D: usize> SquareAvgState<B, D> {
159 fn transform(alpha: f32, grad: Tensor<B, D>, state: Option<Self>) -> (Tensor<B, D>, Self) {
161 match state {
162 Some(state) => {
163 let square_avg = state
164 .square_avg
165 .mul_scalar(alpha)
166 .add(grad.clone().square().mul_scalar(1. - alpha));
167 (grad, Self { square_avg })
168 }
169 _ => {
170 let square_avg = grad.clone().square().mul_scalar(1. - alpha);
171 (grad, Self { square_avg })
172 }
173 }
174 }
175
176 pub fn to_device(mut self, device: &B::Device) -> Self {
186 self.square_avg = self.square_avg.to_device(device);
187 self
188 }
189}
190
191#[derive(Record, Clone, new)]
193pub struct CenteredState<B: Backend, const D: usize> {
194 pub grad_avg: Option<Tensor<B, D>>,
196 pub avg: Tensor<B, D>,
198}
199
200impl<B: Backend, const D: usize> CenteredState<B, D> {
201 fn transform(
203 alpha: f32,
204 centered: bool,
205 grad: Tensor<B, D>,
206 square_avg_state: SquareAvgState<B, D>,
207 centered_state: Option<Self>,
208 ) -> (Tensor<B, D>, SquareAvgState<B, D>, Self) {
209 if centered {
210 let grad_avg_constant = grad.clone().mul_scalar(1. - alpha);
211 let grad_avg = match centered_state {
212 Some(state) => state
213 .grad_avg
214 .map_or(grad_avg_constant.clone(), move |grad_avg| {
215 grad_avg.mul_scalar(alpha).add(grad_avg_constant)
216 }),
217 _ => grad_avg_constant,
218 };
219 let avg = square_avg_state
220 .square_avg
221 .clone()
222 .sub(grad_avg.clone().square());
223
224 (
225 grad,
226 square_avg_state,
227 Self {
228 grad_avg: Some(grad_avg),
229 avg,
230 },
231 )
232 } else {
233 (
234 grad,
235 square_avg_state.clone(),
236 Self {
237 grad_avg: None,
238 avg: square_avg_state.square_avg,
239 },
240 )
241 }
242 }
243
244 pub fn to_device(mut self, device: &B::Device) -> Self {
254 self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device));
255 self.avg = self.avg.to_device(device);
256 self
257 }
258}
259
260#[derive(Clone)]
263pub struct RmsPropMomentum {
264 momentum: f32,
265 epsilon: f32,
266}
267
268impl RmsPropMomentum {
269 fn transform<B: Backend, const D: usize>(
271 &self,
272 grad: Tensor<B, D>,
273 centered_state: CenteredState<B, D>,
274 momentum_state: Option<RmsPropMomentumState<B, D>>,
275 ) -> (
276 Tensor<B, D>,
277 CenteredState<B, D>,
278 Option<RmsPropMomentumState<B, D>>,
279 ) {
280 let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
281
282 if self.momentum > 0. {
283 let buf = match momentum_state {
284 Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
285 _ => grad,
286 };
287 (
288 buf.clone(),
289 centered_state,
290 Some(RmsPropMomentumState { buf }),
291 )
292 } else {
293 (grad, centered_state, None)
294 }
295 }
296}
297
298#[derive(Record, Clone, new)]
300pub struct RmsPropMomentumState<B: Backend, const D: usize> {
301 buf: Tensor<B, D>,
302}
303
304impl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {
305 pub fn to_device(mut self, device: &B::Device) -> Self {
315 self.buf = self.buf.to_device(device);
316 self
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use burn::tensor::ops::FloatElem;
323 use burn::tensor::{Shape, Tolerance};
324
325 use super::*;
326 use crate::TestAutodiffBackend;
327 use crate::optim::{GradientsParams, Optimizer};
328 use burn::module::{Module, Param};
329 use burn::tensor::{Distribution, Tensor, TensorData};
330 use burn_nn::{Linear, LinearConfig, LinearRecord};
331
332 type FT = FloatElem<TestAutodiffBackend>;
333
334 const LEARNING_RATE: LearningRate = 0.01;
335
336 #[test]
337 fn test_rmsprop_optimizer_save_load_state() {
338 let device = Default::default();
339 let linear = LinearConfig::new(6, 6).init(&device);
340 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
341 let mut optimizer = create_rmsprop();
342 let grads = linear.forward(x).backward();
343 let grads = GradientsParams::from_grads(grads, &linear);
344 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
345
346 #[cfg(feature = "std")]
347 {
348 use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
349
350 BinFileRecorder::<FullPrecisionSettings>::default()
351 .record(
352 optimizer.to_record(),
353 std::env::temp_dir().as_path().join("test_optim_rmsprop"),
354 )
355 .unwrap();
356 }
357 #[cfg(not(feature = "std"))]
358 {
359 use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
360
361 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
362 .record(optimizer.to_record(), ())
363 .unwrap();
364 assert!(!result.is_empty());
365 }
366
367 let state_optim_before = optimizer.to_record();
368 let state_optim_before_copy = optimizer.to_record();
369 let optimizer = create_rmsprop();
370 let optimizer = optimizer.load_record(state_optim_before_copy);
371 let state_optim_after = optimizer.to_record();
372
373 assert_eq!(state_optim_before.len(), state_optim_after.len());
374 }
375
376 #[test]
378 fn test_rmsprop_optimizer_with_numbers_basic() {
379 let linear = given_linear_layer(
380 TensorData::from([
381 [1., 1., 1., 1., 1., 1.],
382 [1., 1., 1., 1., 1., 1.],
383 [1., 1., 1., 1., 1., 1.],
384 [1., 1., 1., 1., 1., 1.],
385 [1., 1., 1., 1., 1., 1.],
386 [1., 1., 1., 1., 1., 1.],
387 ]),
388 TensorData::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
389 );
390 let device = Default::default();
391 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
392 [
393 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
394 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
395 ],
396 &device,
397 )
398 .require_grad();
399 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
400 [
401 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
402 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
403 ],
404 &device,
405 )
406 .require_grad();
407
408 let mut optimizer = RmsPropConfig::new()
409 .with_alpha(0.99)
410 .with_epsilon(1e-8)
411 .with_weight_decay(WeightDecayConfig::new(0.05).into())
412 .with_momentum(0.9)
413 .with_centered(false)
414 .init();
415
416 let grads = linear.forward(x_1).backward();
418 let grads = GradientsParams::from_grads(grads, &linear);
419 let linear = optimizer.step(LEARNING_RATE, linear, grads);
420
421 let grads = linear.forward(x_2).backward();
423 let grads = GradientsParams::from_grads(grads, &linear);
424 let linear = optimizer.step(LEARNING_RATE, linear, grads);
425
426 let state_updated = linear.into_record();
428
429 let (weight_updated, bias_updated) = (
430 state_updated.weight.to_data(),
431 state_updated.bias.unwrap().to_data(),
432 );
433
434 let weights_expected = TensorData::from([
438 [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937],
439 [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809],
440 [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881],
441 [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366],
442 [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005],
443 [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710],
444 ]);
445 let bias_expected =
446 TensorData::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]);
447
448 let tolerance = Tolerance::absolute(1e-6);
449 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
450 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
451 }
452
453 #[test]
454 fn test_rmsprop_optimizer_with_numbers() {
455 let linear = given_linear_layer(
456 TensorData::from([
457 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
458 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
459 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
460 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
461 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
462 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
463 ]),
464 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
465 );
466 let device = Default::default();
467 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
468 [
469 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
470 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
471 ],
472 &device,
473 )
474 .require_grad();
475 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
476 [
477 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
478 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
479 ],
480 &device,
481 )
482 .require_grad();
483
484 let mut optimizer = RmsPropConfig::new()
485 .with_alpha(0.99)
486 .with_epsilon(1e-8)
487 .with_weight_decay(WeightDecayConfig::new(0.05).into())
488 .with_momentum(0.9)
489 .with_centered(false)
490 .init();
491
492 let grads = linear.forward(x_1).backward();
493 let grads = GradientsParams::from_grads(grads, &linear);
494 let linear = optimizer.step(LEARNING_RATE, linear, grads);
495
496 let grads = linear.forward(x_2).backward();
497 let grads = GradientsParams::from_grads(grads, &linear);
498 let linear = optimizer.step(LEARNING_RATE, linear, grads);
499
500 let state_updated = linear.into_record();
501 let weights_expected = TensorData::from([
502 [
503 -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779,
504 ],
505 [
506 -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207,
507 ],
508 [
509 -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967,
510 ],
511 [
512 -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997,
513 ],
514 [
515 0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912,
516 ],
517 [
518 -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126,
519 ],
520 ]);
521 let bias_expected = TensorData::from([
522 -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800,
523 ]);
524
525 let (weight_updated, bias_updated) = (
526 state_updated.weight.to_data(),
527 state_updated.bias.unwrap().to_data(),
528 );
529
530 let tolerance = Tolerance::absolute(1e-6);
534 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
535 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
536 }
537
538 fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
539 let device = Default::default();
540 let record = LinearRecord {
541 weight: Param::from_data(weight, &device),
542 bias: Some(Param::from_data(bias, &device)),
543 };
544
545 LinearConfig::new(6, 6).init(&device).load_record(record)
546 }
547
548 #[allow(dead_code)]
549 fn create_random_tensor() -> Tensor<TestAutodiffBackend, 2> {
550 Tensor::<TestAutodiffBackend, 2>::random(
551 Shape::new([2, 20]),
552 Distribution::Default,
553 &Default::default(),
554 )
555 }
556
557 fn create_rmsprop()
558 -> OptimizerAdaptor<RmsProp, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
559 RmsPropConfig {
560 alpha: 0.99,
561 epsilon: 1e-9,
562 centered: false,
563 weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
564 momentum: 0.9,
565 grad_clipping: None,
566 }
567 .init()
568 }
569}