burn_core/lr_scheduler/
noam.rs

1use burn_tensor::backend::Backend;
2
3use crate as burn;
4
5use super::{LrScheduler, String};
6use crate::{LearningRate, config::Config};
7
8/// Configuration to create a [noam](NoamLrScheduler) learning rate scheduler.
9#[derive(Config)]
10pub struct NoamLrSchedulerConfig {
11    /// The initial learning rate.
12    init_lr: LearningRate,
13    /// The number of steps before the exponential decay stats.
14    #[config(default = 4000)]
15    warmup_steps: usize,
16    /// The size of the model.
17    #[config(default = 512)]
18    model_size: usize,
19}
20
21/// Noam learning rate scheduler as described in [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
22#[derive(Clone, Debug)]
23pub struct NoamLrScheduler {
24    warmup_steps: f64,
25    embedding_size: f64,
26    init_lr: LearningRate,
27    step: f64,
28}
29
30impl NoamLrSchedulerConfig {
31    /// Initialize a new [noam](NoamLrScheduler) learning rate scheduler.
32    ///
33    /// # Errors
34    ///
35    /// An error will be returned if any of the following conditions is true:
36    ///
37    /// * `warmup_steps` is 0
38    /// * `model_size` is 0
39    pub fn init(&self) -> Result<NoamLrScheduler, String> {
40        if self.warmup_steps == 0 {
41            return Err(
42                "Number of steps before exponential decay starts must be greater than 0".into(),
43            );
44        }
45        if self.model_size == 0 {
46            return Err("Model size must be greater than 0".into());
47        }
48
49        Ok(NoamLrScheduler {
50            warmup_steps: self.warmup_steps as f64,
51            embedding_size: self.model_size as f64,
52            init_lr: self.init_lr,
53            step: 0.0,
54        })
55    }
56}
57
58impl LrScheduler for NoamLrScheduler {
59    type Record<B: Backend> = usize;
60
61    fn step(&mut self) -> LearningRate {
62        self.step += 1.0;
63
64        let arg1 = self.step.powf(-0.5);
65        let arg2 = self.step * self.warmup_steps.powf(-1.5);
66
67        self.init_lr * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2)
68    }
69
70    fn to_record<B: Backend>(&self) -> Self::Record<B> {
71        self.step as usize
72    }
73
74    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
75        self.step = record as f64;
76        self
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn test_config_warmup_steps_invalid() {
86        let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(0).init();
87        assert!(r.is_err(), "Should return an error");
88    }
89
90    #[test]
91    fn test_config_warmup_steps_valid() {
92        let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(1).init();
93        assert!(r.is_ok(), "Should return a success value");
94    }
95
96    #[test]
97    fn test_config_model_size_invalid() {
98        let r = NoamLrSchedulerConfig::new(0.1).with_model_size(0).init();
99        assert!(r.is_err(), "Should return an error");
100    }
101
102    #[test]
103    fn test_config_model_size_valid() {
104        let r = NoamLrSchedulerConfig::new(0.1).with_model_size(1).init();
105        assert!(r.is_ok(), "Should return a success value");
106    }
107
108    #[test]
109    fn test_function_increase_and_decrease() {
110        let warmup_steps = 100;
111        let mut scheduler = NoamLrSchedulerConfig::new(10.0)
112            .with_warmup_steps(warmup_steps)
113            .init()
114            .unwrap();
115        let mut lr_current = 0.0;
116
117        for _ in 0..warmup_steps {
118            let lr = scheduler.step();
119            assert!(
120                lr > lr_current,
121                "Learning rate should increase before the warmup_steps is reached."
122            );
123            lr_current = lr;
124        }
125
126        for _ in 0..warmup_steps {
127            let lr = scheduler.step();
128            assert!(
129                lr < lr_current,
130                "Learning rate should decrease after the warmup_steps is reached."
131            );
132            lr_current = lr;
133        }
134    }
135}