1use crate::optimizer::Optimizer;
9
10pub trait LRScheduler {
16 fn step<O: Optimizer>(&mut self, optimizer: &mut O);
18
19 fn get_last_lr(&self) -> f32;
21
22 fn get_step(&self) -> usize;
24}
25
26pub struct StepLR {
34 initial_lr: f32,
35 step_size: usize,
36 gamma: f32,
37 current_step: usize,
38 last_lr: f32,
39}
40
41impl StepLR {
42 pub fn new<O: Optimizer>(optimizer: &O, step_size: usize, gamma: f32) -> Self {
44 let initial_lr = optimizer.get_lr();
45 Self {
46 initial_lr,
47 step_size,
48 gamma,
49 current_step: 0,
50 last_lr: initial_lr,
51 }
52 }
53}
54
55impl LRScheduler for StepLR {
56 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
57 self.current_step += 1;
58 let num_decays = self.current_step / self.step_size;
59 let new_lr = self.initial_lr * self.gamma.powi(num_decays as i32);
60 optimizer.set_lr(new_lr);
61 self.last_lr = new_lr;
62 }
63
64 fn get_last_lr(&self) -> f32 {
65 self.last_lr
66 }
67
68 fn get_step(&self) -> usize {
69 self.current_step
70 }
71}
72
73pub struct MultiStepLR {
79 initial_lr: f32,
80 milestones: Vec<usize>,
81 gamma: f32,
82 current_step: usize,
83 last_lr: f32,
84 milestone_idx: usize,
85}
86
87impl MultiStepLR {
88 pub fn new<O: Optimizer>(optimizer: &O, mut milestones: Vec<usize>, gamma: f32) -> Self {
90 let initial_lr = optimizer.get_lr();
91 milestones.sort_unstable();
92 Self {
93 initial_lr,
94 milestones,
95 gamma,
96 current_step: 0,
97 last_lr: initial_lr,
98 milestone_idx: 0,
99 }
100 }
101}
102
103impl LRScheduler for MultiStepLR {
104 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
105 self.current_step += 1;
106
107 while self.milestone_idx < self.milestones.len()
109 && self.current_step >= self.milestones[self.milestone_idx]
110 {
111 self.milestone_idx += 1;
112 }
113
114 let new_lr = self.initial_lr * self.gamma.powi(self.milestone_idx as i32);
115 optimizer.set_lr(new_lr);
116 self.last_lr = new_lr;
117 }
118
119 fn get_last_lr(&self) -> f32 {
120 self.last_lr
121 }
122
123 fn get_step(&self) -> usize {
124 self.current_step
125 }
126}
127
128pub struct ExponentialLR {
136 initial_lr: f32,
137 gamma: f32,
138 current_step: usize,
139 last_lr: f32,
140}
141
142impl ExponentialLR {
143 pub fn new<O: Optimizer>(optimizer: &O, gamma: f32) -> Self {
145 let initial_lr = optimizer.get_lr();
146 Self {
147 initial_lr,
148 gamma,
149 current_step: 0,
150 last_lr: initial_lr,
151 }
152 }
153}
154
155impl LRScheduler for ExponentialLR {
156 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
157 self.current_step += 1;
158 let new_lr = self.initial_lr * self.gamma.powi(self.current_step as i32);
159 optimizer.set_lr(new_lr);
160 self.last_lr = new_lr;
161 }
162
163 fn get_last_lr(&self) -> f32 {
164 self.last_lr
165 }
166
167 fn get_step(&self) -> usize {
168 self.current_step
169 }
170}
171
172pub struct CosineAnnealingLR {
180 initial_lr: f32,
181 t_max: usize,
182 eta_min: f32,
183 current_step: usize,
184 last_lr: f32,
185}
186
187impl CosineAnnealingLR {
188 pub fn new<O: Optimizer>(optimizer: &O, t_max: usize) -> Self {
190 Self::with_eta_min(optimizer, t_max, 0.0)
191 }
192
193 pub fn with_eta_min<O: Optimizer>(optimizer: &O, t_max: usize, eta_min: f32) -> Self {
195 let initial_lr = optimizer.get_lr();
196 Self {
197 initial_lr,
198 t_max,
199 eta_min,
200 current_step: 0,
201 last_lr: initial_lr,
202 }
203 }
204}
205
206impl LRScheduler for CosineAnnealingLR {
207 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
208 self.current_step += 1;
209
210 let progress = self.current_step as f32 / self.t_max as f32;
211 let new_lr = self.eta_min
212 + (self.initial_lr - self.eta_min) * (1.0 + (std::f32::consts::PI * progress).cos())
213 / 2.0;
214
215 optimizer.set_lr(new_lr);
216 self.last_lr = new_lr;
217 }
218
219 fn get_last_lr(&self) -> f32 {
220 self.last_lr
221 }
222
223 fn get_step(&self) -> usize {
224 self.current_step
225 }
226}
227
228pub struct ReduceLROnPlateau {
234 mode: String,
235 factor: f32,
236 patience: usize,
237 threshold: f32,
238 cooldown: usize,
239 min_lr: f32,
240 best: f32,
241 num_bad_epochs: usize,
242 cooldown_counter: usize,
243 current_step: usize,
244 last_lr: f32,
245}
246
247impl ReduceLROnPlateau {
248 pub fn new<O: Optimizer>(optimizer: &O) -> Self {
250 Self::with_options(optimizer, "min", 0.1, 10, 1e-4, 0, 0.0)
251 }
252
253 pub fn with_options<O: Optimizer>(
255 optimizer: &O,
256 mode: &str,
257 factor: f32,
258 patience: usize,
259 threshold: f32,
260 cooldown: usize,
261 min_lr: f32,
262 ) -> Self {
263 let initial_lr = optimizer.get_lr();
264 let best = if mode == "min" {
265 f32::INFINITY
266 } else {
267 f32::NEG_INFINITY
268 };
269 Self {
270 mode: mode.to_string(),
271 factor,
272 patience,
273 threshold,
274 cooldown,
275 min_lr,
276 best,
277 num_bad_epochs: 0,
278 cooldown_counter: 0,
279 current_step: 0,
280 last_lr: initial_lr,
281 }
282 }
283
284 pub fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
286 self.current_step += 1;
287
288 if self.cooldown_counter > 0 {
290 self.cooldown_counter -= 1;
291 return;
292 }
293
294 let improved = if self.mode == "min" {
296 metric < self.best * (1.0 - self.threshold)
297 } else {
298 metric > self.best * (1.0 + self.threshold)
299 };
300
301 if improved {
302 self.best = metric;
303 self.num_bad_epochs = 0;
304 } else {
305 self.num_bad_epochs += 1;
306 }
307
308 if self.num_bad_epochs > self.patience {
310 let current_lr = optimizer.get_lr();
311 let new_lr = (current_lr * self.factor).max(self.min_lr);
312 optimizer.set_lr(new_lr);
313 self.last_lr = new_lr;
314 self.cooldown_counter = self.cooldown;
315 self.num_bad_epochs = 0;
316 }
317 }
318}
319
320impl LRScheduler for ReduceLROnPlateau {
321 fn step<O: Optimizer>(&mut self, _optimizer: &mut O) {
322 self.current_step += 1;
325 }
326
327 fn get_last_lr(&self) -> f32 {
328 self.last_lr
329 }
330
331 fn get_step(&self) -> usize {
332 self.current_step
333 }
334}
335
336pub struct OneCycleLR {
344 max_lr: f32,
345 total_steps: usize,
346 pct_start: f32,
347 div_factor: f32,
348 final_div_factor: f32,
349 current_step: usize,
350 last_lr: f32,
351}
352
353impl OneCycleLR {
354 pub fn new<O: Optimizer>(optimizer: &O, max_lr: f32, total_steps: usize) -> Self {
356 Self::with_options(optimizer, max_lr, total_steps, 0.3, 25.0, 1e4)
357 }
358
359 pub fn with_options<O: Optimizer>(
361 _optimizer: &O,
362 max_lr: f32,
363 total_steps: usize,
364 pct_start: f32,
365 div_factor: f32,
366 final_div_factor: f32,
367 ) -> Self {
368 let initial_lr = max_lr / div_factor;
369 Self {
370 max_lr,
371 total_steps,
372 pct_start,
373 div_factor,
374 final_div_factor,
375 current_step: 0,
376 last_lr: initial_lr,
377 }
378 }
379}
380
381impl LRScheduler for OneCycleLR {
382 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
383 self.current_step += 1;
384
385 let step_ratio = self.current_step as f32 / self.total_steps as f32;
386 let initial_lr = self.max_lr / self.div_factor;
387 let min_lr = self.max_lr / self.final_div_factor;
388
389 let new_lr = if step_ratio <= self.pct_start {
390 let phase_ratio = step_ratio / self.pct_start;
392 initial_lr + (self.max_lr - initial_lr) * phase_ratio
393 } else {
394 let phase_ratio = (step_ratio - self.pct_start) / (1.0 - self.pct_start);
396 min_lr
397 + (self.max_lr - min_lr) * (1.0 + (std::f32::consts::PI * phase_ratio).cos()) / 2.0
398 };
399
400 optimizer.set_lr(new_lr);
401 self.last_lr = new_lr;
402 }
403
404 fn get_last_lr(&self) -> f32 {
405 self.last_lr
406 }
407
408 fn get_step(&self) -> usize {
409 self.current_step
410 }
411}
412
413pub struct WarmupLR {
421 initial_lr: f32,
422 warmup_steps: usize,
423 current_step: usize,
424 last_lr: f32,
425}
426
427impl WarmupLR {
428 pub fn new<O: Optimizer>(optimizer: &O, warmup_steps: usize) -> Self {
430 let initial_lr = optimizer.get_lr();
431 Self {
432 initial_lr,
433 warmup_steps,
434 current_step: 0,
435 last_lr: 0.0,
436 }
437 }
438}
439
440impl LRScheduler for WarmupLR {
441 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
442 self.current_step += 1;
443
444 let new_lr = if self.current_step <= self.warmup_steps {
445 self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
446 } else {
447 self.initial_lr
448 };
449
450 optimizer.set_lr(new_lr);
451 self.last_lr = new_lr;
452 }
453
454 fn get_last_lr(&self) -> f32 {
455 self.last_lr
456 }
457
458 fn get_step(&self) -> usize {
459 self.current_step
460 }
461}
462
463#[cfg(test)]
468mod tests {
469 use super::*;
470 use crate::SGD;
471 use axonml_autograd::Variable;
472 use axonml_nn::Parameter;
473 use axonml_tensor::Tensor;
474
475 fn create_test_optimizer() -> SGD {
476 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
477 let param = Parameter::from_variable(var);
478 SGD::new(vec![param], 0.1)
479 }
480
481 #[test]
482 fn test_step_lr() {
483 let mut optimizer = create_test_optimizer();
484 let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
485
486 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
487
488 for _ in 0..10 {
489 scheduler.step(&mut optimizer);
490 }
491
492 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
493
494 for _ in 0..10 {
495 scheduler.step(&mut optimizer);
496 }
497
498 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
499 }
500
501 #[test]
502 fn test_multi_step_lr() {
503 let mut optimizer = create_test_optimizer();
504 let mut scheduler = MultiStepLR::new(&optimizer, vec![5, 15], 0.1);
505
506 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
507
508 for _ in 0..5 {
509 scheduler.step(&mut optimizer);
510 }
511 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
512
513 for _ in 0..10 {
514 scheduler.step(&mut optimizer);
515 }
516 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
517 }
518
519 #[test]
520 fn test_exponential_lr() {
521 let mut optimizer = create_test_optimizer();
522 let mut scheduler = ExponentialLR::new(&optimizer, 0.9);
523
524 scheduler.step(&mut optimizer);
525 assert!((optimizer.get_lr() - 0.09).abs() < 1e-6);
526
527 scheduler.step(&mut optimizer);
528 assert!((optimizer.get_lr() - 0.081).abs() < 1e-6);
529 }
530
531 #[test]
532 fn test_cosine_annealing_lr() {
533 let mut optimizer = create_test_optimizer();
534 let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
535
536 for _ in 0..50 {
538 scheduler.step(&mut optimizer);
539 }
540 assert!((optimizer.get_lr() - 0.05).abs() < 0.01);
541
542 for _ in 0..50 {
544 scheduler.step(&mut optimizer);
545 }
546 assert!(optimizer.get_lr() < 0.01);
547 }
548
549 #[test]
550 fn test_warmup_lr() {
551 let mut optimizer = create_test_optimizer();
552 let mut scheduler = WarmupLR::new(&optimizer, 10);
553
554 scheduler.step(&mut optimizer);
555 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
556
557 for _ in 0..9 {
558 scheduler.step(&mut optimizer);
559 }
560 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
561
562 scheduler.step(&mut optimizer);
564 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
565 }
566
567 #[test]
568 fn test_one_cycle_lr() {
569 let mut optimizer = create_test_optimizer();
570 let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
571
572 assert!((scheduler.get_last_lr() - 0.004).abs() < 0.001);
574
575 for _ in 0..30 {
577 scheduler.step(&mut optimizer);
578 }
579
580 assert!(optimizer.get_lr() > 0.08);
582 }
583
584 #[test]
585 fn test_reduce_lr_on_plateau() {
586 let mut optimizer = create_test_optimizer();
587 let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 2, 0.0, 0, 0.0);
588
589 let initial_lr = optimizer.get_lr();
590
591 scheduler.step_with_metric(&mut optimizer, 1.0);
593 scheduler.step_with_metric(&mut optimizer, 0.9);
594 assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
595
596 scheduler.step_with_metric(&mut optimizer, 0.91);
598 scheduler.step_with_metric(&mut optimizer, 0.91);
599 scheduler.step_with_metric(&mut optimizer, 0.91);
600
601 assert!(optimizer.get_lr() < initial_lr);
603 }
604}