burn_optim/lr_scheduler/
exponential.rs

1use burn_core as burn;
2
3use super::{LrScheduler, String};
4use crate::LearningRate;
5use burn::config::Config;
6use burn::tensor::backend::Backend;
7
8/// The configuration for creating an [exponential learning rate scheduler](ExponentialLrScheduler).
9///
10/// This scheduler returns the learning rate `initial_lr` at the first step, then multiplies it by
11/// a constant `gamma` at every iteration. At any iteration `i` (which starts from 0), the learning
12/// rate is given by `initial_lr * gamma^i`.
13#[derive(Config, Debug)]
14pub struct ExponentialLrSchedulerConfig {
15    // The initial learning rate.
16    initial_lr: LearningRate,
17    // The constant that the learning rate is multiplied by on each iteration.
18    gamma: f64,
19}
20
21impl ExponentialLrSchedulerConfig {
22    /// Initializes a [exponential learning rate scheduler](ExponentialLrScheduler).
23    ///
24    /// # Errors
25    ///
26    /// An error will be returned if any of the following conditions is true:
27    ///
28    /// * `initial_lr` is out of range (0.0, 1.0]
29    /// * `gamma` is out of range (0.0, 1.0]
30    pub fn init(&self) -> Result<ExponentialLrScheduler, String> {
31        if self.initial_lr <= 0. || self.initial_lr > 1. {
32            return Err("Initial learning rate must be greater than 0 and at most 1".into());
33        }
34        if self.gamma <= 0. || self.gamma > 1. {
35            return Err("Gamma must be greater than 0 and at most 1".into());
36        }
37
38        Ok(ExponentialLrScheduler {
39            // Such an initial value eliminates the need for special-case handling of the first
40            // learning rate.
41            previous_lr: self.initial_lr / self.gamma,
42            gamma: self.gamma,
43        })
44    }
45}
46
47/// A exponential learning rate scheduler.
48///
49/// See [ExponentialLrSchedulerConfig] for more information.
50#[derive(Clone, Copy, Debug)]
51pub struct ExponentialLrScheduler {
52    // The previous iteration's learning rate.
53    previous_lr: LearningRate,
54    // The constant that the learning rate is multiplied by on each iteration.
55    gamma: f64,
56}
57
58impl LrScheduler for ExponentialLrScheduler {
59    type Record<B: Backend> = LearningRate;
60
61    fn step(&mut self) -> LearningRate {
62        self.previous_lr *= self.gamma;
63        self.previous_lr
64    }
65
66    fn to_record<B: Backend>(&self) -> Self::Record<B> {
67        self.previous_lr
68    }
69
70    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
71        self.previous_lr = record;
72        self
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::super::test_utils;
79    use super::*;
80
81    #[test]
82    fn config_initial_lr_too_low() {
83        let r = ExponentialLrSchedulerConfig::new(0., 0.5).init();
84        assert!(r.is_err(), "Should return an error");
85        assert_eq!(
86            r.unwrap_err(),
87            "Initial learning rate must be greater than 0 and at most 1",
88            "Error messages should match",
89        );
90    }
91
92    #[test]
93    fn config_initial_lr_too_high() {
94        let r = ExponentialLrSchedulerConfig::new(1.5, 0.5).init();
95        assert!(r.is_err(), "Should return an error");
96        assert_eq!(
97            r.unwrap_err(),
98            "Initial learning rate must be greater than 0 and at most 1",
99            "Error messages should match",
100        );
101    }
102
103    #[test]
104    fn config_gamma_too_low() {
105        let r = ExponentialLrSchedulerConfig::new(0.5, 0.0).init();
106        assert!(r.is_err(), "Should return an error");
107        assert_eq!(
108            r.unwrap_err(),
109            "Gamma must be greater than 0 and at most 1",
110            "Error messages should match",
111        );
112    }
113
114    #[test]
115    fn config_gamma_too_high() {
116        let r = ExponentialLrSchedulerConfig::new(0.5, 1.5).init();
117        assert!(r.is_err(), "Should return an error");
118        assert_eq!(
119            r.unwrap_err(),
120            "Gamma must be greater than 0 and at most 1",
121            "Error messages should match",
122        );
123    }
124
125    #[test]
126    fn test_lr_change() {
127        let scheduler = ExponentialLrSchedulerConfig::new(0.8, 0.1).init().unwrap();
128        let expected_lrs = [0.8, 0.08, 0.008, 0.0008, 0.00008];
129        test_utils::check_lr_sequence(scheduler, expected_lrs);
130    }
131
132    #[test]
133    fn test_save_and_load() {
134        let scheduler = ExponentialLrSchedulerConfig::new(0.083, 0.3)
135            .init()
136            .unwrap();
137        test_utils::check_save_load(scheduler, 7);
138    }
139}