1use crate::{
2 self as burn, LearningRate, grad_clipping::GradientClippingConfig, module::AutodiffModule,
3 record::Record,
4};
5
6use super::{
7 SimpleOptimizer,
8 decay::{WeightDecay, WeightDecayConfig},
9};
10use crate::config::Config;
11use crate::optim::adaptor::OptimizerAdaptor;
12use crate::tensor::{Tensor, backend::AutodiffBackend, ops::Device};
13use burn_tensor::backend::Backend;
14
15#[derive(Config)]
17pub struct RmsPropConfig {
18 #[config(default = 0.99)]
20 alpha: f32,
21 #[config(default = 0.9)]
23 momentum: f32,
24 #[config(default = 1e-5)]
26 epsilon: f32,
27 #[config(default = false)]
29 centered: bool,
30 weight_decay: Option<WeightDecayConfig>,
32 grad_clipping: Option<GradientClippingConfig>,
34}
35
36impl RmsPropConfig {
37 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
43 &self,
44 ) -> OptimizerAdaptor<RmsProp, M, B> {
45 let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
46
47 let mut optim = OptimizerAdaptor::from(RmsProp {
48 alpha: self.alpha,
49 centered: self.centered,
50 weight_decay,
51 momentum: RmsPropMomentum {
52 momentum: self.momentum,
53 epsilon: self.epsilon,
54 },
55 });
56
57 if let Some(config) = &self.grad_clipping {
58 optim = optim.with_grad_clipping(config.init());
59 }
60
61 optim
62 }
63}
64
65#[derive(Clone)]
68pub struct RmsProp {
69 alpha: f32,
70 centered: bool,
72 momentum: RmsPropMomentum,
74 weight_decay: Option<WeightDecay>,
75}
76
77impl<B: Backend> SimpleOptimizer<B> for RmsProp {
78 type State<const D: usize> = RmsPropState<B, D>;
79
80 fn step<const D: usize>(
81 &self,
82 lr: LearningRate,
83 tensor: Tensor<B, D>,
84 mut grad: Tensor<B, D>,
85 state: Option<Self::State<D>>,
86 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
87 let mut state_square_avg = None;
89 let mut state_centered = None;
90 let mut state_momentum = None;
91 if let Some(state) = state {
92 state_square_avg = Some(state.square_avg);
93 state_centered = Some(state.centered);
94 state_momentum = state.momentum;
95 }
96
97 if let Some(weight_decay) = &self.weight_decay {
99 grad = weight_decay.transform(grad, tensor.clone());
100 }
101
102 let (grad, state_square_avg) =
104 SquareAvgState::transform(self.alpha, grad, state_square_avg);
105
106 let (grad, state_square_avg, state_centered) = CenteredState::transform(
108 self.alpha,
109 self.centered,
110 grad,
111 state_square_avg,
112 state_centered,
113 );
114
115 let (grad, state_centered, state_momentum) =
117 self.momentum
118 .transform(grad, state_centered, state_momentum);
119
120 let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);
122
123 let delta = grad.mul_scalar(lr);
125 (tensor - delta, Some(state))
126 }
127
128 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
129 state.square_avg = state.square_avg.to_device(device);
130 state.centered = state.centered.to_device(device);
131 state.momentum = state.momentum.map(|momentum| momentum.to_device(device));
132 state
133 }
134}
135
136#[derive(Record, Clone, new)]
138pub struct RmsPropState<B: Backend, const D: usize> {
139 pub square_avg: SquareAvgState<B, D>,
141 pub centered: CenteredState<B, D>,
143 pub momentum: Option<RmsPropMomentumState<B, D>>,
145}
146
147#[derive(Record, Clone, new)]
149pub struct SquareAvgState<B: Backend, const D: usize> {
150 pub square_avg: Tensor<B, D>,
152}
153
154impl<B: Backend, const D: usize> SquareAvgState<B, D> {
155 fn transform(alpha: f32, grad: Tensor<B, D>, state: Option<Self>) -> (Tensor<B, D>, Self) {
157 match state {
158 Some(state) => {
159 let square_avg = state
160 .square_avg
161 .mul_scalar(alpha)
162 .add(grad.clone().powi_scalar(2).mul_scalar(1. - alpha));
163 (grad, Self { square_avg })
164 }
165 _ => {
166 let square_avg = grad.clone().powi_scalar(2).mul_scalar(1. - alpha);
167 (grad, Self { square_avg })
168 }
169 }
170 }
171
172 pub fn to_device(mut self, device: &B::Device) -> Self {
182 self.square_avg = self.square_avg.to_device(device);
183 self
184 }
185}
186
187#[derive(Record, Clone, new)]
189pub struct CenteredState<B: Backend, const D: usize> {
190 pub grad_avg: Option<Tensor<B, D>>,
192 pub avg: Tensor<B, D>,
194}
195
196impl<B: Backend, const D: usize> CenteredState<B, D> {
197 fn transform(
199 alpha: f32,
200 centered: bool,
201 grad: Tensor<B, D>,
202 square_avg_state: SquareAvgState<B, D>,
203 centered_state: Option<Self>,
204 ) -> (Tensor<B, D>, SquareAvgState<B, D>, Self) {
205 if centered {
206 let grad_avg_constant = grad.clone().mul_scalar(1. - alpha);
207 let grad_avg = match centered_state {
208 Some(state) => state
209 .grad_avg
210 .map_or(grad_avg_constant.clone(), move |grad_avg| {
211 grad_avg.mul_scalar(alpha).add(grad_avg_constant)
212 }),
213 _ => grad_avg_constant,
214 };
215 let avg = square_avg_state
216 .square_avg
217 .clone()
218 .sub(grad_avg.clone().powi_scalar(2));
219
220 (
221 grad,
222 square_avg_state,
223 Self {
224 grad_avg: Some(grad_avg),
225 avg,
226 },
227 )
228 } else {
229 (
230 grad,
231 square_avg_state.clone(),
232 Self {
233 grad_avg: None,
234 avg: square_avg_state.square_avg,
235 },
236 )
237 }
238 }
239
240 pub fn to_device(mut self, device: &B::Device) -> Self {
250 self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device));
251 self.avg = self.avg.to_device(device);
252 self
253 }
254}
255
256#[derive(Clone)]
259pub struct RmsPropMomentum {
260 momentum: f32,
261 epsilon: f32,
262}
263
264impl RmsPropMomentum {
265 fn transform<B: Backend, const D: usize>(
267 &self,
268 grad: Tensor<B, D>,
269 centered_state: CenteredState<B, D>,
270 momentum_state: Option<RmsPropMomentumState<B, D>>,
271 ) -> (
272 Tensor<B, D>,
273 CenteredState<B, D>,
274 Option<RmsPropMomentumState<B, D>>,
275 ) {
276 let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
277
278 if self.momentum > 0. {
279 let buf = match momentum_state {
280 Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
281 _ => grad,
282 };
283 (
284 buf.clone(),
285 centered_state,
286 Some(RmsPropMomentumState { buf }),
287 )
288 } else {
289 (grad, centered_state, None)
290 }
291 }
292}
293
294#[derive(Record, Clone, new)]
296pub struct RmsPropMomentumState<B: Backend, const D: usize> {
297 buf: Tensor<B, D>,
298}
299
300impl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {
301 pub fn to_device(mut self, device: &B::Device) -> Self {
311 self.buf = self.buf.to_device(device);
312 self
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use burn_tensor::ops::FloatElem;
319 use burn_tensor::{Shape, Tolerance};
320
321 use super::*;
322 use crate::module::{Module, Param};
323 use crate::optim::{GradientsParams, Optimizer};
324 use crate::tensor::{Distribution, Tensor, TensorData};
325 use crate::{TestAutodiffBackend, nn};
326
327 type FT = FloatElem<TestAutodiffBackend>;
328
329 const LEARNING_RATE: LearningRate = 0.01;
330
331 #[test]
332 fn test_rmsprop_optimizer_save_load_state() {
333 let device = Default::default();
334 let linear = nn::LinearConfig::new(6, 6).init(&device);
335 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
336 let mut optimizer = create_rmsprop();
337 let grads = linear.forward(x).backward();
338 let grads = GradientsParams::from_grads(grads, &linear);
339 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
340
341 #[cfg(feature = "std")]
342 {
343 use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
344
345 BinFileRecorder::<FullPrecisionSettings>::default()
346 .record(
347 optimizer.to_record(),
348 std::env::temp_dir().as_path().join("test_optim_rmsprop"),
349 )
350 .unwrap();
351 }
352 #[cfg(not(feature = "std"))]
353 {
354 use crate::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
355
356 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
357 .record(optimizer.to_record(), ())
358 .unwrap();
359 assert!(!result.is_empty());
360 }
361
362 let state_optim_before = optimizer.to_record();
363 let state_optim_before_copy = optimizer.to_record();
364 let optimizer = create_rmsprop();
365 let optimizer = optimizer.load_record(state_optim_before_copy);
366 let state_optim_after = optimizer.to_record();
367
368 assert_eq!(state_optim_before.len(), state_optim_after.len());
369 }
370
371 #[test]
373 fn test_rmsprop_optimizer_with_numbers_basic() {
374 let linear = given_linear_layer(
375 TensorData::from([
376 [1., 1., 1., 1., 1., 1.],
377 [1., 1., 1., 1., 1., 1.],
378 [1., 1., 1., 1., 1., 1.],
379 [1., 1., 1., 1., 1., 1.],
380 [1., 1., 1., 1., 1., 1.],
381 [1., 1., 1., 1., 1., 1.],
382 ]),
383 TensorData::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
384 );
385 let device = Default::default();
386 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
387 [
388 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
389 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
390 ],
391 &device,
392 )
393 .require_grad();
394 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
395 [
396 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
397 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
398 ],
399 &device,
400 )
401 .require_grad();
402
403 let mut optimizer = RmsPropConfig::new()
404 .with_alpha(0.99)
405 .with_epsilon(1e-8)
406 .with_weight_decay(WeightDecayConfig::new(0.05).into())
407 .with_momentum(0.9)
408 .with_centered(false)
409 .init();
410
411 let grads = linear.forward(x_1).backward();
413 let grads = GradientsParams::from_grads(grads, &linear);
414 let linear = optimizer.step(LEARNING_RATE, linear, grads);
415
416 let grads = linear.forward(x_2).backward();
418 let grads = GradientsParams::from_grads(grads, &linear);
419 let linear = optimizer.step(LEARNING_RATE, linear, grads);
420
421 let state_updated = linear.into_record();
423
424 let (weight_updated, bias_updated) = (
425 state_updated.weight.to_data(),
426 state_updated.bias.unwrap().to_data(),
427 );
428
429 let weights_expected = TensorData::from([
433 [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937],
434 [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809],
435 [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881],
436 [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366],
437 [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005],
438 [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710],
439 ]);
440 let bias_expected =
441 TensorData::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]);
442
443 let tolerance = Tolerance::absolute(1e-6);
444 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
445 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
446 }
447
448 #[test]
449 fn test_rmsprop_optimizer_with_numbers() {
450 let linear = given_linear_layer(
451 TensorData::from([
452 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
453 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
454 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
455 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
456 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
457 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
458 ]),
459 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
460 );
461 let device = Default::default();
462 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
463 [
464 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
465 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
466 ],
467 &device,
468 )
469 .require_grad();
470 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
471 [
472 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
473 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
474 ],
475 &device,
476 )
477 .require_grad();
478
479 let mut optimizer = RmsPropConfig::new()
480 .with_alpha(0.99)
481 .with_epsilon(1e-8)
482 .with_weight_decay(WeightDecayConfig::new(0.05).into())
483 .with_momentum(0.9)
484 .with_centered(false)
485 .init();
486
487 let grads = linear.forward(x_1).backward();
488 let grads = GradientsParams::from_grads(grads, &linear);
489 let linear = optimizer.step(LEARNING_RATE, linear, grads);
490
491 let grads = linear.forward(x_2).backward();
492 let grads = GradientsParams::from_grads(grads, &linear);
493 let linear = optimizer.step(LEARNING_RATE, linear, grads);
494
495 let state_updated = linear.into_record();
496 let weights_expected = TensorData::from([
497 [
498 -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779,
499 ],
500 [
501 -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207,
502 ],
503 [
504 -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967,
505 ],
506 [
507 -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997,
508 ],
509 [
510 0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912,
511 ],
512 [
513 -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126,
514 ],
515 ]);
516 let bias_expected = TensorData::from([
517 -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800,
518 ]);
519
520 let (weight_updated, bias_updated) = (
521 state_updated.weight.to_data(),
522 state_updated.bias.unwrap().to_data(),
523 );
524
525 let tolerance = Tolerance::absolute(1e-6);
529 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
530 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
531 }
532
533 fn given_linear_layer(weight: TensorData, bias: TensorData) -> nn::Linear<TestAutodiffBackend> {
534 let device = Default::default();
535 let record = nn::LinearRecord {
536 weight: Param::from_data(weight, &device),
537 bias: Some(Param::from_data(bias, &device)),
538 };
539
540 nn::LinearConfig::new(6, 6)
541 .init(&device)
542 .load_record(record)
543 }
544
545 #[allow(dead_code)]
546 fn create_random_tensor() -> Tensor<TestAutodiffBackend, 2> {
547 Tensor::<TestAutodiffBackend, 2>::random(
548 Shape::new([2, 20]),
549 Distribution::Default,
550 &Default::default(),
551 )
552 }
553
554 fn create_rmsprop()
555 -> OptimizerAdaptor<RmsProp, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend> {
556 RmsPropConfig {
557 alpha: 0.99,
558 epsilon: 1e-9,
559 centered: false,
560 weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
561 momentum: 0.9,
562 grad_clipping: None,
563 }
564 .init()
565 }
566}