1use ndarray::Array2;
2use std::collections::HashMap;
3use crate::schedulers::LearningRateScheduler;
4
5pub trait Optimizer {
7 fn update(&mut self, param_id: &str, param: &mut Array2<f64>, gradient: &Array2<f64>);
8 fn reset(&mut self);
9
10 fn set_learning_rate(&mut self, lr: f64);
12
13 fn get_learning_rate(&self) -> f64;
15}
16
17pub struct SGD {
19 learning_rate: f64,
20}
21
22impl SGD {
23 pub fn new(learning_rate: f64) -> Self {
24 SGD { learning_rate }
25 }
26}
27
28impl Optimizer for SGD {
29 fn update(&mut self, _param_id: &str, param: &mut Array2<f64>, gradient: &Array2<f64>) {
30 *param = &*param - self.learning_rate * gradient;
31 }
32
33 fn reset(&mut self) {
34 }
36
37 fn set_learning_rate(&mut self, lr: f64) {
38 self.learning_rate = lr;
39 }
40
41 fn get_learning_rate(&self) -> f64 {
42 self.learning_rate
43 }
44}
45
46pub struct Adam {
48 learning_rate: f64,
49 beta1: f64,
50 beta2: f64,
51 epsilon: f64,
52 t: i32,
53 m: HashMap<String, Array2<f64>>,
54 v: HashMap<String, Array2<f64>>,
55}
56
57impl Adam {
58 pub fn new(learning_rate: f64) -> Self {
59 Adam::with_params(learning_rate, 0.9, 0.999, 1e-8)
60 }
61
62 pub fn with_params(learning_rate: f64, beta1: f64, beta2: f64, epsilon: f64) -> Self {
63 Adam {
64 learning_rate,
65 beta1,
66 beta2,
67 epsilon,
68 t: 0,
69 m: HashMap::new(),
70 v: HashMap::new(),
71 }
72 }
73}
74
75impl Optimizer for Adam {
76 fn update(&mut self, param_id: &str, param: &mut Array2<f64>, gradient: &Array2<f64>) {
77 self.t += 1;
78
79 if !self.m.contains_key(param_id) {
80 self.m.insert(param_id.to_string(), Array2::zeros(param.raw_dim()));
81 self.v.insert(param_id.to_string(), Array2::zeros(param.raw_dim()));
82 }
83
84 let m_t = self.m.get_mut(param_id).unwrap();
85 let v_t = self.v.get_mut(param_id).unwrap();
86
87 *m_t = self.beta1 * &*m_t + (1.0 - self.beta1) * gradient;
88 *v_t = self.beta2 * &*v_t + (1.0 - self.beta2) * gradient * gradient;
89
90 let m_hat = &*m_t / (1.0 - self.beta1.powi(self.t));
91 let v_hat = &*v_t / (1.0 - self.beta2.powi(self.t));
92
93 let update = self.learning_rate * m_hat / (v_hat.map(|x| x.sqrt()) + self.epsilon);
94 *param = &*param - update;
95 }
96
97 fn reset(&mut self) {
98 self.t = 0;
99 self.m.clear();
100 self.v.clear();
101 }
102
103 fn set_learning_rate(&mut self, lr: f64) {
104 self.learning_rate = lr;
105 }
106
107 fn get_learning_rate(&self) -> f64 {
108 self.learning_rate
109 }
110}
111
112pub struct RMSprop {
114 learning_rate: f64,
115 alpha: f64,
116 epsilon: f64,
117 v: HashMap<String, Array2<f64>>,
118}
119
120impl RMSprop {
121 pub fn new(learning_rate: f64) -> Self {
122 RMSprop::with_params(learning_rate, 0.99, 1e-8)
123 }
124
125 pub fn with_params(learning_rate: f64, alpha: f64, epsilon: f64) -> Self {
126 RMSprop {
127 learning_rate,
128 alpha,
129 epsilon,
130 v: HashMap::new(),
131 }
132 }
133}
134
135impl Optimizer for RMSprop {
136 fn update(&mut self, param_id: &str, param: &mut Array2<f64>, gradient: &Array2<f64>) {
137 if !self.v.contains_key(param_id) {
138 self.v.insert(param_id.to_string(), Array2::zeros(param.raw_dim()));
139 }
140
141 let v_t = self.v.get_mut(param_id).unwrap();
142
143 *v_t = self.alpha * &*v_t + (1.0 - self.alpha) * gradient * gradient;
144
145 let update = self.learning_rate * gradient / (v_t.map(|x| x.sqrt()) + self.epsilon);
146 *param = &*param - update;
147 }
148
149 fn reset(&mut self) {
150 self.v.clear();
151 }
152
153 fn set_learning_rate(&mut self, lr: f64) {
154 self.learning_rate = lr;
155 }
156
157 fn get_learning_rate(&self) -> f64 {
158 self.learning_rate
159 }
160}
161
162pub struct ScheduledOptimizer<O: Optimizer, S: LearningRateScheduler> {
164 optimizer: O,
165 scheduler: S,
166 base_lr: f64,
167 current_epoch: usize,
168}
169
170impl<O: Optimizer, S: LearningRateScheduler> ScheduledOptimizer<O, S> {
171 pub fn new(optimizer: O, scheduler: S, base_lr: f64) -> Self {
172 ScheduledOptimizer {
173 optimizer,
174 scheduler,
175 base_lr,
176 current_epoch: 0,
177 }
178 }
179
180 pub fn step(&mut self) {
182 self.current_epoch += 1;
183 let new_lr = self.scheduler.get_lr(self.current_epoch, self.base_lr);
184 self.optimizer.set_learning_rate(new_lr);
185 }
186
187 pub fn step_with_val_loss(&mut self, val_loss: f64) {
189 self.current_epoch += 1;
190 let base_lr = self.base_lr; let new_lr = if let Some(plateau_scheduler) = self.scheduler_as_plateau_mut() {
193 plateau_scheduler.step(val_loss, base_lr)
194 } else {
195 self.scheduler.get_lr(self.current_epoch, self.base_lr)
196 };
197 self.optimizer.set_learning_rate(new_lr);
198 }
199
200 pub fn get_current_lr(&self) -> f64 {
202 self.optimizer.get_learning_rate()
203 }
204
205 pub fn get_current_epoch(&self) -> usize {
207 self.current_epoch
208 }
209
210 pub fn reset(&mut self) {
212 self.optimizer.reset();
213 self.scheduler.reset();
214 self.current_epoch = 0;
215 self.optimizer.set_learning_rate(self.base_lr);
216 }
217
218 pub fn scheduler_name(&self) -> &'static str {
220 self.scheduler.name()
221 }
222
223 fn scheduler_as_plateau_mut(&mut self) -> Option<&mut crate::schedulers::ReduceLROnPlateau> {
225 None
229 }
230}
231
232impl<O: Optimizer, S: LearningRateScheduler> Optimizer for ScheduledOptimizer<O, S> {
233 fn update(&mut self, param_id: &str, param: &mut Array2<f64>, gradient: &Array2<f64>) {
234 self.optimizer.update(param_id, param, gradient);
235 }
236
237 fn reset(&mut self) {
238 self.reset(); }
240
241 fn set_learning_rate(&mut self, lr: f64) {
242 self.base_lr = lr;
243 self.optimizer.set_learning_rate(lr);
244 }
245
246 fn get_learning_rate(&self) -> f64 {
247 self.optimizer.get_learning_rate()
248 }
249}
250
251impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::ConstantLR> {
253 pub fn constant(optimizer: O, lr: f64) -> Self {
254 Self::new(optimizer, crate::schedulers::ConstantLR, lr)
255 }
256}
257
258impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::StepLR> {
259 pub fn step_lr(optimizer: O, lr: f64, step_size: usize, gamma: f64) -> Self {
260 Self::new(optimizer, crate::schedulers::StepLR::new(step_size, gamma), lr)
261 }
262}
263
264impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::ExponentialLR> {
265 pub fn exponential(optimizer: O, lr: f64, gamma: f64) -> Self {
266 Self::new(optimizer, crate::schedulers::ExponentialLR::new(gamma), lr)
267 }
268}
269
270impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::CosineAnnealingLR> {
271 pub fn cosine_annealing(optimizer: O, lr: f64, t_max: usize, eta_min: f64) -> Self {
272 Self::new(optimizer, crate::schedulers::CosineAnnealingLR::new(t_max, eta_min), lr)
273 }
274}
275
276impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::PolynomialLR> {
277 pub fn polynomial(optimizer: O, lr: f64, total_iters: usize, power: f64, end_lr: f64) -> Self {
278 Self::new(optimizer, crate::schedulers::PolynomialLR::new(total_iters, power, end_lr), lr)
279 }
280}
281
282impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::CyclicalLR> {
283 pub fn cyclical(optimizer: O, base_lr: f64, max_lr: f64, step_size: usize) -> Self {
284 Self::new(optimizer, crate::schedulers::CyclicalLR::new(base_lr, max_lr, step_size), base_lr)
285 }
286
287 pub fn cyclical_triangular2(optimizer: O, base_lr: f64, max_lr: f64, step_size: usize) -> Self {
288 let scheduler = crate::schedulers::CyclicalLR::new(base_lr, max_lr, step_size)
289 .with_mode(crate::schedulers::CyclicalMode::Triangular2);
290 Self::new(optimizer, scheduler, base_lr)
291 }
292
293 pub fn cyclical_exp_range(optimizer: O, base_lr: f64, max_lr: f64, step_size: usize, gamma: f64) -> Self {
294 let scheduler = crate::schedulers::CyclicalLR::new(base_lr, max_lr, step_size)
295 .with_mode(crate::schedulers::CyclicalMode::ExpRange)
296 .with_gamma(gamma);
297 Self::new(optimizer, scheduler, base_lr)
298 }
299}
300
301impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::OneCycleLR> {
302 pub fn one_cycle(optimizer: O, max_lr: f64, total_steps: usize) -> Self {
303 Self::new(optimizer, crate::schedulers::OneCycleLR::new(max_lr, total_steps), max_lr)
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use ndarray::arr2;
311
312 #[test]
313 fn test_sgd_optimizer() {
314 let mut optimizer = SGD::new(0.1);
315 let mut param = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
316 let gradient = arr2(&[[0.1, 0.2], [0.3, 0.4]]);
317
318 let original_param = param.clone();
319 optimizer.update("test_param", &mut param, &gradient);
320
321 let expected = &original_param - 0.1 * &gradient;
322 assert!((param - expected).map(|x| x.abs()).sum() < 1e-10);
323 }
324
325 #[test]
326 fn test_adam_optimizer() {
327 let mut optimizer = Adam::new(0.001);
328 let mut param = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
329 let gradient = arr2(&[[0.1, 0.2], [0.3, 0.4]]);
330
331 let original_param = param.clone();
332 optimizer.update("test_param", &mut param, &gradient);
333
334 assert!((param - original_param).map(|x| x.abs()).sum() > 1e-10);
335 }
336
337 #[test]
338 fn test_rmsprop_optimizer() {
339 let mut optimizer = RMSprop::new(0.01);
340 let mut param = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
341 let gradient = arr2(&[[0.1, 0.2], [0.3, 0.4]]);
342
343 let original_param = param.clone();
344 optimizer.update("test_param", &mut param, &gradient);
345
346 assert!((param - original_param).map(|x| x.abs()).sum() > 1e-10);
347 }
348}