burn_core/lr_scheduler/
step.rs1use burn_tensor::backend::Backend;
2
3use crate as burn;
4
5use super::{LrScheduler, String};
6use crate::{LearningRate, config::Config};
7
8#[derive(Config)]
22pub struct StepLrSchedulerConfig {
23 initial_lr: LearningRate,
25 step_size: usize,
28 #[config(default = 0.1)]
30 gamma: f64,
31}
32
33impl StepLrSchedulerConfig {
34 pub fn init(&self) -> Result<StepLrScheduler, String> {
40 if self.step_size == 0 {
41 return Err("Step size must be greater than 0".into());
42 }
43
44 if self.initial_lr <= 0.0 {
47 log::warn!(
48 "Initial learning rate value of {} is not a positive number. Ignore this warning \
49 if it is intended.",
50 self.initial_lr
51 );
52 }
53 if self.gamma <= 0.0 || self.gamma >= 1.0 {
54 log::warn!(
55 "Gamma value of {} is out of range (0.0, 1.0). Ignore this warning if it is \
56 intended.",
57 self.gamma
58 );
59 }
60
61 Ok(StepLrScheduler {
62 init_lr: self.initial_lr,
63 step_size: self.step_size,
64 gamma: self.gamma,
65 iter_idx: -1,
66 })
67 }
68}
69
70#[derive(Clone, Debug)]
72pub struct StepLrScheduler {
73 init_lr: LearningRate,
74 step_size: usize,
75 gamma: f64,
76 iter_idx: i32,
79}
80
81impl LrScheduler for StepLrScheduler {
82 type Record<B: Backend> = i32;
83
84 fn step(&mut self) -> LearningRate {
85 self.iter_idx = self
86 .iter_idx
87 .checked_add(1)
88 .expect("`.step()` should be called no more than `i32::MAX + 1` times");
89 self.init_lr
91 * self
92 .gamma
93 .powi((self.iter_idx as usize / self.step_size) as i32)
94 }
95
96 fn to_record<B: Backend>(&self) -> Self::Record<B> {
97 self.iter_idx
98 }
99
100 fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
101 self.iter_idx = record;
102 self
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::super::test_utils;
109 use super::*;
110 use crate::TestBackend;
111
112 #[test]
131 fn test_config_step_size_zero() {
132 let r = StepLrSchedulerConfig::new(1.0, 0).init();
133 assert!(r.is_err(), "Should return an error");
134 }
135
136 #[test]
137 fn test_config_step_size_nonzero() {
138 let r = StepLrSchedulerConfig::new(1.0, 1).init();
139 assert!(r.is_ok(), "Should return a success value");
140 }
141
142 #[test]
143 fn test_config_default_gamma() {
144 const INIT_LR: LearningRate = 0.4;
145 const STEP_SIZE: usize = 2;
146
147 let mut default = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
148 .init()
149 .unwrap();
150 let mut explicit = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
151 .with_gamma(0.1)
152 .init()
153 .unwrap();
154 test_utils::compare_steps(&mut default, &mut explicit, 3 * STEP_SIZE);
155 }
156
157 #[test]
158 fn test_lr_decreasing() {
159 let scheduler = StepLrSchedulerConfig::new(0.5, 3)
160 .with_gamma(0.1)
161 .init()
162 .unwrap();
163 let expected_lrs = [0.5, 0.5, 0.5, 0.05, 0.05, 0.05, 0.005, 0.005, 0.005];
164 test_utils::check_lr_sequence(scheduler, expected_lrs);
165 }
166
167 #[test]
168 fn test_lr_increasing() {
169 let scheduler = StepLrSchedulerConfig::new(0.1, 2)
170 .with_gamma(2.0)
171 .init()
172 .unwrap();
173 let expected_lrs = [0.1, 0.1, 0.2, 0.2, 0.4, 0.4];
174 test_utils::check_lr_sequence(scheduler, expected_lrs);
175 }
176
177 #[test]
178 fn test_lr_unchanging() {
179 let scheduler = StepLrSchedulerConfig::new(3.1, 1)
180 .with_gamma(1.0)
181 .init()
182 .unwrap();
183 let expected_lrs = [3.1, 3.1, 3.1];
184 test_utils::check_lr_sequence(scheduler, expected_lrs);
185 }
186
187 #[test]
188 fn test_save_and_load() {
189 const STEP_SIZE: usize = 10;
190
191 let scheduler = StepLrSchedulerConfig::new(0.007, STEP_SIZE)
192 .with_gamma(0.03)
193 .init()
194 .unwrap();
195 test_utils::check_save_load(scheduler, 3 * STEP_SIZE / 2);
196 }
197
198 #[test]
201 fn test_number_of_calls_within_limit() {
202 let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
204 scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
205 scheduler.step();
206 }
207
208 #[test]
209 #[should_panic = "i32::MAX"]
210 fn test_number_of_calls_over_limit() {
211 let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
213 scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
214 scheduler.step();
215 scheduler.step();
216 }
217}