Skip to main content

burn_optim/lr_scheduler/
cosine.rs

1use burn_core as burn;
2
3use super::{LrScheduler, String};
4use crate::LearningRate;
5use burn::config::Config;
6use burn::tensor::backend::Backend;
7
8/// The configuration for creating a [Cosine Annealing learning rate scheduler with warm
9/// restarts](CosineAnnealingLrScheduler).
10///
11/// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by
12/// following a cosine function. After `num_iters` iterations, the learning rate is reset to
13/// `initial_lr`.
14#[derive(Config, Debug)]
15pub struct CosineAnnealingLrSchedulerConfig {
16    // The initial learning rate.
17    initial_lr: LearningRate,
18    // The final learning rate.
19    #[config(default = 0.0)]
20    min_lr: LearningRate,
21    // The number of iterations between two restarts. The two restart iterations themselves are not
22    // included.
23    num_iters: usize,
24}
25
26impl CosineAnnealingLrSchedulerConfig {
27    /// Initializes a [Cosine learning rate scheduler](CosineAnnealingLrScheduler).
28    ///
29    /// # Errors
30    ///
31    /// An error will be returned if any of the following conditions is true:
32    ///
33    /// * `initial_lr` is out of range (0.0, 1.0]
34    /// * `min_lr` is out of range [0.0, `initial_lr`]
35    /// * `num_iters` is 0
36    pub fn init(&self) -> Result<CosineAnnealingLrScheduler, String> {
37        if self.initial_lr <= 0. || self.initial_lr > 1. {
38            return Err("Initial learning rate must be greater than 0 and at most 1".into());
39        }
40        if self.min_lr < 0.0 || self.min_lr > self.initial_lr {
41            return Err(
42                "Minimum learning rate must be at least 0 and at most equal to the initial \
43                 learning rate"
44                    .into(),
45            );
46        }
47        if self.num_iters == 0 {
48            return Err("Number of iterations must be at least 1".into());
49        }
50
51        Ok(CosineAnnealingLrScheduler {
52            min_lr: self.min_lr,
53            max_lr: self.initial_lr,
54            num_iters: self.num_iters,
55            current_iter: usize::MAX,
56        })
57    }
58}
59
60/// A Cosine Annealing learning rate scheduler.
61///
62/// This scheduler is described in [SGDR: Stochastic Gradient Descent with Warm
63/// Restarts](https://arxiv.org/abs/1608.03983). See [CosineAnnealingLrSchedulerConfig] for more
64/// information.
65#[derive(Clone, Copy, Debug)]
66pub struct CosineAnnealingLrScheduler {
67    min_lr: LearningRate,
68    max_lr: LearningRate,
69    num_iters: usize,
70    current_iter: usize,
71}
72
73impl LrScheduler for CosineAnnealingLrScheduler {
74    type Record<B: Backend> = usize;
75
76    fn step(&mut self) -> LearningRate {
77        // Make current_iter overflow from usize::MAX to 0 to get the initial learning rate on the
78        // first call. We could've used i64 with an initial value -1, but keeping it in usize saves
79        // us from some type casting here.
80        self.current_iter = self.current_iter.wrapping_add(1) % (self.num_iters + 1);
81        self.min_lr
82            + 0.5
83                * (self.max_lr - self.min_lr)
84                * (1.0
85                    + (self.current_iter as f64 / self.num_iters as f64 * std::f64::consts::PI)
86                        .cos())
87    }
88
89    fn to_record<B: Backend>(&self) -> Self::Record<B> {
90        self.current_iter
91    }
92
93    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
94        self.current_iter = record;
95        self
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::super::test_utils;
102    use super::*;
103
104    #[test]
105    fn config_initial_lr_too_low() {
106        let r = CosineAnnealingLrSchedulerConfig::new(0., 10).init();
107        assert!(r.is_err(), "Should return an error");
108        assert_eq!(
109            r.unwrap_err(),
110            "Initial learning rate must be greater than 0 and at most 1",
111            "Error messages should match",
112        );
113    }
114
115    #[test]
116    fn config_initial_lr_too_high() {
117        let r = CosineAnnealingLrSchedulerConfig::new(1.5, 10).init();
118        assert!(r.is_err(), "Should return an error");
119        assert_eq!(
120            r.unwrap_err(),
121            "Initial learning rate must be greater than 0 and at most 1",
122            "Error messages should match",
123        );
124    }
125
126    #[test]
127    fn config_min_lr_too_low() {
128        let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
129            .with_min_lr(-0.1)
130            .init();
131        assert!(r.is_err(), "Should return an error");
132        assert_eq!(
133            r.unwrap_err(),
134            "Minimum learning rate must be at least 0 and at most equal to the initial learning \
135             rate",
136            "Error messages should match",
137        );
138    }
139
140    #[test]
141    fn config_min_lr_too_high() {
142        let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
143            .with_min_lr(0.6)
144            .init();
145        assert!(r.is_err(), "Should return an error");
146        assert_eq!(
147            r.unwrap_err(),
148            "Minimum learning rate must be at least 0 and at most equal to the initial learning \
149             rate",
150            "Error messages should match",
151        );
152    }
153
154    #[test]
155    fn config_num_iters_too_low() {
156        let r = CosineAnnealingLrSchedulerConfig::new(0.5, 0).init();
157        assert!(r.is_err(), "Should return an error");
158        assert_eq!(
159            r.unwrap_err(),
160            "Number of iterations must be at least 1",
161            "Error messages should match",
162        );
163    }
164
165    #[test]
166    fn test_lr_change() {
167        const INITIAL_LR: LearningRate = 0.5;
168        const MIN_LR: LearningRate = 0.1;
169
170        let scheduler = CosineAnnealingLrSchedulerConfig::new(INITIAL_LR, 2)
171            .with_min_lr(MIN_LR)
172            .init()
173            .unwrap();
174        let expected_lrs = [
175            INITIAL_LR,                  // cos(0)
176            (INITIAL_LR + MIN_LR) * 0.5, // cos(PI/2)
177            MIN_LR,                      // cos(PI)
178            INITIAL_LR,                  // restart
179        ];
180        test_utils::check_lr_sequence(scheduler, expected_lrs);
181    }
182
183    #[test]
184    fn test_save_and_load() {
185        const NUM_ITERS: usize = 9;
186        let scheduler = CosineAnnealingLrSchedulerConfig::new(1.0, NUM_ITERS)
187            .init()
188            .unwrap();
189        test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
190    }
191}