burn_core/lr_scheduler/
linear.rs

1use super::{LrScheduler, String};
2use crate as burn;
3use crate::{LearningRate, config::Config};
4use burn_tensor::backend::Backend;
5
6/// The configuration for creating a [linear learning rate scheduler](LinearLrScheduler).
7///
8/// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by a
9/// constant amount on each iteration until reaching a final learning rate `final_lr`. The
10/// `num_iters` parameter controls how many iterations are needed to go from `initial_lr` to
11/// `final_lr`.
12#[derive(Config)]
13pub struct LinearLrSchedulerConfig {
14    // The initial learning rate.
15    initial_lr: LearningRate,
16    // The final learning rate.
17    final_lr: LearningRate,
18    // The number of iterations before reaching the final learning rate.
19    num_iters: usize,
20}
21
22impl LinearLrSchedulerConfig {
23    /// Initializes a [linear learning rate scheduler](LinearLrScheduler).
24    ///
25    /// # Errors
26    ///
27    /// An error will be returned if any of the following conditions is true:
28    ///
29    /// * `initial_lr` is out of range (0.0, 1.0]
30    /// * `final_lr` is out of range [0.0, 1.0]
31    /// * `num_iters` is 0
32    pub fn init(&self) -> Result<LinearLrScheduler, String> {
33        if self.initial_lr <= 0. || self.initial_lr > 1. {
34            return Err("Initial learning rate must be greater than 0 and at most 1".into());
35        }
36        if self.final_lr < 0. || self.final_lr > 1. {
37            return Err("Final learning rate must be at least 0 and at most 1".into());
38        }
39        if self.num_iters == 0 {
40            return Err("Number of iterations must be at least 1".into());
41        }
42
43        Ok(LinearLrScheduler {
44            final_lr: self.final_lr,
45            step_size: (self.final_lr - self.initial_lr) / self.num_iters as f64,
46            remaining_iters: self.num_iters + 1,
47        })
48    }
49}
50
51/// A linear learning rate scheduler.
52///
53/// See [LinearLrSchedulerConfig] for more information.
54#[derive(Clone, Copy, Debug)]
55pub struct LinearLrScheduler {
56    // The final learning rate after the linear changing process stops.
57    final_lr: LearningRate,
58    // The amount that the learning rate changes by on each iteration.
59    step_size: f64,
60    // The number of iterations left before reaching the final learning rate.
61    remaining_iters: usize,
62}
63
64impl LrScheduler for LinearLrScheduler {
65    type Record<B: Backend> = usize;
66
67    fn step(&mut self) -> LearningRate {
68        self.remaining_iters -= (self.remaining_iters != 0) as usize;
69        self.final_lr - self.step_size * self.remaining_iters as f64
70    }
71
72    fn to_record<B: Backend>(&self) -> Self::Record<B> {
73        self.remaining_iters
74    }
75
76    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
77        self.remaining_iters = record;
78        self
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::super::test_utils;
85    use super::*;
86
87    #[test]
88    fn config_initial_lr_too_low() {
89        let r = LinearLrSchedulerConfig::new(0., 0.5, 100).init();
90        assert!(r.is_err(), "Should return an error");
91        assert_eq!(
92            r.unwrap_err(),
93            "Initial learning rate must be greater than 0 and at most 1",
94            "Error messages should match",
95        );
96    }
97
98    #[test]
99    fn config_initial_lr_too_high() {
100        let r = LinearLrSchedulerConfig::new(1.5, 0.5, 100).init();
101        assert!(r.is_err(), "Should return an error");
102        assert_eq!(
103            r.unwrap_err(),
104            "Initial learning rate must be greater than 0 and at most 1",
105            "Error messages should match",
106        );
107    }
108
109    #[test]
110    fn config_final_lr_too_low() {
111        let r = LinearLrSchedulerConfig::new(0.5, -0.5, 100).init();
112        assert!(r.is_err(), "Should return an error");
113        assert_eq!(
114            r.unwrap_err(),
115            "Final learning rate must be at least 0 and at most 1",
116            "Error messages should match",
117        );
118    }
119
120    #[test]
121    fn config_final_lr_too_high() {
122        let r = LinearLrSchedulerConfig::new(0.5, 1.5, 100).init();
123        assert!(r.is_err(), "Should return an error");
124        assert_eq!(
125            r.unwrap_err(),
126            "Final learning rate must be at least 0 and at most 1",
127            "Error messages should match",
128        );
129    }
130
131    #[test]
132    fn config_num_iters_too_low() {
133        let r = LinearLrSchedulerConfig::new(0.9, 0.1, 0).init();
134        assert!(r.is_err(), "Should return an error");
135        assert_eq!(
136            r.unwrap_err(),
137            "Number of iterations must be at least 1",
138            "Error messages should match",
139        );
140    }
141
142    #[test]
143    fn test_lr_decreasing() {
144        let scheduler = LinearLrSchedulerConfig::new(0.9, 0.5, 4).init().unwrap();
145        let expected_lrs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.5];
146        test_utils::check_lr_sequence(scheduler, expected_lrs);
147    }
148
149    #[test]
150    fn test_lr_increasing() {
151        let scheduler = LinearLrSchedulerConfig::new(0.01, 0.04, 3).init().unwrap();
152        let expected_lrs = [0.01, 0.02, 0.03, 0.04, 0.04];
153        test_utils::check_lr_sequence(scheduler, expected_lrs);
154    }
155
156    #[test]
157    fn test_lr_unchanging() {
158        let scheduler = LinearLrSchedulerConfig::new(0.3, 0.3, 2).init().unwrap();
159        let expected_lrs = [0.3, 0.3, 0.3, 0.3];
160        test_utils::check_lr_sequence(scheduler, expected_lrs);
161    }
162
163    #[test]
164    fn test_save_and_load() {
165        const NUM_ITERS: usize = 6;
166        let scheduler = LinearLrSchedulerConfig::new(1.0, 0.01, NUM_ITERS)
167            .init()
168            .unwrap();
169        test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
170    }
171}