Skip to main content

entrenar/optim/scheduler/
step_decay.rs

1//! Step decay learning rate scheduler
2
3use super::LRScheduler;
4use crate::optim::Optimizer;
5
6/// Step Decay Learning Rate Scheduler
7///
8/// Multiplies learning rate by gamma every step_size epochs.
9///
10/// Formula: lr_t = lr_initial * gamma^(floor(epoch / step_size))
11pub struct StepDecayLR {
12    lr_initial: f32,
13    gamma: f32,
14    step_size: usize,
15    current_epoch: usize,
16}
17
18impl StepDecayLR {
19    /// Create a new step decay scheduler
20    ///
21    /// # Arguments
22    /// * `lr_initial` - Initial learning rate
23    /// * `step_size` - Decay LR every step_size epochs
24    /// * `gamma` - Multiplicative factor (e.g., 0.1 for 10x reduction)
25    pub fn new(lr_initial: f32, step_size: usize, gamma: f32) -> Self {
26        Self { lr_initial, gamma, step_size, current_epoch: 0 }
27    }
28
29    /// Apply the current learning rate to an optimizer
30    pub fn apply<O: Optimizer>(&self, optimizer: &mut O) {
31        optimizer.set_lr(self.get_lr());
32    }
33}
34
35impl LRScheduler for StepDecayLR {
36    fn get_lr(&self) -> f32 {
37        if self.step_size == 0 {
38            return self.lr_initial;
39        }
40        let num_decays = self.current_epoch / self.step_size;
41        self.lr_initial * self.gamma.powi(num_decays as i32)
42    }
43
44    fn step(&mut self) {
45        self.current_epoch += 1;
46    }
47}