burn_core/lr_scheduler/
exponential.rs

1use super::{LrScheduler, String};
2use crate as burn;
3use crate::{LearningRate, config::Config};
4use burn_tensor::backend::Backend;
5
6/// The configuration for creating an [exponential learning rate scheduler](ExponentialLrScheduler).
7///
8/// This scheduler returns the learning rate `initial_lr` at the first step, then multiplies it by
9/// a constant `gamma` at every iteration. At any iteration `i` (which starts from 0), the learning
10/// rate is given by `initial_lr * gamma^i`.
11#[derive(Config)]
12pub struct ExponentialLrSchedulerConfig {
13    // The initial learning rate.
14    initial_lr: LearningRate,
15    // The constant that the learning rate is multiplied by on each iteration.
16    gamma: f64,
17}
18
19impl ExponentialLrSchedulerConfig {
20    /// Initializes a [exponential learning rate scheduler](ExponentialLrScheduler).
21    ///
22    /// # Errors
23    ///
24    /// An error will be returned if any of the following conditions is true:
25    ///
26    /// * `initial_lr` is out of range (0.0, 1.0]
27    /// * `gamma` is out of range (0.0, 1.0]
28    pub fn init(&self) -> Result<ExponentialLrScheduler, String> {
29        if self.initial_lr <= 0. || self.initial_lr > 1. {
30            return Err("Initial learning rate must be greater than 0 and at most 1".into());
31        }
32        if self.gamma <= 0. || self.gamma > 1. {
33            return Err("Gamma must be greater than 0 and at most 1".into());
34        }
35
36        Ok(ExponentialLrScheduler {
37            // Such an initial value eliminates the need for special-case handling of the first
38            // learning rate.
39            previous_lr: self.initial_lr / self.gamma,
40            gamma: self.gamma,
41        })
42    }
43}
44
45/// A exponential learning rate scheduler.
46///
47/// See [ExponentialLrSchedulerConfig] for more information.
48#[derive(Clone, Copy, Debug)]
49pub struct ExponentialLrScheduler {
50    // The previous iteration's learning rate.
51    previous_lr: LearningRate,
52    // The constant that the learning rate is multiplied by on each iteration.
53    gamma: f64,
54}
55
56impl LrScheduler for ExponentialLrScheduler {
57    type Record<B: Backend> = LearningRate;
58
59    fn step(&mut self) -> LearningRate {
60        self.previous_lr *= self.gamma;
61        self.previous_lr
62    }
63
64    fn to_record<B: Backend>(&self) -> Self::Record<B> {
65        self.previous_lr
66    }
67
68    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
69        self.previous_lr = record;
70        self
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::super::test_utils;
77    use super::*;
78
79    #[test]
80    fn config_initial_lr_too_low() {
81        let r = ExponentialLrSchedulerConfig::new(0., 0.5).init();
82        assert!(r.is_err(), "Should return an error");
83        assert_eq!(
84            r.unwrap_err(),
85            "Initial learning rate must be greater than 0 and at most 1",
86            "Error messages should match",
87        );
88    }
89
90    #[test]
91    fn config_initial_lr_too_high() {
92        let r = ExponentialLrSchedulerConfig::new(1.5, 0.5).init();
93        assert!(r.is_err(), "Should return an error");
94        assert_eq!(
95            r.unwrap_err(),
96            "Initial learning rate must be greater than 0 and at most 1",
97            "Error messages should match",
98        );
99    }
100
101    #[test]
102    fn config_gamma_too_low() {
103        let r = ExponentialLrSchedulerConfig::new(0.5, 0.0).init();
104        assert!(r.is_err(), "Should return an error");
105        assert_eq!(
106            r.unwrap_err(),
107            "Gamma must be greater than 0 and at most 1",
108            "Error messages should match",
109        );
110    }
111
112    #[test]
113    fn config_gamma_too_high() {
114        let r = ExponentialLrSchedulerConfig::new(0.5, 1.5).init();
115        assert!(r.is_err(), "Should return an error");
116        assert_eq!(
117            r.unwrap_err(),
118            "Gamma must be greater than 0 and at most 1",
119            "Error messages should match",
120        );
121    }
122
123    #[test]
124    fn test_lr_change() {
125        let scheduler = ExponentialLrSchedulerConfig::new(0.8, 0.1).init().unwrap();
126        let expected_lrs = [0.8, 0.08, 0.008, 0.0008, 0.00008];
127        test_utils::check_lr_sequence(scheduler, expected_lrs);
128    }
129
130    #[test]
131    fn test_save_and_load() {
132        let scheduler = ExponentialLrSchedulerConfig::new(0.083, 0.3)
133            .init()
134            .unwrap();
135        test_utils::check_save_load(scheduler, 7);
136    }
137}