entrenar/optim/scheduler/
cosine_annealing.rs1use super::LRScheduler;
4use crate::optim::Optimizer;
5use std::f32::consts::PI;
6
7pub struct CosineAnnealingLR {
19 lr_max: f32,
20 lr_min: f32,
21 t_max: usize,
22 current_step: usize,
23}
24
25impl CosineAnnealingLR {
26 pub fn new(lr_max: f32, t_max: usize, lr_min: f32) -> Self {
33 Self { lr_max, lr_min, t_max, current_step: 0 }
34 }
35
36 pub fn default_min(lr_max: f32, t_max: usize) -> Self {
38 Self::new(lr_max, t_max, 0.0)
39 }
40
41 pub fn apply<O: Optimizer>(&self, optimizer: &mut O) {
43 optimizer.set_lr(self.get_lr());
44 }
45}
46
47impl LRScheduler for CosineAnnealingLR {
48 fn get_lr(&self) -> f32 {
49 if self.current_step >= self.t_max {
50 return self.lr_min;
51 }
52
53 let progress = self.current_step as f32 / self.t_max as f32;
54 let cosine_decay = 0.5 * (1.0 + (PI * progress).cos());
55 self.lr_min + (self.lr_max - self.lr_min) * cosine_decay
56 }
57
58 fn step(&mut self) {
59 self.current_step += 1;
60 }
61}