ghostflow_optim/
scheduler.rs1use crate::optimizer::Optimizer;
4
5pub trait LRScheduler {
7 fn step(&mut self);
8 fn get_lr(&self) -> f32;
9}
10
11pub struct StepLR<O: Optimizer> {
13 optimizer: O,
14 step_size: usize,
15 gamma: f32,
16 current_step: usize,
17 base_lr: f32,
18}
19
20impl<O: Optimizer> StepLR<O> {
21 pub fn new(optimizer: O, step_size: usize, gamma: f32) -> Self {
22 let base_lr = optimizer.get_lr();
23 StepLR {
24 optimizer,
25 step_size,
26 gamma,
27 current_step: 0,
28 base_lr,
29 }
30 }
31
32 pub fn optimizer(&self) -> &O {
33 &self.optimizer
34 }
35
36 pub fn optimizer_mut(&mut self) -> &mut O {
37 &mut self.optimizer
38 }
39}
40
41impl<O: Optimizer> LRScheduler for StepLR<O> {
42 fn step(&mut self) {
43 self.current_step += 1;
44 let num_decays = self.current_step / self.step_size;
45 let new_lr = self.base_lr * self.gamma.powi(num_decays as i32);
46 self.optimizer.set_lr(new_lr);
47 }
48
49 fn get_lr(&self) -> f32 {
50 self.optimizer.get_lr()
51 }
52}
53
54pub struct ExponentialLR<O: Optimizer> {
56 optimizer: O,
57 gamma: f32,
58}
59
60impl<O: Optimizer> ExponentialLR<O> {
61 pub fn new(optimizer: O, gamma: f32) -> Self {
62 ExponentialLR { optimizer, gamma }
63 }
64
65 pub fn optimizer(&self) -> &O {
66 &self.optimizer
67 }
68
69 pub fn optimizer_mut(&mut self) -> &mut O {
70 &mut self.optimizer
71 }
72}
73
74impl<O: Optimizer> LRScheduler for ExponentialLR<O> {
75 fn step(&mut self) {
76 let current_lr = self.optimizer.get_lr();
77 self.optimizer.set_lr(current_lr * self.gamma);
78 }
79
80 fn get_lr(&self) -> f32 {
81 self.optimizer.get_lr()
82 }
83}
84
85pub struct CosineAnnealingLR<O: Optimizer> {
87 optimizer: O,
88 t_max: usize,
89 eta_min: f32,
90 base_lr: f32,
91 current_step: usize,
92}
93
94impl<O: Optimizer> CosineAnnealingLR<O> {
95 pub fn new(optimizer: O, t_max: usize, eta_min: f32) -> Self {
96 let base_lr = optimizer.get_lr();
97 CosineAnnealingLR {
98 optimizer,
99 t_max,
100 eta_min,
101 base_lr,
102 current_step: 0,
103 }
104 }
105
106 pub fn optimizer(&self) -> &O {
107 &self.optimizer
108 }
109
110 pub fn optimizer_mut(&mut self) -> &mut O {
111 &mut self.optimizer
112 }
113}
114
115impl<O: Optimizer> LRScheduler for CosineAnnealingLR<O> {
116 fn step(&mut self) {
117 self.current_step += 1;
118 let t = self.current_step % self.t_max;
119 let cos_val = (std::f32::consts::PI * t as f32 / self.t_max as f32).cos();
120 let new_lr = self.eta_min + (self.base_lr - self.eta_min) * (1.0 + cos_val) / 2.0;
121 self.optimizer.set_lr(new_lr);
122 }
123
124 fn get_lr(&self) -> f32 {
125 self.optimizer.get_lr()
126 }
127}
128
129pub struct LinearWarmup<O: Optimizer> {
131 optimizer: O,
132 warmup_steps: usize,
133 target_lr: f32,
134 current_step: usize,
135}
136
137impl<O: Optimizer> LinearWarmup<O> {
138 pub fn new(mut optimizer: O, warmup_steps: usize) -> Self {
139 let target_lr = optimizer.get_lr();
140 optimizer.set_lr(0.0);
141 LinearWarmup {
142 optimizer,
143 warmup_steps,
144 target_lr,
145 current_step: 0,
146 }
147 }
148
149 pub fn optimizer(&self) -> &O {
150 &self.optimizer
151 }
152
153 pub fn optimizer_mut(&mut self) -> &mut O {
154 &mut self.optimizer
155 }
156}
157
158impl<O: Optimizer> LRScheduler for LinearWarmup<O> {
159 fn step(&mut self) {
160 self.current_step += 1;
161 if self.current_step <= self.warmup_steps {
162 let new_lr = self.target_lr * (self.current_step as f32 / self.warmup_steps as f32);
163 self.optimizer.set_lr(new_lr);
164 }
165 }
166
167 fn get_lr(&self) -> f32 {
168 self.optimizer.get_lr()
169 }
170}
171
172pub struct ReduceLROnPlateau<O: Optimizer> {
174 optimizer: O,
175 factor: f32,
176 patience: usize,
177 min_lr: f32,
178 best_loss: f32,
179 num_bad_epochs: usize,
180}
181
182impl<O: Optimizer> ReduceLROnPlateau<O> {
183 pub fn new(optimizer: O, factor: f32, patience: usize) -> Self {
184 ReduceLROnPlateau {
185 optimizer,
186 factor,
187 patience,
188 min_lr: 1e-8,
189 best_loss: f32::INFINITY,
190 num_bad_epochs: 0,
191 }
192 }
193
194 pub fn min_lr(mut self, min_lr: f32) -> Self {
195 self.min_lr = min_lr;
196 self
197 }
198
199 pub fn step_with_loss(&mut self, loss: f32) {
200 if loss < self.best_loss {
201 self.best_loss = loss;
202 self.num_bad_epochs = 0;
203 } else {
204 self.num_bad_epochs += 1;
205
206 if self.num_bad_epochs >= self.patience {
207 let current_lr = self.optimizer.get_lr();
208 let new_lr = (current_lr * self.factor).max(self.min_lr);
209 self.optimizer.set_lr(new_lr);
210 self.num_bad_epochs = 0;
211 }
212 }
213 }
214
215 pub fn optimizer(&self) -> &O {
216 &self.optimizer
217 }
218
219 pub fn optimizer_mut(&mut self) -> &mut O {
220 &mut self.optimizer
221 }
222}