entrenar/optim/scheduler/
step_decay.rs1use super::LRScheduler;
4use crate::optim::Optimizer;
5
6pub struct StepDecayLR {
12 lr_initial: f32,
13 gamma: f32,
14 step_size: usize,
15 current_epoch: usize,
16}
17
18impl StepDecayLR {
19 pub fn new(lr_initial: f32, step_size: usize, gamma: f32) -> Self {
26 Self { lr_initial, gamma, step_size, current_epoch: 0 }
27 }
28
29 pub fn apply<O: Optimizer>(&self, optimizer: &mut O) {
31 optimizer.set_lr(self.get_lr());
32 }
33}
34
35impl LRScheduler for StepDecayLR {
36 fn get_lr(&self) -> f32 {
37 if self.step_size == 0 {
38 return self.lr_initial;
39 }
40 let num_decays = self.current_epoch / self.step_size;
41 self.lr_initial * self.gamma.powi(num_decays as i32)
42 }
43
44 fn step(&mut self) {
45 self.current_epoch += 1;
46 }
47}