rust_lstm/
optimizers.rs

1use ndarray::Array2;
2use std::collections::HashMap;
3use crate::schedulers::LearningRateScheduler;
4
5/// Optimizer trait for parameter updates during training
6pub trait Optimizer {
7    fn update(&mut self, param_id: &str, param: &mut Array2<f64>, gradient: &Array2<f64>);
8    fn reset(&mut self);
9    
10    /// Set the learning rate dynamically (for compatibility with schedulers)
11    fn set_learning_rate(&mut self, lr: f64);
12    
13    /// Get the current learning rate
14    fn get_learning_rate(&self) -> f64;
15}
16
17/// Stochastic Gradient Descent optimizer
18pub struct SGD {
19    learning_rate: f64,
20}
21
22impl SGD {
23    pub fn new(learning_rate: f64) -> Self {
24        SGD { learning_rate }
25    }
26}
27
28impl Optimizer for SGD {
29    fn update(&mut self, _param_id: &str, param: &mut Array2<f64>, gradient: &Array2<f64>) {
30        *param = &*param - self.learning_rate * gradient;
31    }
32    
33    fn reset(&mut self) {
34        // SGD has no state to reset
35    }
36    
37    fn set_learning_rate(&mut self, lr: f64) {
38        self.learning_rate = lr;
39    }
40    
41    fn get_learning_rate(&self) -> f64 {
42        self.learning_rate
43    }
44}
45
46/// Adam optimizer with adaptive learning rates
47pub struct Adam {
48    learning_rate: f64,
49    beta1: f64,
50    beta2: f64,
51    epsilon: f64,
52    t: i32,
53    m: HashMap<String, Array2<f64>>,
54    v: HashMap<String, Array2<f64>>,
55}
56
57impl Adam {
58    pub fn new(learning_rate: f64) -> Self {
59        Adam::with_params(learning_rate, 0.9, 0.999, 1e-8)
60    }
61    
62    pub fn with_params(learning_rate: f64, beta1: f64, beta2: f64, epsilon: f64) -> Self {
63        Adam {
64            learning_rate,
65            beta1,
66            beta2,
67            epsilon,
68            t: 0,
69            m: HashMap::new(),
70            v: HashMap::new(),
71        }
72    }
73}
74
75impl Optimizer for Adam {
76    fn update(&mut self, param_id: &str, param: &mut Array2<f64>, gradient: &Array2<f64>) {
77        self.t += 1;
78        
79        if !self.m.contains_key(param_id) {
80            self.m.insert(param_id.to_string(), Array2::zeros(param.raw_dim()));
81            self.v.insert(param_id.to_string(), Array2::zeros(param.raw_dim()));
82        }
83        
84        let m_t = self.m.get_mut(param_id).unwrap();
85        let v_t = self.v.get_mut(param_id).unwrap();
86        
87        *m_t = self.beta1 * &*m_t + (1.0 - self.beta1) * gradient;
88        *v_t = self.beta2 * &*v_t + (1.0 - self.beta2) * gradient * gradient;
89        
90        let m_hat = &*m_t / (1.0 - self.beta1.powi(self.t));
91        let v_hat = &*v_t / (1.0 - self.beta2.powi(self.t));
92        
93        let update = self.learning_rate * m_hat / (v_hat.map(|x| x.sqrt()) + self.epsilon);
94        *param = &*param - update;
95    }
96    
97    fn reset(&mut self) {
98        self.t = 0;
99        self.m.clear();
100        self.v.clear();
101    }
102    
103    fn set_learning_rate(&mut self, lr: f64) {
104        self.learning_rate = lr;
105    }
106    
107    fn get_learning_rate(&self) -> f64 {
108        self.learning_rate
109    }
110}
111
112/// RMSprop optimizer
113pub struct RMSprop {
114    learning_rate: f64,
115    alpha: f64,
116    epsilon: f64,
117    v: HashMap<String, Array2<f64>>,
118}
119
120impl RMSprop {
121    pub fn new(learning_rate: f64) -> Self {
122        RMSprop::with_params(learning_rate, 0.99, 1e-8)
123    }
124    
125    pub fn with_params(learning_rate: f64, alpha: f64, epsilon: f64) -> Self {
126        RMSprop {
127            learning_rate,
128            alpha,
129            epsilon,
130            v: HashMap::new(),
131        }
132    }
133}
134
135impl Optimizer for RMSprop {
136    fn update(&mut self, param_id: &str, param: &mut Array2<f64>, gradient: &Array2<f64>) {
137        if !self.v.contains_key(param_id) {
138            self.v.insert(param_id.to_string(), Array2::zeros(param.raw_dim()));
139        }
140        
141        let v_t = self.v.get_mut(param_id).unwrap();
142        
143        *v_t = self.alpha * &*v_t + (1.0 - self.alpha) * gradient * gradient;
144        
145        let update = self.learning_rate * gradient / (v_t.map(|x| x.sqrt()) + self.epsilon);
146        *param = &*param - update;
147    }
148    
149    fn reset(&mut self) {
150        self.v.clear();
151    }
152    
153    fn set_learning_rate(&mut self, lr: f64) {
154        self.learning_rate = lr;
155    }
156    
157    fn get_learning_rate(&self) -> f64 {
158        self.learning_rate
159    }
160}
161
162/// Wrapper that combines an optimizer with a learning rate scheduler
163pub struct ScheduledOptimizer<O: Optimizer, S: LearningRateScheduler> {
164    optimizer: O,
165    scheduler: S,
166    base_lr: f64,
167    current_epoch: usize,
168}
169
170impl<O: Optimizer, S: LearningRateScheduler> ScheduledOptimizer<O, S> {
171    pub fn new(optimizer: O, scheduler: S, base_lr: f64) -> Self {
172        ScheduledOptimizer {
173            optimizer,
174            scheduler,
175            base_lr,
176            current_epoch: 0,
177        }
178    }
179    
180    /// Step the scheduler (should be called at the end of each epoch)
181    pub fn step(&mut self) {
182        self.current_epoch += 1;
183        let new_lr = self.scheduler.get_lr(self.current_epoch, self.base_lr);
184        self.optimizer.set_learning_rate(new_lr);
185    }
186    
187    /// Step with validation loss (for ReduceLROnPlateau)
188    pub fn step_with_val_loss(&mut self, val_loss: f64) {
189        self.current_epoch += 1;
190        // For ReduceLROnPlateau, we need special handling
191        let base_lr = self.base_lr; // Copy the value before mutable borrow
192        let new_lr = if let Some(plateau_scheduler) = self.scheduler_as_plateau_mut() {
193            plateau_scheduler.step(val_loss, base_lr)
194        } else {
195            self.scheduler.get_lr(self.current_epoch, self.base_lr)
196        };
197        self.optimizer.set_learning_rate(new_lr);
198    }
199    
200    /// Get the current learning rate
201    pub fn get_current_lr(&self) -> f64 {
202        self.optimizer.get_learning_rate()
203    }
204    
205    /// Get the current epoch
206    pub fn get_current_epoch(&self) -> usize {
207        self.current_epoch
208    }
209    
210    /// Reset both optimizer and scheduler
211    pub fn reset(&mut self) {
212        self.optimizer.reset();
213        self.scheduler.reset();
214        self.current_epoch = 0;
215        self.optimizer.set_learning_rate(self.base_lr);
216    }
217    
218    /// Get the scheduler name for logging
219    pub fn scheduler_name(&self) -> &'static str {
220        self.scheduler.name()
221    }
222    
223    /// Helper method to downcast scheduler to ReduceLROnPlateau if possible
224    fn scheduler_as_plateau_mut(&mut self) -> Option<&mut crate::schedulers::ReduceLROnPlateau> {
225        // This is a bit of a hack since we can't downcast traits easily in Rust
226        // In practice, users should use step_with_val_loss only with ReduceLROnPlateau
227        // For now, we'll return None and let the caller handle it properly
228        None
229    }
230}
231
232impl<O: Optimizer, S: LearningRateScheduler> Optimizer for ScheduledOptimizer<O, S> {
233    fn update(&mut self, param_id: &str, param: &mut Array2<f64>, gradient: &Array2<f64>) {
234        self.optimizer.update(param_id, param, gradient);
235    }
236    
237    fn reset(&mut self) {
238        self.reset(); // Call our custom reset that handles both optimizer and scheduler
239    }
240    
241    fn set_learning_rate(&mut self, lr: f64) {
242        self.base_lr = lr;
243        self.optimizer.set_learning_rate(lr);
244    }
245    
246    fn get_learning_rate(&self) -> f64 {
247        self.optimizer.get_learning_rate()
248    }
249}
250
251/// Helper functions to create common optimizer-scheduler combinations
252impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::ConstantLR> {
253    pub fn constant(optimizer: O, lr: f64) -> Self {
254        Self::new(optimizer, crate::schedulers::ConstantLR, lr)
255    }
256}
257
258impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::StepLR> {
259    pub fn step_lr(optimizer: O, lr: f64, step_size: usize, gamma: f64) -> Self {
260        Self::new(optimizer, crate::schedulers::StepLR::new(step_size, gamma), lr)
261    }
262}
263
264impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::ExponentialLR> {
265    pub fn exponential(optimizer: O, lr: f64, gamma: f64) -> Self {
266        Self::new(optimizer, crate::schedulers::ExponentialLR::new(gamma), lr)
267    }
268}
269
270impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::CosineAnnealingLR> {
271    pub fn cosine_annealing(optimizer: O, lr: f64, t_max: usize, eta_min: f64) -> Self {
272        Self::new(optimizer, crate::schedulers::CosineAnnealingLR::new(t_max, eta_min), lr)
273    }
274}
275
276impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::PolynomialLR> {
277    pub fn polynomial(optimizer: O, lr: f64, total_iters: usize, power: f64, end_lr: f64) -> Self {
278        Self::new(optimizer, crate::schedulers::PolynomialLR::new(total_iters, power, end_lr), lr)
279    }
280}
281
282impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::CyclicalLR> {
283    pub fn cyclical(optimizer: O, base_lr: f64, max_lr: f64, step_size: usize) -> Self {
284        Self::new(optimizer, crate::schedulers::CyclicalLR::new(base_lr, max_lr, step_size), base_lr)
285    }
286    
287    pub fn cyclical_triangular2(optimizer: O, base_lr: f64, max_lr: f64, step_size: usize) -> Self {
288        let scheduler = crate::schedulers::CyclicalLR::new(base_lr, max_lr, step_size)
289            .with_mode(crate::schedulers::CyclicalMode::Triangular2);
290        Self::new(optimizer, scheduler, base_lr)
291    }
292    
293    pub fn cyclical_exp_range(optimizer: O, base_lr: f64, max_lr: f64, step_size: usize, gamma: f64) -> Self {
294        let scheduler = crate::schedulers::CyclicalLR::new(base_lr, max_lr, step_size)
295            .with_mode(crate::schedulers::CyclicalMode::ExpRange)
296            .with_gamma(gamma);
297        Self::new(optimizer, scheduler, base_lr)
298    }
299}
300
301impl<O: Optimizer> ScheduledOptimizer<O, crate::schedulers::OneCycleLR> {
302    pub fn one_cycle(optimizer: O, max_lr: f64, total_steps: usize) -> Self {
303        Self::new(optimizer, crate::schedulers::OneCycleLR::new(max_lr, total_steps), max_lr)
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use ndarray::arr2;
311
312    #[test]
313    fn test_sgd_optimizer() {
314        let mut optimizer = SGD::new(0.1);
315        let mut param = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
316        let gradient = arr2(&[[0.1, 0.2], [0.3, 0.4]]);
317        
318        let original_param = param.clone();
319        optimizer.update("test_param", &mut param, &gradient);
320        
321        let expected = &original_param - 0.1 * &gradient;
322        assert!((param - expected).map(|x| x.abs()).sum() < 1e-10);
323    }
324
325    #[test]
326    fn test_adam_optimizer() {
327        let mut optimizer = Adam::new(0.001);
328        let mut param = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
329        let gradient = arr2(&[[0.1, 0.2], [0.3, 0.4]]);
330        
331        let original_param = param.clone();
332        optimizer.update("test_param", &mut param, &gradient);
333        
334        assert!((param - original_param).map(|x| x.abs()).sum() > 1e-10);
335    }
336
337    #[test]
338    fn test_rmsprop_optimizer() {
339        let mut optimizer = RMSprop::new(0.01);
340        let mut param = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
341        let gradient = arr2(&[[0.1, 0.2], [0.3, 0.4]]);
342        
343        let original_param = param.clone();
344        optimizer.update("test_param", &mut param, &gradient);
345        
346        assert!((param - original_param).map(|x| x.abs()).sum() > 1e-10);
347    }
348}