Skip to main content

entrenar/optim/scheduler/
cosine_annealing.rs

1//! Cosine annealing learning rate scheduler
2
3use super::LRScheduler;
4use crate::optim::Optimizer;
5use std::f32::consts::PI;
6
7/// Cosine Annealing Learning Rate Scheduler
8///
9/// Decreases the learning rate following a cosine curve from lr_max to lr_min.
10///
11/// Formula: lr_t = lr_min + 0.5 * (lr_max - lr_min) * (1 + cos(pi * t / T))
12///
13/// Where:
14/// - t is the current step
15/// - T is the total number of steps
16/// - lr_max is the initial learning rate
17/// - lr_min is the minimum learning rate (default 0)
18pub struct CosineAnnealingLR {
19    lr_max: f32,
20    lr_min: f32,
21    t_max: usize,
22    current_step: usize,
23}
24
25impl CosineAnnealingLR {
26    /// Create a new cosine annealing scheduler
27    ///
28    /// # Arguments
29    /// * `lr_max` - Initial (maximum) learning rate
30    /// * `t_max` - Total number of steps for the schedule
31    /// * `lr_min` - Minimum learning rate (default 0)
32    pub fn new(lr_max: f32, t_max: usize, lr_min: f32) -> Self {
33        Self { lr_max, lr_min, t_max, current_step: 0 }
34    }
35
36    /// Create scheduler with lr_min = 0
37    pub fn default_min(lr_max: f32, t_max: usize) -> Self {
38        Self::new(lr_max, t_max, 0.0)
39    }
40
41    /// Apply the current learning rate to an optimizer
42    pub fn apply<O: Optimizer>(&self, optimizer: &mut O) {
43        optimizer.set_lr(self.get_lr());
44    }
45}
46
47impl LRScheduler for CosineAnnealingLR {
48    fn get_lr(&self) -> f32 {
49        if self.current_step >= self.t_max {
50            return self.lr_min;
51        }
52
53        let progress = self.current_step as f32 / self.t_max as f32;
54        let cosine_decay = 0.5 * (1.0 + (PI * progress).cos());
55        self.lr_min + (self.lr_max - self.lr_min) * cosine_decay
56    }
57
58    fn step(&mut self) {
59        self.current_step += 1;
60    }
61}