burn_optim/lr_scheduler/
step.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::backend::Backend;
5
6use super::{LrScheduler, String};
7use crate::LearningRate;
8
9/// The configuration for create a [step learning rate scheduler](StepLrScheduler).
10///
11/// This scheduler returns the learning rate `initial_lr` from the start, and keeps doing so until
12/// the same value has been given for `step_size` times. Then it multiplies the learning rate by
13/// `gamma` before repeating the process.
14///
15/// Gamma values out of range (0.0, 1.0) and non-positive initial learning rates are acceptable, but
16/// a warning log will be output for such a value in case of mistyping.
17///
18/// ## Notes
19///
20/// The [step](StepLrScheduler::step) method of the scheduler panics if it is called more than
21/// `i32::MAX + 1` times.
22#[derive(Config, Debug)]
23pub struct StepLrSchedulerConfig {
24    // The learning rate at the initial step.
25    initial_lr: LearningRate,
26    // The number of iterations over which the learning rate remains unchanged before the next
27    // update.
28    step_size: usize,
29    /// The factor by which the learning rate is multiplied with each update. Default: 0.1.
30    #[config(default = 0.1)]
31    gamma: f64,
32}
33
34impl StepLrSchedulerConfig {
35    /// Initializes a [step learning rate scheduler](StepLrScheduler).
36    ///
37    /// # Errors
38    ///
39    /// An error will be returned if `step_size` is 0.
40    pub fn init(&self) -> Result<StepLrScheduler, String> {
41        if self.step_size == 0 {
42            return Err("Step size must be greater than 0".into());
43        }
44
45        // Atypical values of `initial_lr` and `gamma` are not rejected because they might be useful
46        // in some cases like debugging (e.g., https://datascience.stackexchange.com/q/89518).
47        if self.initial_lr <= 0.0 {
48            log::warn!(
49                "Initial learning rate value of {} is not a positive number. Ignore this warning \
50                 if it is intended.",
51                self.initial_lr
52            );
53        }
54        if self.gamma <= 0.0 || self.gamma >= 1.0 {
55            log::warn!(
56                "Gamma value of {} is out of range (0.0, 1.0). Ignore this warning if it is \
57                 intended.",
58                self.gamma
59            );
60        }
61
62        Ok(StepLrScheduler {
63            init_lr: self.initial_lr,
64            step_size: self.step_size,
65            gamma: self.gamma,
66            iter_idx: -1,
67        })
68    }
69}
70
71/// Step learning rate scheduler.
72#[derive(Clone, Debug)]
73pub struct StepLrScheduler {
74    init_lr: LearningRate,
75    step_size: usize,
76    gamma: f64,
77    // The index of the current iteration.
78    // `i32` is used for avoiding truncating the exponent when taking powers of `gamma`.
79    iter_idx: i32,
80}
81
82impl LrScheduler for StepLrScheduler {
83    type Record<B: Backend> = i32;
84
85    fn step(&mut self) -> LearningRate {
86        self.iter_idx = self
87            .iter_idx
88            .checked_add(1)
89            .expect("`.step()` should be called no more than `i32::MAX + 1` times");
90        // Type casting below causes no truncation, as all the values fall within the ranges.
91        self.init_lr
92            * self
93                .gamma
94                .powi((self.iter_idx as usize / self.step_size) as i32)
95    }
96
97    fn to_record<B: Backend>(&self) -> Self::Record<B> {
98        self.iter_idx
99    }
100
101    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
102        self.iter_idx = record;
103        self
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::super::test_utils;
110    use super::*;
111    use crate::TestBackend;
112
113    // Warning logs for initial LR and gamma are not tested because there seems no straightforward
114    // way to do it.
115    //
116    // Creating a mock logger that collects logs into `String` for later examination seems a possible
117    // solution, but unit tests run in the same process in parallel, where the single logger would
118    // be shared by multiple tests, so logs from different tests would be mixed up with no easy way
119    // to separate them.
120    // Using "--test-threads=1" could prevent mixup, but whether the ability to test logging is
121    // worth the slowdown would be a question. Also, using a primitive provided by `std` to
122    // synchronize the logger across tests is not an option since we need to support `no-std`.
123    // Maybe the mocking approach can be reconsidered after we are given an option to run tests in
124    // separate processes like what the issue below is proposing:
125    //     https://github.com/rust-lang/rust/issues/47506
126    //
127    // As a side note, a helper crate exists for the exact purpose:
128    //     https://crates.io/crates/testing_logger
129    // but the crate has been unmaintained and using it would introduce another dependency.
130
131    #[test]
132    fn test_config_step_size_zero() {
133        let r = StepLrSchedulerConfig::new(1.0, 0).init();
134        assert!(r.is_err(), "Should return an error");
135    }
136
137    #[test]
138    fn test_config_step_size_nonzero() {
139        let r = StepLrSchedulerConfig::new(1.0, 1).init();
140        assert!(r.is_ok(), "Should return a success value");
141    }
142
143    #[test]
144    fn test_config_default_gamma() {
145        const INIT_LR: LearningRate = 0.4;
146        const STEP_SIZE: usize = 2;
147
148        let mut default = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
149            .init()
150            .unwrap();
151        let mut explicit = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
152            .with_gamma(0.1)
153            .init()
154            .unwrap();
155        test_utils::compare_steps(&mut default, &mut explicit, 3 * STEP_SIZE);
156    }
157
158    #[test]
159    fn test_lr_decreasing() {
160        let scheduler = StepLrSchedulerConfig::new(0.5, 3)
161            .with_gamma(0.1)
162            .init()
163            .unwrap();
164        let expected_lrs = [0.5, 0.5, 0.5, 0.05, 0.05, 0.05, 0.005, 0.005, 0.005];
165        test_utils::check_lr_sequence(scheduler, expected_lrs);
166    }
167
168    #[test]
169    fn test_lr_increasing() {
170        let scheduler = StepLrSchedulerConfig::new(0.1, 2)
171            .with_gamma(2.0)
172            .init()
173            .unwrap();
174        let expected_lrs = [0.1, 0.1, 0.2, 0.2, 0.4, 0.4];
175        test_utils::check_lr_sequence(scheduler, expected_lrs);
176    }
177
178    #[test]
179    fn test_lr_unchanging() {
180        let scheduler = StepLrSchedulerConfig::new(3.1, 1)
181            .with_gamma(1.0)
182            .init()
183            .unwrap();
184        let expected_lrs = [3.1, 3.1, 3.1];
185        test_utils::check_lr_sequence(scheduler, expected_lrs);
186    }
187
188    #[test]
189    fn test_save_and_load() {
190        const STEP_SIZE: usize = 10;
191
192        let scheduler = StepLrSchedulerConfig::new(0.007, STEP_SIZE)
193            .with_gamma(0.03)
194            .init()
195            .unwrap();
196        test_utils::check_save_load(scheduler, 3 * STEP_SIZE / 2);
197    }
198
199    // It's too time consuming to actually run a scheduler `i32::MAX` steps, so an approach that
200    // depends on private fields is used to implement the test.
201    #[test]
202    fn test_number_of_calls_within_limit() {
203        // Create a scheduler that has already run `i32::MAX` steps
204        let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
205        scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
206        scheduler.step();
207    }
208
209    #[test]
210    #[should_panic = "i32::MAX"]
211    fn test_number_of_calls_over_limit() {
212        // Create a scheduler that has already run `i32::MAX` steps
213        let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
214        scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
215        scheduler.step();
216        scheduler.step();
217    }
218}