burn_core/lr_scheduler/
cosine.rs

1use super::{LrScheduler, String};
2use crate as burn;
3use crate::{LearningRate, config::Config};
4use burn_tensor::backend::Backend;
5
6/// The configuration for creating a [Cosine Annealing learning rate scheduler with warm
7/// restarts](CosineAnnealingLrScheduler).
8///
9/// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by
10/// following a cosine function. After `num_iters` iterations, the learning rate is reset to
11/// `initial_lr`.
12#[derive(Config)]
13pub struct CosineAnnealingLrSchedulerConfig {
14    // The initial learning rate.
15    initial_lr: LearningRate,
16    // The final learning rate.
17    #[config(default = 0.0)]
18    min_lr: LearningRate,
19    // The number of iterations between two restarts. The two restart iterations themselves are not
20    // included.
21    num_iters: usize,
22}
23
24impl CosineAnnealingLrSchedulerConfig {
25    /// Initializes a [Cosine learning rate scheduler](CosineAnnealingLrScheduler).
26    ///
27    /// # Errors
28    ///
29    /// An error will be returned if any of the following conditions is true:
30    ///
31    /// * `initial_lr` is out of range (0.0, 1.0]
32    /// * `min_lr` is out of range [0.0, `initial_lr`]
33    /// * `num_iters` is 0
34    pub fn init(&self) -> Result<CosineAnnealingLrScheduler, String> {
35        if self.initial_lr <= 0. || self.initial_lr > 1. {
36            return Err("Initial learning rate must be greater than 0 and at most 1".into());
37        }
38        if self.min_lr < 0.0 || self.min_lr > self.initial_lr {
39            return Err(
40                "Minimum learning rate must be at least 0 and at most equal to the initial \
41                 learning rate"
42                    .into(),
43            );
44        }
45        if self.num_iters == 0 {
46            return Err("Number of iterations must be at least 1".into());
47        }
48
49        Ok(CosineAnnealingLrScheduler {
50            min_lr: self.min_lr,
51            max_lr: self.initial_lr,
52            num_iters: self.num_iters,
53            current_iter: usize::MAX,
54        })
55    }
56}
57
58/// A Cosine Annealing learning rate scheduler.
59///
60/// This scheduler is described in [SGDR: Stochastic Gradient Descent with Warm
61/// Restarts](https://arxiv.org/abs/1608.03983). See [CosineAnnealingLrSchedulerConfig] for more
62/// information.
63#[derive(Clone, Copy, Debug)]
64pub struct CosineAnnealingLrScheduler {
65    min_lr: LearningRate,
66    max_lr: LearningRate,
67    num_iters: usize,
68    current_iter: usize,
69}
70
71impl LrScheduler for CosineAnnealingLrScheduler {
72    type Record<B: Backend> = usize;
73
74    fn step(&mut self) -> LearningRate {
75        // Make current_iter overflow from usize::MAX to 0 to get the initial learning rate on the
76        // first call. We could've used i64 with an initial value -1, but keeping it in usize saves
77        // us from some type casting here.
78        self.current_iter = self.current_iter.wrapping_add(1) % (self.num_iters + 1);
79        self.min_lr
80            + 0.5
81                * (self.max_lr - self.min_lr)
82                * (1.0
83                    + (self.current_iter as f64 / self.num_iters as f64 * std::f64::consts::PI)
84                        .cos())
85    }
86
87    fn to_record<B: Backend>(&self) -> Self::Record<B> {
88        self.current_iter
89    }
90
91    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
92        self.current_iter = record;
93        self
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::super::test_utils;
100    use super::*;
101
102    #[test]
103    fn config_initial_lr_too_low() {
104        let r = CosineAnnealingLrSchedulerConfig::new(0., 10).init();
105        assert!(r.is_err(), "Should return an error");
106        assert_eq!(
107            r.unwrap_err(),
108            "Initial learning rate must be greater than 0 and at most 1",
109            "Error messages should match",
110        );
111    }
112
113    #[test]
114    fn config_initial_lr_too_high() {
115        let r = CosineAnnealingLrSchedulerConfig::new(1.5, 10).init();
116        assert!(r.is_err(), "Should return an error");
117        assert_eq!(
118            r.unwrap_err(),
119            "Initial learning rate must be greater than 0 and at most 1",
120            "Error messages should match",
121        );
122    }
123
124    #[test]
125    fn config_min_lr_too_low() {
126        let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
127            .with_min_lr(-0.1)
128            .init();
129        assert!(r.is_err(), "Should return an error");
130        assert_eq!(
131            r.unwrap_err(),
132            "Minimum learning rate must be at least 0 and at most equal to the initial learning \
133             rate",
134            "Error messages should match",
135        );
136    }
137
138    #[test]
139    fn config_min_lr_too_high() {
140        let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
141            .with_min_lr(0.6)
142            .init();
143        assert!(r.is_err(), "Should return an error");
144        assert_eq!(
145            r.unwrap_err(),
146            "Minimum learning rate must be at least 0 and at most equal to the initial learning \
147             rate",
148            "Error messages should match",
149        );
150    }
151
152    #[test]
153    fn config_num_iters_too_low() {
154        let r = CosineAnnealingLrSchedulerConfig::new(0.5, 0).init();
155        assert!(r.is_err(), "Should return an error");
156        assert_eq!(
157            r.unwrap_err(),
158            "Number of iterations must be at least 1",
159            "Error messages should match",
160        );
161    }
162
163    #[test]
164    fn test_lr_change() {
165        const INITIAL_LR: LearningRate = 0.5;
166        const MIN_LR: LearningRate = 0.1;
167
168        let scheduler = CosineAnnealingLrSchedulerConfig::new(INITIAL_LR, 2)
169            .with_min_lr(MIN_LR)
170            .init()
171            .unwrap();
172        let expected_lrs = [
173            INITIAL_LR,                  // cos(0)
174            (INITIAL_LR + MIN_LR) * 0.5, // cos(PI/2)
175            MIN_LR,                      // cos(PI)
176            INITIAL_LR,                  // restart
177        ];
178        test_utils::check_lr_sequence(scheduler, expected_lrs);
179    }
180
181    #[test]
182    fn test_save_and_load() {
183        const NUM_ITERS: usize = 9;
184        let scheduler = CosineAnnealingLrSchedulerConfig::new(1.0, NUM_ITERS)
185            .init()
186            .unwrap();
187        test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
188    }
189}