burn_optim/lr_scheduler/
cosine.rs1use burn_core as burn;
2
3use super::{LrScheduler, String};
4use crate::LearningRate;
5use burn::config::Config;
6use burn::tensor::backend::Backend;
7
8#[derive(Config, Debug)]
15pub struct CosineAnnealingLrSchedulerConfig {
16 initial_lr: LearningRate,
18 #[config(default = 0.0)]
20 min_lr: LearningRate,
21 num_iters: usize,
24}
25
26impl CosineAnnealingLrSchedulerConfig {
27 pub fn init(&self) -> Result<CosineAnnealingLrScheduler, String> {
37 if self.initial_lr <= 0. || self.initial_lr > 1. {
38 return Err("Initial learning rate must be greater than 0 and at most 1".into());
39 }
40 if self.min_lr < 0.0 || self.min_lr > self.initial_lr {
41 return Err(
42 "Minimum learning rate must be at least 0 and at most equal to the initial \
43 learning rate"
44 .into(),
45 );
46 }
47 if self.num_iters == 0 {
48 return Err("Number of iterations must be at least 1".into());
49 }
50
51 Ok(CosineAnnealingLrScheduler {
52 min_lr: self.min_lr,
53 max_lr: self.initial_lr,
54 num_iters: self.num_iters,
55 current_iter: usize::MAX,
56 })
57 }
58}
59
60#[derive(Clone, Copy, Debug)]
66pub struct CosineAnnealingLrScheduler {
67 min_lr: LearningRate,
68 max_lr: LearningRate,
69 num_iters: usize,
70 current_iter: usize,
71}
72
73impl LrScheduler for CosineAnnealingLrScheduler {
74 type Record<B: Backend> = usize;
75
76 fn step(&mut self) -> LearningRate {
77 self.current_iter = self.current_iter.wrapping_add(1) % (self.num_iters + 1);
81 self.min_lr
82 + 0.5
83 * (self.max_lr - self.min_lr)
84 * (1.0
85 + (self.current_iter as f64 / self.num_iters as f64 * std::f64::consts::PI)
86 .cos())
87 }
88
89 fn to_record<B: Backend>(&self) -> Self::Record<B> {
90 self.current_iter
91 }
92
93 fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
94 self.current_iter = record;
95 self
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::super::test_utils;
102 use super::*;
103
104 #[test]
105 fn config_initial_lr_too_low() {
106 let r = CosineAnnealingLrSchedulerConfig::new(0., 10).init();
107 assert!(r.is_err(), "Should return an error");
108 assert_eq!(
109 r.unwrap_err(),
110 "Initial learning rate must be greater than 0 and at most 1",
111 "Error messages should match",
112 );
113 }
114
115 #[test]
116 fn config_initial_lr_too_high() {
117 let r = CosineAnnealingLrSchedulerConfig::new(1.5, 10).init();
118 assert!(r.is_err(), "Should return an error");
119 assert_eq!(
120 r.unwrap_err(),
121 "Initial learning rate must be greater than 0 and at most 1",
122 "Error messages should match",
123 );
124 }
125
126 #[test]
127 fn config_min_lr_too_low() {
128 let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
129 .with_min_lr(-0.1)
130 .init();
131 assert!(r.is_err(), "Should return an error");
132 assert_eq!(
133 r.unwrap_err(),
134 "Minimum learning rate must be at least 0 and at most equal to the initial learning \
135 rate",
136 "Error messages should match",
137 );
138 }
139
140 #[test]
141 fn config_min_lr_too_high() {
142 let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
143 .with_min_lr(0.6)
144 .init();
145 assert!(r.is_err(), "Should return an error");
146 assert_eq!(
147 r.unwrap_err(),
148 "Minimum learning rate must be at least 0 and at most equal to the initial learning \
149 rate",
150 "Error messages should match",
151 );
152 }
153
154 #[test]
155 fn config_num_iters_too_low() {
156 let r = CosineAnnealingLrSchedulerConfig::new(0.5, 0).init();
157 assert!(r.is_err(), "Should return an error");
158 assert_eq!(
159 r.unwrap_err(),
160 "Number of iterations must be at least 1",
161 "Error messages should match",
162 );
163 }
164
165 #[test]
166 fn test_lr_change() {
167 const INITIAL_LR: LearningRate = 0.5;
168 const MIN_LR: LearningRate = 0.1;
169
170 let scheduler = CosineAnnealingLrSchedulerConfig::new(INITIAL_LR, 2)
171 .with_min_lr(MIN_LR)
172 .init()
173 .unwrap();
174 let expected_lrs = [
175 INITIAL_LR, (INITIAL_LR + MIN_LR) * 0.5, MIN_LR, INITIAL_LR, ];
180 test_utils::check_lr_sequence(scheduler, expected_lrs);
181 }
182
183 #[test]
184 fn test_save_and_load() {
185 const NUM_ITERS: usize = 9;
186 let scheduler = CosineAnnealingLrSchedulerConfig::new(1.0, NUM_ITERS)
187 .init()
188 .unwrap();
189 test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
190 }
191}