burn_optim/lr_scheduler/
step.rs1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::backend::Backend;
5
6use super::{LrScheduler, String};
7use crate::LearningRate;
8
9#[derive(Config, Debug)]
23pub struct StepLrSchedulerConfig {
24 initial_lr: LearningRate,
26 step_size: usize,
29 #[config(default = 0.1)]
31 gamma: f64,
32}
33
34impl StepLrSchedulerConfig {
35 pub fn init(&self) -> Result<StepLrScheduler, String> {
41 if self.step_size == 0 {
42 return Err("Step size must be greater than 0".into());
43 }
44
45 if self.initial_lr <= 0.0 {
48 log::warn!(
49 "Initial learning rate value of {} is not a positive number. Ignore this warning \
50 if it is intended.",
51 self.initial_lr
52 );
53 }
54 if self.gamma <= 0.0 || self.gamma >= 1.0 {
55 log::warn!(
56 "Gamma value of {} is out of range (0.0, 1.0). Ignore this warning if it is \
57 intended.",
58 self.gamma
59 );
60 }
61
62 Ok(StepLrScheduler {
63 init_lr: self.initial_lr,
64 step_size: self.step_size,
65 gamma: self.gamma,
66 iter_idx: -1,
67 })
68 }
69}
70
71#[derive(Clone, Debug)]
73pub struct StepLrScheduler {
74 init_lr: LearningRate,
75 step_size: usize,
76 gamma: f64,
77 iter_idx: i32,
80}
81
82impl LrScheduler for StepLrScheduler {
83 type Record<B: Backend> = i32;
84
85 fn step(&mut self) -> LearningRate {
86 self.iter_idx = self
87 .iter_idx
88 .checked_add(1)
89 .expect("`.step()` should be called no more than `i32::MAX + 1` times");
90 self.init_lr
92 * self
93 .gamma
94 .powi((self.iter_idx as usize / self.step_size) as i32)
95 }
96
97 fn to_record<B: Backend>(&self) -> Self::Record<B> {
98 self.iter_idx
99 }
100
101 fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
102 self.iter_idx = record;
103 self
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::super::test_utils;
110 use super::*;
111 use crate::TestBackend;
112
113 #[test]
132 fn test_config_step_size_zero() {
133 let r = StepLrSchedulerConfig::new(1.0, 0).init();
134 assert!(r.is_err(), "Should return an error");
135 }
136
137 #[test]
138 fn test_config_step_size_nonzero() {
139 let r = StepLrSchedulerConfig::new(1.0, 1).init();
140 assert!(r.is_ok(), "Should return a success value");
141 }
142
143 #[test]
144 fn test_config_default_gamma() {
145 const INIT_LR: LearningRate = 0.4;
146 const STEP_SIZE: usize = 2;
147
148 let mut default = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
149 .init()
150 .unwrap();
151 let mut explicit = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
152 .with_gamma(0.1)
153 .init()
154 .unwrap();
155 test_utils::compare_steps(&mut default, &mut explicit, 3 * STEP_SIZE);
156 }
157
158 #[test]
159 fn test_lr_decreasing() {
160 let scheduler = StepLrSchedulerConfig::new(0.5, 3)
161 .with_gamma(0.1)
162 .init()
163 .unwrap();
164 let expected_lrs = [0.5, 0.5, 0.5, 0.05, 0.05, 0.05, 0.005, 0.005, 0.005];
165 test_utils::check_lr_sequence(scheduler, expected_lrs);
166 }
167
168 #[test]
169 fn test_lr_increasing() {
170 let scheduler = StepLrSchedulerConfig::new(0.1, 2)
171 .with_gamma(2.0)
172 .init()
173 .unwrap();
174 let expected_lrs = [0.1, 0.1, 0.2, 0.2, 0.4, 0.4];
175 test_utils::check_lr_sequence(scheduler, expected_lrs);
176 }
177
178 #[test]
179 fn test_lr_unchanging() {
180 let scheduler = StepLrSchedulerConfig::new(3.1, 1)
181 .with_gamma(1.0)
182 .init()
183 .unwrap();
184 let expected_lrs = [3.1, 3.1, 3.1];
185 test_utils::check_lr_sequence(scheduler, expected_lrs);
186 }
187
188 #[test]
189 fn test_save_and_load() {
190 const STEP_SIZE: usize = 10;
191
192 let scheduler = StepLrSchedulerConfig::new(0.007, STEP_SIZE)
193 .with_gamma(0.03)
194 .init()
195 .unwrap();
196 test_utils::check_save_load(scheduler, 3 * STEP_SIZE / 2);
197 }
198
199 #[test]
202 fn test_number_of_calls_within_limit() {
203 let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
205 scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
206 scheduler.step();
207 }
208
209 #[test]
210 #[should_panic = "i32::MAX"]
211 fn test_number_of_calls_over_limit() {
212 let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
214 scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
215 scheduler.step();
216 scheduler.step();
217 }
218}