burn_core/lr_scheduler/
cosine.rs1use super::{LrScheduler, String};
2use crate as burn;
3use crate::{LearningRate, config::Config};
4use burn_tensor::backend::Backend;
5
6#[derive(Config)]
13pub struct CosineAnnealingLrSchedulerConfig {
14 initial_lr: LearningRate,
16 #[config(default = 0.0)]
18 min_lr: LearningRate,
19 num_iters: usize,
22}
23
24impl CosineAnnealingLrSchedulerConfig {
25 pub fn init(&self) -> Result<CosineAnnealingLrScheduler, String> {
35 if self.initial_lr <= 0. || self.initial_lr > 1. {
36 return Err("Initial learning rate must be greater than 0 and at most 1".into());
37 }
38 if self.min_lr < 0.0 || self.min_lr > self.initial_lr {
39 return Err(
40 "Minimum learning rate must be at least 0 and at most equal to the initial \
41 learning rate"
42 .into(),
43 );
44 }
45 if self.num_iters == 0 {
46 return Err("Number of iterations must be at least 1".into());
47 }
48
49 Ok(CosineAnnealingLrScheduler {
50 min_lr: self.min_lr,
51 max_lr: self.initial_lr,
52 num_iters: self.num_iters,
53 current_iter: usize::MAX,
54 })
55 }
56}
57
58#[derive(Clone, Copy, Debug)]
64pub struct CosineAnnealingLrScheduler {
65 min_lr: LearningRate,
66 max_lr: LearningRate,
67 num_iters: usize,
68 current_iter: usize,
69}
70
71impl LrScheduler for CosineAnnealingLrScheduler {
72 type Record<B: Backend> = usize;
73
74 fn step(&mut self) -> LearningRate {
75 self.current_iter = self.current_iter.wrapping_add(1) % (self.num_iters + 1);
79 self.min_lr
80 + 0.5
81 * (self.max_lr - self.min_lr)
82 * (1.0
83 + (self.current_iter as f64 / self.num_iters as f64 * std::f64::consts::PI)
84 .cos())
85 }
86
87 fn to_record<B: Backend>(&self) -> Self::Record<B> {
88 self.current_iter
89 }
90
91 fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
92 self.current_iter = record;
93 self
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::super::test_utils;
100 use super::*;
101
102 #[test]
103 fn config_initial_lr_too_low() {
104 let r = CosineAnnealingLrSchedulerConfig::new(0., 10).init();
105 assert!(r.is_err(), "Should return an error");
106 assert_eq!(
107 r.unwrap_err(),
108 "Initial learning rate must be greater than 0 and at most 1",
109 "Error messages should match",
110 );
111 }
112
113 #[test]
114 fn config_initial_lr_too_high() {
115 let r = CosineAnnealingLrSchedulerConfig::new(1.5, 10).init();
116 assert!(r.is_err(), "Should return an error");
117 assert_eq!(
118 r.unwrap_err(),
119 "Initial learning rate must be greater than 0 and at most 1",
120 "Error messages should match",
121 );
122 }
123
124 #[test]
125 fn config_min_lr_too_low() {
126 let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
127 .with_min_lr(-0.1)
128 .init();
129 assert!(r.is_err(), "Should return an error");
130 assert_eq!(
131 r.unwrap_err(),
132 "Minimum learning rate must be at least 0 and at most equal to the initial learning \
133 rate",
134 "Error messages should match",
135 );
136 }
137
138 #[test]
139 fn config_min_lr_too_high() {
140 let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
141 .with_min_lr(0.6)
142 .init();
143 assert!(r.is_err(), "Should return an error");
144 assert_eq!(
145 r.unwrap_err(),
146 "Minimum learning rate must be at least 0 and at most equal to the initial learning \
147 rate",
148 "Error messages should match",
149 );
150 }
151
152 #[test]
153 fn config_num_iters_too_low() {
154 let r = CosineAnnealingLrSchedulerConfig::new(0.5, 0).init();
155 assert!(r.is_err(), "Should return an error");
156 assert_eq!(
157 r.unwrap_err(),
158 "Number of iterations must be at least 1",
159 "Error messages should match",
160 );
161 }
162
163 #[test]
164 fn test_lr_change() {
165 const INITIAL_LR: LearningRate = 0.5;
166 const MIN_LR: LearningRate = 0.1;
167
168 let scheduler = CosineAnnealingLrSchedulerConfig::new(INITIAL_LR, 2)
169 .with_min_lr(MIN_LR)
170 .init()
171 .unwrap();
172 let expected_lrs = [
173 INITIAL_LR, (INITIAL_LR + MIN_LR) * 0.5, MIN_LR, INITIAL_LR, ];
178 test_utils::check_lr_sequence(scheduler, expected_lrs);
179 }
180
181 #[test]
182 fn test_save_and_load() {
183 const NUM_ITERS: usize = 9;
184 let scheduler = CosineAnnealingLrSchedulerConfig::new(1.0, NUM_ITERS)
185 .init()
186 .unwrap();
187 test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
188 }
189}