entrenar/optim/scheduler/
warmup_cosine_decay.rs1use super::LRScheduler;
4use crate::optim::Optimizer;
5use std::f32::consts::PI;
6
7pub struct WarmupCosineDecayLR {
13 lr_max: f32,
14 lr_min: f32,
15 warmup_steps: usize,
16 total_steps: usize,
17 current_step: usize,
18}
19
20impl WarmupCosineDecayLR {
21 pub fn new(lr_max: f32, lr_min: f32, warmup_steps: usize, total_steps: usize) -> Self {
29 Self { lr_max, lr_min, warmup_steps, total_steps, current_step: 0 }
30 }
31
32 pub fn apply<O: Optimizer>(&self, optimizer: &mut O) {
34 optimizer.set_lr(self.get_lr());
35 }
36}
37
38impl LRScheduler for WarmupCosineDecayLR {
39 fn get_lr(&self) -> f32 {
40 if self.current_step < self.warmup_steps {
41 if self.warmup_steps == 0 {
43 return self.lr_max;
44 }
45 let progress = self.current_step as f32 / self.warmup_steps as f32;
46 return self.lr_max * progress;
47 }
48
49 let decay_steps = self.total_steps.saturating_sub(self.warmup_steps);
51 if decay_steps == 0 {
52 return self.lr_min;
53 }
54
55 let decay_step = self.current_step - self.warmup_steps;
56 if decay_step >= decay_steps {
57 return self.lr_min;
58 }
59
60 let progress = decay_step as f32 / decay_steps as f32;
61 let cosine_decay = 0.5 * (1.0 + (PI * progress).cos());
62 self.lr_min + (self.lr_max - self.lr_min) * cosine_decay
63 }
64
65 fn step(&mut self) {
66 self.current_step += 1;
67 }
68}