Skip to main content

entrenar/optim/scheduler/
warmup_cosine_decay.rs

1//! Warmup + cosine decay learning rate scheduler
2
3use super::LRScheduler;
4use crate::optim::Optimizer;
5use std::f32::consts::PI;
6
7/// Warmup + Cosine Decay Learning Rate Scheduler
8///
9/// Combines linear warmup with cosine annealing decay.
10/// - Phase 1 (warmup): Linear increase from 0 to lr_max
11/// - Phase 2 (decay): Cosine decay from lr_max to lr_min
12pub struct WarmupCosineDecayLR {
13    lr_max: f32,
14    lr_min: f32,
15    warmup_steps: usize,
16    total_steps: usize,
17    current_step: usize,
18}
19
20impl WarmupCosineDecayLR {
21    /// Create a new warmup + cosine decay scheduler
22    ///
23    /// # Arguments
24    /// * `lr_max` - Maximum learning rate (after warmup)
25    /// * `lr_min` - Minimum learning rate (at end)
26    /// * `warmup_steps` - Number of warmup steps
27    /// * `total_steps` - Total training steps (including warmup)
28    pub fn new(lr_max: f32, lr_min: f32, warmup_steps: usize, total_steps: usize) -> Self {
29        Self { lr_max, lr_min, warmup_steps, total_steps, current_step: 0 }
30    }
31
32    /// Apply the current learning rate to an optimizer
33    pub fn apply<O: Optimizer>(&self, optimizer: &mut O) {
34        optimizer.set_lr(self.get_lr());
35    }
36}
37
38impl LRScheduler for WarmupCosineDecayLR {
39    fn get_lr(&self) -> f32 {
40        if self.current_step < self.warmup_steps {
41            // Warmup phase: linear increase
42            if self.warmup_steps == 0 {
43                return self.lr_max;
44            }
45            let progress = self.current_step as f32 / self.warmup_steps as f32;
46            return self.lr_max * progress;
47        }
48
49        // Cosine decay phase
50        let decay_steps = self.total_steps.saturating_sub(self.warmup_steps);
51        if decay_steps == 0 {
52            return self.lr_min;
53        }
54
55        let decay_step = self.current_step - self.warmup_steps;
56        if decay_step >= decay_steps {
57            return self.lr_min;
58        }
59
60        let progress = decay_step as f32 / decay_steps as f32;
61        let cosine_decay = 0.5 * (1.0 + (PI * progress).cos());
62        self.lr_min + (self.lr_max - self.lr_min) * cosine_decay
63    }
64
65    fn step(&mut self) {
66        self.current_step += 1;
67    }
68}