1use crate::optimizer::Optimizer;
18
19pub trait LRScheduler {
25 fn step<O: Optimizer>(&mut self, optimizer: &mut O);
27
28 fn get_last_lr(&self) -> f32;
30
31 fn get_step(&self) -> usize;
33}
34
35pub struct StepLR {
43 initial_lr: f32,
44 step_size: usize,
45 gamma: f32,
46 current_step: usize,
47 last_lr: f32,
48}
49
50impl StepLR {
51 pub fn new<O: Optimizer>(optimizer: &O, step_size: usize, gamma: f32) -> Self {
53 let initial_lr = optimizer.get_lr();
54 Self {
55 initial_lr,
56 step_size,
57 gamma,
58 current_step: 0,
59 last_lr: initial_lr,
60 }
61 }
62}
63
64impl LRScheduler for StepLR {
65 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
66 self.current_step += 1;
67 let num_decays = self.current_step / self.step_size;
68 let new_lr = self.initial_lr * self.gamma.powi(num_decays as i32);
69 optimizer.set_lr(new_lr);
70 self.last_lr = new_lr;
71 }
72
73 fn get_last_lr(&self) -> f32 {
74 self.last_lr
75 }
76
77 fn get_step(&self) -> usize {
78 self.current_step
79 }
80}
81
82pub struct MultiStepLR {
88 initial_lr: f32,
89 milestones: Vec<usize>,
90 gamma: f32,
91 current_step: usize,
92 last_lr: f32,
93 milestone_idx: usize,
94}
95
96impl MultiStepLR {
97 pub fn new<O: Optimizer>(optimizer: &O, mut milestones: Vec<usize>, gamma: f32) -> Self {
99 let initial_lr = optimizer.get_lr();
100 milestones.sort_unstable();
101 Self {
102 initial_lr,
103 milestones,
104 gamma,
105 current_step: 0,
106 last_lr: initial_lr,
107 milestone_idx: 0,
108 }
109 }
110}
111
112impl LRScheduler for MultiStepLR {
113 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
114 self.current_step += 1;
115
116 while self.milestone_idx < self.milestones.len()
118 && self.current_step >= self.milestones[self.milestone_idx]
119 {
120 self.milestone_idx += 1;
121 }
122
123 let new_lr = self.initial_lr * self.gamma.powi(self.milestone_idx as i32);
124 optimizer.set_lr(new_lr);
125 self.last_lr = new_lr;
126 }
127
128 fn get_last_lr(&self) -> f32 {
129 self.last_lr
130 }
131
132 fn get_step(&self) -> usize {
133 self.current_step
134 }
135}
136
137pub struct ExponentialLR {
145 initial_lr: f32,
146 gamma: f32,
147 current_step: usize,
148 last_lr: f32,
149}
150
151impl ExponentialLR {
152 pub fn new<O: Optimizer>(optimizer: &O, gamma: f32) -> Self {
154 let initial_lr = optimizer.get_lr();
155 Self {
156 initial_lr,
157 gamma,
158 current_step: 0,
159 last_lr: initial_lr,
160 }
161 }
162}
163
164impl LRScheduler for ExponentialLR {
165 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
166 self.current_step += 1;
167 let new_lr = self.initial_lr * self.gamma.powi(self.current_step as i32);
168 optimizer.set_lr(new_lr);
169 self.last_lr = new_lr;
170 }
171
172 fn get_last_lr(&self) -> f32 {
173 self.last_lr
174 }
175
176 fn get_step(&self) -> usize {
177 self.current_step
178 }
179}
180
181pub struct CosineAnnealingLR {
189 initial_lr: f32,
190 t_max: usize,
191 eta_min: f32,
192 current_step: usize,
193 last_lr: f32,
194}
195
196impl CosineAnnealingLR {
197 pub fn new<O: Optimizer>(optimizer: &O, t_max: usize) -> Self {
199 Self::with_eta_min(optimizer, t_max, 0.0)
200 }
201
202 pub fn with_eta_min<O: Optimizer>(optimizer: &O, t_max: usize, eta_min: f32) -> Self {
204 let initial_lr = optimizer.get_lr();
205 Self {
206 initial_lr,
207 t_max,
208 eta_min,
209 current_step: 0,
210 last_lr: initial_lr,
211 }
212 }
213}
214
215impl LRScheduler for CosineAnnealingLR {
216 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
217 self.current_step += 1;
218
219 let progress = self.current_step as f32 / self.t_max as f32;
220 let new_lr = self.eta_min
221 + (self.initial_lr - self.eta_min) * (1.0 + (std::f32::consts::PI * progress).cos())
222 / 2.0;
223
224 optimizer.set_lr(new_lr);
225 self.last_lr = new_lr;
226 }
227
228 fn get_last_lr(&self) -> f32 {
229 self.last_lr
230 }
231
232 fn get_step(&self) -> usize {
233 self.current_step
234 }
235}
236
237pub struct ReduceLROnPlateau {
243 mode: String,
244 factor: f32,
245 patience: usize,
246 threshold: f32,
247 cooldown: usize,
248 min_lr: f32,
249 best: f32,
250 num_bad_epochs: usize,
251 cooldown_counter: usize,
252 current_step: usize,
253 last_lr: f32,
254}
255
256impl ReduceLROnPlateau {
257 pub fn new<O: Optimizer>(optimizer: &O) -> Self {
259 Self::with_options(optimizer, "min", 0.1, 10, 1e-4, 0, 0.0)
260 }
261
262 pub fn with_options<O: Optimizer>(
264 optimizer: &O,
265 mode: &str,
266 factor: f32,
267 patience: usize,
268 threshold: f32,
269 cooldown: usize,
270 min_lr: f32,
271 ) -> Self {
272 let initial_lr = optimizer.get_lr();
273 let best = if mode == "min" {
274 f32::INFINITY
275 } else {
276 f32::NEG_INFINITY
277 };
278 Self {
279 mode: mode.to_string(),
280 factor,
281 patience,
282 threshold,
283 cooldown,
284 min_lr,
285 best,
286 num_bad_epochs: 0,
287 cooldown_counter: 0,
288 current_step: 0,
289 last_lr: initial_lr,
290 }
291 }
292
293 pub fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
295 self.current_step += 1;
296
297 if self.cooldown_counter > 0 {
299 self.cooldown_counter -= 1;
300 return;
301 }
302
303 let improved = if self.mode == "min" {
305 metric < self.best * (1.0 - self.threshold)
306 } else {
307 metric > self.best * (1.0 + self.threshold)
308 };
309
310 if improved {
311 self.best = metric;
312 self.num_bad_epochs = 0;
313 } else {
314 self.num_bad_epochs += 1;
315 }
316
317 if self.num_bad_epochs > self.patience {
319 let current_lr = optimizer.get_lr();
320 let new_lr = (current_lr * self.factor).max(self.min_lr);
321 optimizer.set_lr(new_lr);
322 self.last_lr = new_lr;
323 self.cooldown_counter = self.cooldown;
324 self.num_bad_epochs = 0;
325 }
326 }
327}
328
329impl LRScheduler for ReduceLROnPlateau {
330 fn step<O: Optimizer>(&mut self, _optimizer: &mut O) {
331 self.current_step += 1;
334 }
335
336 fn get_last_lr(&self) -> f32 {
337 self.last_lr
338 }
339
340 fn get_step(&self) -> usize {
341 self.current_step
342 }
343}
344
345pub struct OneCycleLR {
353 max_lr: f32,
354 total_steps: usize,
355 pct_start: f32,
356 div_factor: f32,
357 final_div_factor: f32,
358 current_step: usize,
359 last_lr: f32,
360}
361
362impl OneCycleLR {
363 pub fn new<O: Optimizer>(optimizer: &O, max_lr: f32, total_steps: usize) -> Self {
365 Self::with_options(optimizer, max_lr, total_steps, 0.3, 25.0, 1e4)
366 }
367
368 pub fn with_options<O: Optimizer>(
370 _optimizer: &O,
371 max_lr: f32,
372 total_steps: usize,
373 pct_start: f32,
374 div_factor: f32,
375 final_div_factor: f32,
376 ) -> Self {
377 let initial_lr = max_lr / div_factor;
378 Self {
379 max_lr,
380 total_steps,
381 pct_start,
382 div_factor,
383 final_div_factor,
384 current_step: 0,
385 last_lr: initial_lr,
386 }
387 }
388}
389
390impl LRScheduler for OneCycleLR {
391 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
392 self.current_step += 1;
393
394 let step_ratio = self.current_step as f32 / self.total_steps as f32;
395 let initial_lr = self.max_lr / self.div_factor;
396 let min_lr = self.max_lr / self.final_div_factor;
397
398 let new_lr = if step_ratio <= self.pct_start {
399 let phase_ratio = step_ratio / self.pct_start;
401 initial_lr + (self.max_lr - initial_lr) * phase_ratio
402 } else {
403 let phase_ratio = (step_ratio - self.pct_start) / (1.0 - self.pct_start);
405 min_lr
406 + (self.max_lr - min_lr) * (1.0 + (std::f32::consts::PI * phase_ratio).cos()) / 2.0
407 };
408
409 optimizer.set_lr(new_lr);
410 self.last_lr = new_lr;
411 }
412
413 fn get_last_lr(&self) -> f32 {
414 self.last_lr
415 }
416
417 fn get_step(&self) -> usize {
418 self.current_step
419 }
420}
421
422pub struct WarmupLR {
430 initial_lr: f32,
431 warmup_steps: usize,
432 current_step: usize,
433 last_lr: f32,
434}
435
436impl WarmupLR {
437 pub fn new<O: Optimizer>(optimizer: &O, warmup_steps: usize) -> Self {
439 let initial_lr = optimizer.get_lr();
440 Self {
441 initial_lr,
442 warmup_steps,
443 current_step: 0,
444 last_lr: 0.0,
445 }
446 }
447}
448
449impl LRScheduler for WarmupLR {
450 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
451 self.current_step += 1;
452
453 let new_lr = if self.current_step <= self.warmup_steps {
454 self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
455 } else {
456 self.initial_lr
457 };
458
459 optimizer.set_lr(new_lr);
460 self.last_lr = new_lr;
461 }
462
463 fn get_last_lr(&self) -> f32 {
464 self.last_lr
465 }
466
467 fn get_step(&self) -> usize {
468 self.current_step
469 }
470}
471
472#[cfg(test)]
477mod tests {
478 use super::*;
479 use crate::SGD;
480 use axonml_autograd::Variable;
481 use axonml_nn::Parameter;
482 use axonml_tensor::Tensor;
483
484 fn create_test_optimizer() -> SGD {
485 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
486 let param = Parameter::from_variable(var);
487 SGD::new(vec![param], 0.1)
488 }
489
490 #[test]
491 fn test_step_lr() {
492 let mut optimizer = create_test_optimizer();
493 let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
494
495 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
496
497 for _ in 0..10 {
498 scheduler.step(&mut optimizer);
499 }
500
501 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
502
503 for _ in 0..10 {
504 scheduler.step(&mut optimizer);
505 }
506
507 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
508 }
509
510 #[test]
511 fn test_multi_step_lr() {
512 let mut optimizer = create_test_optimizer();
513 let mut scheduler = MultiStepLR::new(&optimizer, vec![5, 15], 0.1);
514
515 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
516
517 for _ in 0..5 {
518 scheduler.step(&mut optimizer);
519 }
520 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
521
522 for _ in 0..10 {
523 scheduler.step(&mut optimizer);
524 }
525 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
526 }
527
528 #[test]
529 fn test_exponential_lr() {
530 let mut optimizer = create_test_optimizer();
531 let mut scheduler = ExponentialLR::new(&optimizer, 0.9);
532
533 scheduler.step(&mut optimizer);
534 assert!((optimizer.get_lr() - 0.09).abs() < 1e-6);
535
536 scheduler.step(&mut optimizer);
537 assert!((optimizer.get_lr() - 0.081).abs() < 1e-6);
538 }
539
540 #[test]
541 fn test_cosine_annealing_lr() {
542 let mut optimizer = create_test_optimizer();
543 let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
544
545 for _ in 0..50 {
547 scheduler.step(&mut optimizer);
548 }
549 assert!((optimizer.get_lr() - 0.05).abs() < 0.01);
550
551 for _ in 0..50 {
553 scheduler.step(&mut optimizer);
554 }
555 assert!(optimizer.get_lr() < 0.01);
556 }
557
558 #[test]
559 fn test_warmup_lr() {
560 let mut optimizer = create_test_optimizer();
561 let mut scheduler = WarmupLR::new(&optimizer, 10);
562
563 scheduler.step(&mut optimizer);
564 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
565
566 for _ in 0..9 {
567 scheduler.step(&mut optimizer);
568 }
569 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
570
571 scheduler.step(&mut optimizer);
573 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
574 }
575
576 #[test]
577 fn test_one_cycle_lr() {
578 let mut optimizer = create_test_optimizer();
579 let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
580
581 assert!((scheduler.get_last_lr() - 0.004).abs() < 0.001);
583
584 for _ in 0..30 {
586 scheduler.step(&mut optimizer);
587 }
588
589 assert!(optimizer.get_lr() > 0.08);
591 }
592
593 #[test]
594 fn test_reduce_lr_on_plateau() {
595 let mut optimizer = create_test_optimizer();
596 let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 2, 0.0, 0, 0.0);
597
598 let initial_lr = optimizer.get_lr();
599
600 scheduler.step_with_metric(&mut optimizer, 1.0);
602 scheduler.step_with_metric(&mut optimizer, 0.9);
603 assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
604
605 scheduler.step_with_metric(&mut optimizer, 0.91);
607 scheduler.step_with_metric(&mut optimizer, 0.91);
608 scheduler.step_with_metric(&mut optimizer, 0.91);
609
610 assert!(optimizer.get_lr() < initial_lr);
612 }
613}