Skip to main content

burn_optim/lr_scheduler/
noam.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::backend::Backend;
5
6use super::{LrScheduler, String};
7use crate::LearningRate;
8
9/// Configuration to create a [noam](NoamLrScheduler) learning rate scheduler.
10#[derive(Config, Debug)]
11pub struct NoamLrSchedulerConfig {
12    /// The overall scale factor for the learning rate decay.
13    factor: f64,
14    /// The number of steps before the exponential decay stats.
15    #[config(default = 4000)]
16    warmup_steps: usize,
17    /// The size of the model.
18    #[config(default = 512)]
19    model_size: usize,
20}
21
22/// Noam learning rate scheduler as described in [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
23#[derive(Clone, Debug)]
24pub struct NoamLrScheduler {
25    warmup_steps: f64,
26    embedding_size: f64,
27    factor: f64,
28    step: f64,
29}
30
31impl NoamLrSchedulerConfig {
32    /// Initialize a new [noam](NoamLrScheduler) learning rate scheduler.
33    ///
34    /// # Errors
35    ///
36    /// An error will be returned if any of the following conditions is true:
37    ///
38    /// * `warmup_steps` is 0
39    /// * `model_size` is 0
40    pub fn init(&self) -> Result<NoamLrScheduler, String> {
41        if self.warmup_steps == 0 {
42            return Err(
43                "Number of steps before exponential decay starts must be greater than 0".into(),
44            );
45        }
46        if self.model_size == 0 {
47            return Err("Model size must be greater than 0".into());
48        }
49
50        Ok(NoamLrScheduler {
51            warmup_steps: self.warmup_steps as f64,
52            embedding_size: self.model_size as f64,
53            factor: self.factor,
54            step: 0.0,
55        })
56    }
57}
58
59impl LrScheduler for NoamLrScheduler {
60    type Record<B: Backend> = usize;
61
62    fn step(&mut self) -> LearningRate {
63        self.step += 1.0;
64
65        let arg1 = self.step.powf(-0.5);
66        let arg2 = self.step * self.warmup_steps.powf(-1.5);
67
68        self.factor * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2)
69    }
70
71    fn to_record<B: Backend>(&self) -> Self::Record<B> {
72        self.step as usize
73    }
74
75    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
76        self.step = record as f64;
77        self
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn test_config_warmup_steps_invalid() {
87        let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(0).init();
88        assert!(r.is_err(), "Should return an error");
89    }
90
91    #[test]
92    fn test_config_warmup_steps_valid() {
93        let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(1).init();
94        assert!(r.is_ok(), "Should return a success value");
95    }
96
97    #[test]
98    fn test_config_model_size_invalid() {
99        let r = NoamLrSchedulerConfig::new(0.1).with_model_size(0).init();
100        assert!(r.is_err(), "Should return an error");
101    }
102
103    #[test]
104    fn test_config_model_size_valid() {
105        let r = NoamLrSchedulerConfig::new(0.1).with_model_size(1).init();
106        assert!(r.is_ok(), "Should return a success value");
107    }
108
109    #[test]
110    fn test_function_increase_and_decrease() {
111        let warmup_steps = 100;
112        let mut scheduler = NoamLrSchedulerConfig::new(10.0)
113            .with_warmup_steps(warmup_steps)
114            .init()
115            .unwrap();
116        let mut lr_current = 0.0;
117
118        for _ in 0..warmup_steps {
119            let lr = scheduler.step();
120            assert!(
121                lr > lr_current,
122                "Learning rate should increase before the warmup_steps is reached."
123            );
124            lr_current = lr;
125        }
126
127        for _ in 0..warmup_steps {
128            let lr = scheduler.step();
129            assert!(
130                lr < lr_current,
131                "Learning rate should decrease after the warmup_steps is reached."
132            );
133            lr_current = lr;
134        }
135    }
136}