ghostflow_optim/
scheduler.rs

1//! Learning rate schedulers
2
3use crate::optimizer::Optimizer;
4
5/// Learning rate scheduler trait
6pub trait LRScheduler {
7    fn step(&mut self);
8    fn get_lr(&self) -> f32;
9}
10
11/// Step decay scheduler
12pub struct StepLR<O: Optimizer> {
13    optimizer: O,
14    step_size: usize,
15    gamma: f32,
16    current_step: usize,
17    base_lr: f32,
18}
19
20impl<O: Optimizer> StepLR<O> {
21    pub fn new(optimizer: O, step_size: usize, gamma: f32) -> Self {
22        let base_lr = optimizer.get_lr();
23        StepLR {
24            optimizer,
25            step_size,
26            gamma,
27            current_step: 0,
28            base_lr,
29        }
30    }
31
32    pub fn optimizer(&self) -> &O {
33        &self.optimizer
34    }
35
36    pub fn optimizer_mut(&mut self) -> &mut O {
37        &mut self.optimizer
38    }
39}
40
41impl<O: Optimizer> LRScheduler for StepLR<O> {
42    fn step(&mut self) {
43        self.current_step += 1;
44        let num_decays = self.current_step / self.step_size;
45        let new_lr = self.base_lr * self.gamma.powi(num_decays as i32);
46        self.optimizer.set_lr(new_lr);
47    }
48
49    fn get_lr(&self) -> f32 {
50        self.optimizer.get_lr()
51    }
52}
53
54/// Exponential decay scheduler
55pub struct ExponentialLR<O: Optimizer> {
56    optimizer: O,
57    gamma: f32,
58}
59
60impl<O: Optimizer> ExponentialLR<O> {
61    pub fn new(optimizer: O, gamma: f32) -> Self {
62        ExponentialLR { optimizer, gamma }
63    }
64
65    pub fn optimizer(&self) -> &O {
66        &self.optimizer
67    }
68
69    pub fn optimizer_mut(&mut self) -> &mut O {
70        &mut self.optimizer
71    }
72}
73
74impl<O: Optimizer> LRScheduler for ExponentialLR<O> {
75    fn step(&mut self) {
76        let current_lr = self.optimizer.get_lr();
77        self.optimizer.set_lr(current_lr * self.gamma);
78    }
79
80    fn get_lr(&self) -> f32 {
81        self.optimizer.get_lr()
82    }
83}
84
85/// Cosine annealing scheduler
86pub struct CosineAnnealingLR<O: Optimizer> {
87    optimizer: O,
88    t_max: usize,
89    eta_min: f32,
90    base_lr: f32,
91    current_step: usize,
92}
93
94impl<O: Optimizer> CosineAnnealingLR<O> {
95    pub fn new(optimizer: O, t_max: usize, eta_min: f32) -> Self {
96        let base_lr = optimizer.get_lr();
97        CosineAnnealingLR {
98            optimizer,
99            t_max,
100            eta_min,
101            base_lr,
102            current_step: 0,
103        }
104    }
105
106    pub fn optimizer(&self) -> &O {
107        &self.optimizer
108    }
109
110    pub fn optimizer_mut(&mut self) -> &mut O {
111        &mut self.optimizer
112    }
113}
114
115impl<O: Optimizer> LRScheduler for CosineAnnealingLR<O> {
116    fn step(&mut self) {
117        self.current_step += 1;
118        let t = self.current_step % self.t_max;
119        let cos_val = (std::f32::consts::PI * t as f32 / self.t_max as f32).cos();
120        let new_lr = self.eta_min + (self.base_lr - self.eta_min) * (1.0 + cos_val) / 2.0;
121        self.optimizer.set_lr(new_lr);
122    }
123
124    fn get_lr(&self) -> f32 {
125        self.optimizer.get_lr()
126    }
127}
128
129/// Linear warmup scheduler
130pub struct LinearWarmup<O: Optimizer> {
131    optimizer: O,
132    warmup_steps: usize,
133    target_lr: f32,
134    current_step: usize,
135}
136
137impl<O: Optimizer> LinearWarmup<O> {
138    pub fn new(mut optimizer: O, warmup_steps: usize) -> Self {
139        let target_lr = optimizer.get_lr();
140        optimizer.set_lr(0.0);
141        LinearWarmup {
142            optimizer,
143            warmup_steps,
144            target_lr,
145            current_step: 0,
146        }
147    }
148
149    pub fn optimizer(&self) -> &O {
150        &self.optimizer
151    }
152
153    pub fn optimizer_mut(&mut self) -> &mut O {
154        &mut self.optimizer
155    }
156}
157
158impl<O: Optimizer> LRScheduler for LinearWarmup<O> {
159    fn step(&mut self) {
160        self.current_step += 1;
161        if self.current_step <= self.warmup_steps {
162            let new_lr = self.target_lr * (self.current_step as f32 / self.warmup_steps as f32);
163            self.optimizer.set_lr(new_lr);
164        }
165    }
166
167    fn get_lr(&self) -> f32 {
168        self.optimizer.get_lr()
169    }
170}
171
172/// Reduce on plateau scheduler
173pub struct ReduceLROnPlateau<O: Optimizer> {
174    optimizer: O,
175    factor: f32,
176    patience: usize,
177    min_lr: f32,
178    best_loss: f32,
179    num_bad_epochs: usize,
180}
181
182impl<O: Optimizer> ReduceLROnPlateau<O> {
183    pub fn new(optimizer: O, factor: f32, patience: usize) -> Self {
184        ReduceLROnPlateau {
185            optimizer,
186            factor,
187            patience,
188            min_lr: 1e-8,
189            best_loss: f32::INFINITY,
190            num_bad_epochs: 0,
191        }
192    }
193
194    pub fn min_lr(mut self, min_lr: f32) -> Self {
195        self.min_lr = min_lr;
196        self
197    }
198
199    pub fn step_with_loss(&mut self, loss: f32) {
200        if loss < self.best_loss {
201            self.best_loss = loss;
202            self.num_bad_epochs = 0;
203        } else {
204            self.num_bad_epochs += 1;
205            
206            if self.num_bad_epochs >= self.patience {
207                let current_lr = self.optimizer.get_lr();
208                let new_lr = (current_lr * self.factor).max(self.min_lr);
209                self.optimizer.set_lr(new_lr);
210                self.num_bad_epochs = 0;
211            }
212        }
213    }
214
215    pub fn optimizer(&self) -> &O {
216        &self.optimizer
217    }
218
219    pub fn optimizer_mut(&mut self) -> &mut O {
220        &mut self.optimizer
221    }
222}