burn_core/lr_scheduler/
exponential.rs1use super::{LrScheduler, String};
2use crate as burn;
3use crate::{LearningRate, config::Config};
4use burn_tensor::backend::Backend;
5
6#[derive(Config)]
12pub struct ExponentialLrSchedulerConfig {
13 initial_lr: LearningRate,
15 gamma: f64,
17}
18
19impl ExponentialLrSchedulerConfig {
20 pub fn init(&self) -> Result<ExponentialLrScheduler, String> {
29 if self.initial_lr <= 0. || self.initial_lr > 1. {
30 return Err("Initial learning rate must be greater than 0 and at most 1".into());
31 }
32 if self.gamma <= 0. || self.gamma > 1. {
33 return Err("Gamma must be greater than 0 and at most 1".into());
34 }
35
36 Ok(ExponentialLrScheduler {
37 previous_lr: self.initial_lr / self.gamma,
40 gamma: self.gamma,
41 })
42 }
43}
44
45#[derive(Clone, Copy, Debug)]
49pub struct ExponentialLrScheduler {
50 previous_lr: LearningRate,
52 gamma: f64,
54}
55
56impl LrScheduler for ExponentialLrScheduler {
57 type Record<B: Backend> = LearningRate;
58
59 fn step(&mut self) -> LearningRate {
60 self.previous_lr *= self.gamma;
61 self.previous_lr
62 }
63
64 fn to_record<B: Backend>(&self) -> Self::Record<B> {
65 self.previous_lr
66 }
67
68 fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
69 self.previous_lr = record;
70 self
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::super::test_utils;
77 use super::*;
78
79 #[test]
80 fn config_initial_lr_too_low() {
81 let r = ExponentialLrSchedulerConfig::new(0., 0.5).init();
82 assert!(r.is_err(), "Should return an error");
83 assert_eq!(
84 r.unwrap_err(),
85 "Initial learning rate must be greater than 0 and at most 1",
86 "Error messages should match",
87 );
88 }
89
90 #[test]
91 fn config_initial_lr_too_high() {
92 let r = ExponentialLrSchedulerConfig::new(1.5, 0.5).init();
93 assert!(r.is_err(), "Should return an error");
94 assert_eq!(
95 r.unwrap_err(),
96 "Initial learning rate must be greater than 0 and at most 1",
97 "Error messages should match",
98 );
99 }
100
101 #[test]
102 fn config_gamma_too_low() {
103 let r = ExponentialLrSchedulerConfig::new(0.5, 0.0).init();
104 assert!(r.is_err(), "Should return an error");
105 assert_eq!(
106 r.unwrap_err(),
107 "Gamma must be greater than 0 and at most 1",
108 "Error messages should match",
109 );
110 }
111
112 #[test]
113 fn config_gamma_too_high() {
114 let r = ExponentialLrSchedulerConfig::new(0.5, 1.5).init();
115 assert!(r.is_err(), "Should return an error");
116 assert_eq!(
117 r.unwrap_err(),
118 "Gamma must be greater than 0 and at most 1",
119 "Error messages should match",
120 );
121 }
122
123 #[test]
124 fn test_lr_change() {
125 let scheduler = ExponentialLrSchedulerConfig::new(0.8, 0.1).init().unwrap();
126 let expected_lrs = [0.8, 0.08, 0.008, 0.0008, 0.00008];
127 test_utils::check_lr_sequence(scheduler, expected_lrs);
128 }
129
130 #[test]
131 fn test_save_and_load() {
132 let scheduler = ExponentialLrSchedulerConfig::new(0.083, 0.3)
133 .init()
134 .unwrap();
135 test_utils::check_save_load(scheduler, 7);
136 }
137}