burn_core/lr_scheduler/
step.rs

1use burn_tensor::backend::Backend;
2
3use crate as burn;
4
5use super::{LrScheduler, String};
6use crate::{LearningRate, config::Config};
7
8/// The configuration for create a [step learning rate scheduler](StepLrScheduler).
9///
10/// This scheduler returns the learning rate `initial_lr` from the start, and keeps doing so until
11/// the same value has been given for `step_size` times. Then it multiplies the learning rate by
12/// `gamma` before repeating the process.
13///
14/// Gamma values out of range (0.0, 1.0) and non-positive initial learning rates are acceptable, but
15/// a warning log will be output for such a value in case of mistyping.
16///
17/// ## Notes
18///
19/// The [step](StepLrScheduler::step) method of the scheduler panics if it is called more than
20/// `i32::MAX + 1` times.
21#[derive(Config)]
22pub struct StepLrSchedulerConfig {
23    // The learning rate at the initial step.
24    initial_lr: LearningRate,
25    // The number of iterations over which the learning rate remains unchanged before the next
26    // update.
27    step_size: usize,
28    /// The factor by which the learning rate is multiplied with each update. Default: 0.1.
29    #[config(default = 0.1)]
30    gamma: f64,
31}
32
33impl StepLrSchedulerConfig {
34    /// Initializes a [step learning rate scheduler](StepLrScheduler).
35    ///
36    /// # Errors
37    ///
38    /// An error will be returned if `step_size` is 0.
39    pub fn init(&self) -> Result<StepLrScheduler, String> {
40        if self.step_size == 0 {
41            return Err("Step size must be greater than 0".into());
42        }
43
44        // Atypical values of `initial_lr` and `gamma` are not rejected because they might be useful
45        // in some cases like debugging (e.g., https://datascience.stackexchange.com/q/89518).
46        if self.initial_lr <= 0.0 {
47            log::warn!(
48                "Initial learning rate value of {} is not a positive number. Ignore this warning \
49                 if it is intended.",
50                self.initial_lr
51            );
52        }
53        if self.gamma <= 0.0 || self.gamma >= 1.0 {
54            log::warn!(
55                "Gamma value of {} is out of range (0.0, 1.0). Ignore this warning if it is \
56                 intended.",
57                self.gamma
58            );
59        }
60
61        Ok(StepLrScheduler {
62            init_lr: self.initial_lr,
63            step_size: self.step_size,
64            gamma: self.gamma,
65            iter_idx: -1,
66        })
67    }
68}
69
70/// Step learning rate scheduler.
71#[derive(Clone, Debug)]
72pub struct StepLrScheduler {
73    init_lr: LearningRate,
74    step_size: usize,
75    gamma: f64,
76    // The index of the current iteration.
77    // `i32` is used for avoiding truncating the exponent when taking powers of `gamma`.
78    iter_idx: i32,
79}
80
81impl LrScheduler for StepLrScheduler {
82    type Record<B: Backend> = i32;
83
84    fn step(&mut self) -> LearningRate {
85        self.iter_idx = self
86            .iter_idx
87            .checked_add(1)
88            .expect("`.step()` should be called no more than `i32::MAX + 1` times");
89        // Type casting below causes no truncation, as all the values fall within the ranges.
90        self.init_lr
91            * self
92                .gamma
93                .powi((self.iter_idx as usize / self.step_size) as i32)
94    }
95
96    fn to_record<B: Backend>(&self) -> Self::Record<B> {
97        self.iter_idx
98    }
99
100    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
101        self.iter_idx = record;
102        self
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::super::test_utils;
109    use super::*;
110    use crate::TestBackend;
111
112    // Warning logs for initial LR and gamma are not tested because there seems no straightforward
113    // way to do it.
114    //
115    // Creating a mock logger that collects logs into `String` for later examination seems a possible
116    // solution, but unit tests run in the same process in parallel, where the single logger would
117    // be shared by multiple tests, so logs from different tests would be mixed up with no easy way
118    // to separate them.
119    // Using "--test-threads=1" could prevent mixup, but whether the ability to test logging is
120    // worth the slowdown would be a question. Also, using a primitive provided by `std` to
121    // synchronize the logger across tests is not an option since we need to support `no-std`.
122    // Maybe the mocking approach can be reconsidered after we are given an option to run tests in
123    // separate processes like what the issue below is proposing:
124    //     https://github.com/rust-lang/rust/issues/47506
125    //
126    // As a side note, a helper crate exists for the exact purpose:
127    //     https://crates.io/crates/testing_logger
128    // but the crate has been unmaintained and using it would introduce another dependency.
129
130    #[test]
131    fn test_config_step_size_zero() {
132        let r = StepLrSchedulerConfig::new(1.0, 0).init();
133        assert!(r.is_err(), "Should return an error");
134    }
135
136    #[test]
137    fn test_config_step_size_nonzero() {
138        let r = StepLrSchedulerConfig::new(1.0, 1).init();
139        assert!(r.is_ok(), "Should return a success value");
140    }
141
142    #[test]
143    fn test_config_default_gamma() {
144        const INIT_LR: LearningRate = 0.4;
145        const STEP_SIZE: usize = 2;
146
147        let mut default = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
148            .init()
149            .unwrap();
150        let mut explicit = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
151            .with_gamma(0.1)
152            .init()
153            .unwrap();
154        test_utils::compare_steps(&mut default, &mut explicit, 3 * STEP_SIZE);
155    }
156
157    #[test]
158    fn test_lr_decreasing() {
159        let scheduler = StepLrSchedulerConfig::new(0.5, 3)
160            .with_gamma(0.1)
161            .init()
162            .unwrap();
163        let expected_lrs = [0.5, 0.5, 0.5, 0.05, 0.05, 0.05, 0.005, 0.005, 0.005];
164        test_utils::check_lr_sequence(scheduler, expected_lrs);
165    }
166
167    #[test]
168    fn test_lr_increasing() {
169        let scheduler = StepLrSchedulerConfig::new(0.1, 2)
170            .with_gamma(2.0)
171            .init()
172            .unwrap();
173        let expected_lrs = [0.1, 0.1, 0.2, 0.2, 0.4, 0.4];
174        test_utils::check_lr_sequence(scheduler, expected_lrs);
175    }
176
177    #[test]
178    fn test_lr_unchanging() {
179        let scheduler = StepLrSchedulerConfig::new(3.1, 1)
180            .with_gamma(1.0)
181            .init()
182            .unwrap();
183        let expected_lrs = [3.1, 3.1, 3.1];
184        test_utils::check_lr_sequence(scheduler, expected_lrs);
185    }
186
187    #[test]
188    fn test_save_and_load() {
189        const STEP_SIZE: usize = 10;
190
191        let scheduler = StepLrSchedulerConfig::new(0.007, STEP_SIZE)
192            .with_gamma(0.03)
193            .init()
194            .unwrap();
195        test_utils::check_save_load(scheduler, 3 * STEP_SIZE / 2);
196    }
197
198    // It's too time consuming to actually run a scheduler `i32::MAX` steps, so an approach that
199    // depends on private fields is used to implement the test.
200    #[test]
201    fn test_number_of_calls_within_limit() {
202        // Create a scheduler that has already run `i32::MAX` steps
203        let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
204        scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
205        scheduler.step();
206    }
207
208    #[test]
209    #[should_panic = "i32::MAX"]
210    fn test_number_of_calls_over_limit() {
211        // Create a scheduler that has already run `i32::MAX` steps
212        let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
213        scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
214        scheduler.step();
215        scheduler.step();
216    }
217}