burn_core/lr_scheduler/
noam.rs1use burn_tensor::backend::Backend;
2
3use crate as burn;
4
5use super::{LrScheduler, String};
6use crate::{LearningRate, config::Config};
7
8#[derive(Config)]
10pub struct NoamLrSchedulerConfig {
11 init_lr: LearningRate,
13 #[config(default = 4000)]
15 warmup_steps: usize,
16 #[config(default = 512)]
18 model_size: usize,
19}
20
21#[derive(Clone, Debug)]
23pub struct NoamLrScheduler {
24 warmup_steps: f64,
25 embedding_size: f64,
26 init_lr: LearningRate,
27 step: f64,
28}
29
30impl NoamLrSchedulerConfig {
31 pub fn init(&self) -> Result<NoamLrScheduler, String> {
40 if self.warmup_steps == 0 {
41 return Err(
42 "Number of steps before exponential decay starts must be greater than 0".into(),
43 );
44 }
45 if self.model_size == 0 {
46 return Err("Model size must be greater than 0".into());
47 }
48
49 Ok(NoamLrScheduler {
50 warmup_steps: self.warmup_steps as f64,
51 embedding_size: self.model_size as f64,
52 init_lr: self.init_lr,
53 step: 0.0,
54 })
55 }
56}
57
58impl LrScheduler for NoamLrScheduler {
59 type Record<B: Backend> = usize;
60
61 fn step(&mut self) -> LearningRate {
62 self.step += 1.0;
63
64 let arg1 = self.step.powf(-0.5);
65 let arg2 = self.step * self.warmup_steps.powf(-1.5);
66
67 self.init_lr * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2)
68 }
69
70 fn to_record<B: Backend>(&self) -> Self::Record<B> {
71 self.step as usize
72 }
73
74 fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
75 self.step = record as f64;
76 self
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[test]
85 fn test_config_warmup_steps_invalid() {
86 let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(0).init();
87 assert!(r.is_err(), "Should return an error");
88 }
89
90 #[test]
91 fn test_config_warmup_steps_valid() {
92 let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(1).init();
93 assert!(r.is_ok(), "Should return a success value");
94 }
95
96 #[test]
97 fn test_config_model_size_invalid() {
98 let r = NoamLrSchedulerConfig::new(0.1).with_model_size(0).init();
99 assert!(r.is_err(), "Should return an error");
100 }
101
102 #[test]
103 fn test_config_model_size_valid() {
104 let r = NoamLrSchedulerConfig::new(0.1).with_model_size(1).init();
105 assert!(r.is_ok(), "Should return a success value");
106 }
107
108 #[test]
109 fn test_function_increase_and_decrease() {
110 let warmup_steps = 100;
111 let mut scheduler = NoamLrSchedulerConfig::new(10.0)
112 .with_warmup_steps(warmup_steps)
113 .init()
114 .unwrap();
115 let mut lr_current = 0.0;
116
117 for _ in 0..warmup_steps {
118 let lr = scheduler.step();
119 assert!(
120 lr > lr_current,
121 "Learning rate should increase before the warmup_steps is reached."
122 );
123 lr_current = lr;
124 }
125
126 for _ in 0..warmup_steps {
127 let lr = scheduler.step();
128 assert!(
129 lr < lr_current,
130 "Learning rate should decrease after the warmup_steps is reached."
131 );
132 lr_current = lr;
133 }
134 }
135}