burn_optim/lr_scheduler/
linear.rs

1use burn_core as burn;
2
3use super::{LrScheduler, String};
4use crate::LearningRate;
5use burn::config::Config;
6use burn::tensor::backend::Backend;
7
8/// The configuration for creating a [linear learning rate scheduler](LinearLrScheduler).
9///
10/// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by a
11/// constant amount on each iteration until reaching a final learning rate `final_lr`. The
12/// `num_iters` parameter controls how many iterations are needed to go from `initial_lr` to
13/// `final_lr`.
14#[derive(Config, Debug)]
15pub struct LinearLrSchedulerConfig {
16    // The initial learning rate.
17    initial_lr: LearningRate,
18    // The final learning rate.
19    final_lr: LearningRate,
20    // The number of iterations before reaching the final learning rate.
21    num_iters: usize,
22}
23
24impl LinearLrSchedulerConfig {
25    /// Initializes a [linear learning rate scheduler](LinearLrScheduler).
26    ///
27    /// # Errors
28    ///
29    /// An error will be returned if any of the following conditions is true:
30    ///
31    /// * `initial_lr` is out of range (0.0, 1.0]
32    /// * `final_lr` is out of range [0.0, 1.0]
33    /// * `num_iters` is 0
34    pub fn init(&self) -> Result<LinearLrScheduler, String> {
35        if self.initial_lr <= 0. || self.initial_lr > 1. {
36            return Err("Initial learning rate must be greater than 0 and at most 1".into());
37        }
38        if self.final_lr < 0. || self.final_lr > 1. {
39            return Err("Final learning rate must be at least 0 and at most 1".into());
40        }
41        if self.num_iters == 0 {
42            return Err("Number of iterations must be at least 1".into());
43        }
44
45        Ok(LinearLrScheduler {
46            final_lr: self.final_lr,
47            step_size: (self.final_lr - self.initial_lr) / self.num_iters as f64,
48            remaining_iters: self.num_iters + 1,
49        })
50    }
51}
52
53/// A linear learning rate scheduler.
54///
55/// See [LinearLrSchedulerConfig] for more information.
56#[derive(Clone, Copy, Debug)]
57pub struct LinearLrScheduler {
58    // The final learning rate after the linear changing process stops.
59    final_lr: LearningRate,
60    // The amount that the learning rate changes by on each iteration.
61    step_size: f64,
62    // The number of iterations left before reaching the final learning rate.
63    remaining_iters: usize,
64}
65
66impl LrScheduler for LinearLrScheduler {
67    type Record<B: Backend> = usize;
68
69    fn step(&mut self) -> LearningRate {
70        self.remaining_iters -= (self.remaining_iters != 0) as usize;
71        self.final_lr - self.step_size * self.remaining_iters as f64
72    }
73
74    fn to_record<B: Backend>(&self) -> Self::Record<B> {
75        self.remaining_iters
76    }
77
78    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
79        self.remaining_iters = record;
80        self
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::super::test_utils;
87    use super::*;
88
89    #[test]
90    fn config_initial_lr_too_low() {
91        let r = LinearLrSchedulerConfig::new(0., 0.5, 100).init();
92        assert!(r.is_err(), "Should return an error");
93        assert_eq!(
94            r.unwrap_err(),
95            "Initial learning rate must be greater than 0 and at most 1",
96            "Error messages should match",
97        );
98    }
99
100    #[test]
101    fn config_initial_lr_too_high() {
102        let r = LinearLrSchedulerConfig::new(1.5, 0.5, 100).init();
103        assert!(r.is_err(), "Should return an error");
104        assert_eq!(
105            r.unwrap_err(),
106            "Initial learning rate must be greater than 0 and at most 1",
107            "Error messages should match",
108        );
109    }
110
111    #[test]
112    fn config_final_lr_too_low() {
113        let r = LinearLrSchedulerConfig::new(0.5, -0.5, 100).init();
114        assert!(r.is_err(), "Should return an error");
115        assert_eq!(
116            r.unwrap_err(),
117            "Final learning rate must be at least 0 and at most 1",
118            "Error messages should match",
119        );
120    }
121
122    #[test]
123    fn config_final_lr_too_high() {
124        let r = LinearLrSchedulerConfig::new(0.5, 1.5, 100).init();
125        assert!(r.is_err(), "Should return an error");
126        assert_eq!(
127            r.unwrap_err(),
128            "Final learning rate must be at least 0 and at most 1",
129            "Error messages should match",
130        );
131    }
132
133    #[test]
134    fn config_num_iters_too_low() {
135        let r = LinearLrSchedulerConfig::new(0.9, 0.1, 0).init();
136        assert!(r.is_err(), "Should return an error");
137        assert_eq!(
138            r.unwrap_err(),
139            "Number of iterations must be at least 1",
140            "Error messages should match",
141        );
142    }
143
144    #[test]
145    fn test_lr_decreasing() {
146        let scheduler = LinearLrSchedulerConfig::new(0.9, 0.5, 4).init().unwrap();
147        let expected_lrs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.5];
148        test_utils::check_lr_sequence(scheduler, expected_lrs);
149    }
150
151    #[test]
152    fn test_lr_increasing() {
153        let scheduler = LinearLrSchedulerConfig::new(0.01, 0.04, 3).init().unwrap();
154        let expected_lrs = [0.01, 0.02, 0.03, 0.04, 0.04];
155        test_utils::check_lr_sequence(scheduler, expected_lrs);
156    }
157
158    #[test]
159    fn test_lr_unchanging() {
160        let scheduler = LinearLrSchedulerConfig::new(0.3, 0.3, 2).init().unwrap();
161        let expected_lrs = [0.3, 0.3, 0.3, 0.3];
162        test_utils::check_lr_sequence(scheduler, expected_lrs);
163    }
164
165    #[test]
166    fn test_save_and_load() {
167        const NUM_ITERS: usize = 6;
168        let scheduler = LinearLrSchedulerConfig::new(1.0, 0.01, NUM_ITERS)
169            .init()
170            .unwrap();
171        test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
172    }
173}