burn_optim/lr_scheduler/
noam.rs1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::backend::Backend;
5
6use super::{LrScheduler, String};
7use crate::LearningRate;
8
9#[derive(Config, Debug)]
11pub struct NoamLrSchedulerConfig {
12 factor: f64,
14 #[config(default = 4000)]
16 warmup_steps: usize,
17 #[config(default = 512)]
19 model_size: usize,
20}
21
22#[derive(Clone, Debug)]
24pub struct NoamLrScheduler {
25 warmup_steps: f64,
26 embedding_size: f64,
27 factor: f64,
28 step: f64,
29}
30
31impl NoamLrSchedulerConfig {
32 pub fn init(&self) -> Result<NoamLrScheduler, String> {
41 if self.warmup_steps == 0 {
42 return Err(
43 "Number of steps before exponential decay starts must be greater than 0".into(),
44 );
45 }
46 if self.model_size == 0 {
47 return Err("Model size must be greater than 0".into());
48 }
49
50 Ok(NoamLrScheduler {
51 warmup_steps: self.warmup_steps as f64,
52 embedding_size: self.model_size as f64,
53 factor: self.factor,
54 step: 0.0,
55 })
56 }
57}
58
59impl LrScheduler for NoamLrScheduler {
60 type Record<B: Backend> = usize;
61
62 fn step(&mut self) -> LearningRate {
63 self.step += 1.0;
64
65 let arg1 = self.step.powf(-0.5);
66 let arg2 = self.step * self.warmup_steps.powf(-1.5);
67
68 self.factor * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2)
69 }
70
71 fn to_record<B: Backend>(&self) -> Self::Record<B> {
72 self.step as usize
73 }
74
75 fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
76 self.step = record as f64;
77 self
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn test_config_warmup_steps_invalid() {
87 let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(0).init();
88 assert!(r.is_err(), "Should return an error");
89 }
90
91 #[test]
92 fn test_config_warmup_steps_valid() {
93 let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(1).init();
94 assert!(r.is_ok(), "Should return a success value");
95 }
96
97 #[test]
98 fn test_config_model_size_invalid() {
99 let r = NoamLrSchedulerConfig::new(0.1).with_model_size(0).init();
100 assert!(r.is_err(), "Should return an error");
101 }
102
103 #[test]
104 fn test_config_model_size_valid() {
105 let r = NoamLrSchedulerConfig::new(0.1).with_model_size(1).init();
106 assert!(r.is_ok(), "Should return a success value");
107 }
108
109 #[test]
110 fn test_function_increase_and_decrease() {
111 let warmup_steps = 100;
112 let mut scheduler = NoamLrSchedulerConfig::new(10.0)
113 .with_warmup_steps(warmup_steps)
114 .init()
115 .unwrap();
116 let mut lr_current = 0.0;
117
118 for _ in 0..warmup_steps {
119 let lr = scheduler.step();
120 assert!(
121 lr > lr_current,
122 "Learning rate should increase before the warmup_steps is reached."
123 );
124 lr_current = lr;
125 }
126
127 for _ in 0..warmup_steps {
128 let lr = scheduler.step();
129 assert!(
130 lr < lr_current,
131 "Learning rate should decrease after the warmup_steps is reached."
132 );
133 lr_current = lr;
134 }
135 }
136}