axonml_optim/
lr_scheduler.rs

1//! Learning Rate Schedulers
2//!
3//! Provides learning rate scheduling strategies for training.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use crate::optimizer::Optimizer;
9
10// =============================================================================
11// LRScheduler Trait
12// =============================================================================
13
14/// Trait for learning rate schedulers.
15pub trait LRScheduler {
16    /// Updates the learning rate.
17    fn step<O: Optimizer>(&mut self, optimizer: &mut O);
18
19    /// Returns the current learning rate.
20    fn get_last_lr(&self) -> f32;
21
22    /// Returns the current epoch/step count.
23    fn get_step(&self) -> usize;
24}
25
26// =============================================================================
27// StepLR
28// =============================================================================
29
30/// Decays learning rate by gamma every `step_size` epochs.
31///
32/// lr = `initial_lr` * gamma^(epoch // `step_size`)
33pub struct StepLR {
34    initial_lr: f32,
35    step_size: usize,
36    gamma: f32,
37    current_step: usize,
38    last_lr: f32,
39}
40
41impl StepLR {
42    /// Creates a new `StepLR` scheduler.
43    pub fn new<O: Optimizer>(optimizer: &O, step_size: usize, gamma: f32) -> Self {
44        let initial_lr = optimizer.get_lr();
45        Self {
46            initial_lr,
47            step_size,
48            gamma,
49            current_step: 0,
50            last_lr: initial_lr,
51        }
52    }
53}
54
55impl LRScheduler for StepLR {
56    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
57        self.current_step += 1;
58        let num_decays = self.current_step / self.step_size;
59        let new_lr = self.initial_lr * self.gamma.powi(num_decays as i32);
60        optimizer.set_lr(new_lr);
61        self.last_lr = new_lr;
62    }
63
64    fn get_last_lr(&self) -> f32 {
65        self.last_lr
66    }
67
68    fn get_step(&self) -> usize {
69        self.current_step
70    }
71}
72
73// =============================================================================
74// MultiStepLR
75// =============================================================================
76
77/// Decays learning rate by gamma at each milestone.
78pub struct MultiStepLR {
79    initial_lr: f32,
80    milestones: Vec<usize>,
81    gamma: f32,
82    current_step: usize,
83    last_lr: f32,
84    milestone_idx: usize,
85}
86
87impl MultiStepLR {
88    /// Creates a new `MultiStepLR` scheduler.
89    pub fn new<O: Optimizer>(optimizer: &O, mut milestones: Vec<usize>, gamma: f32) -> Self {
90        let initial_lr = optimizer.get_lr();
91        milestones.sort_unstable();
92        Self {
93            initial_lr,
94            milestones,
95            gamma,
96            current_step: 0,
97            last_lr: initial_lr,
98            milestone_idx: 0,
99        }
100    }
101}
102
103impl LRScheduler for MultiStepLR {
104    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
105        self.current_step += 1;
106
107        // Check if we've passed any milestones
108        while self.milestone_idx < self.milestones.len()
109            && self.current_step >= self.milestones[self.milestone_idx]
110        {
111            self.milestone_idx += 1;
112        }
113
114        let new_lr = self.initial_lr * self.gamma.powi(self.milestone_idx as i32);
115        optimizer.set_lr(new_lr);
116        self.last_lr = new_lr;
117    }
118
119    fn get_last_lr(&self) -> f32 {
120        self.last_lr
121    }
122
123    fn get_step(&self) -> usize {
124        self.current_step
125    }
126}
127
128// =============================================================================
129// ExponentialLR
130// =============================================================================
131
132/// Decays learning rate by gamma every epoch.
133///
134/// lr = `initial_lr` * gamma^epoch
135pub struct ExponentialLR {
136    initial_lr: f32,
137    gamma: f32,
138    current_step: usize,
139    last_lr: f32,
140}
141
142impl ExponentialLR {
143    /// Creates a new `ExponentialLR` scheduler.
144    pub fn new<O: Optimizer>(optimizer: &O, gamma: f32) -> Self {
145        let initial_lr = optimizer.get_lr();
146        Self {
147            initial_lr,
148            gamma,
149            current_step: 0,
150            last_lr: initial_lr,
151        }
152    }
153}
154
155impl LRScheduler for ExponentialLR {
156    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
157        self.current_step += 1;
158        let new_lr = self.initial_lr * self.gamma.powi(self.current_step as i32);
159        optimizer.set_lr(new_lr);
160        self.last_lr = new_lr;
161    }
162
163    fn get_last_lr(&self) -> f32 {
164        self.last_lr
165    }
166
167    fn get_step(&self) -> usize {
168        self.current_step
169    }
170}
171
172// =============================================================================
173// CosineAnnealingLR
174// =============================================================================
175
176/// Cosine annealing learning rate scheduler.
177///
178/// lr = `eta_min` + (`initial_lr` - `eta_min`) * (1 + cos(pi * epoch / `T_max`)) / 2
179pub struct CosineAnnealingLR {
180    initial_lr: f32,
181    t_max: usize,
182    eta_min: f32,
183    current_step: usize,
184    last_lr: f32,
185}
186
187impl CosineAnnealingLR {
188    /// Creates a new `CosineAnnealingLR` scheduler.
189    pub fn new<O: Optimizer>(optimizer: &O, t_max: usize) -> Self {
190        Self::with_eta_min(optimizer, t_max, 0.0)
191    }
192
193    /// Creates a `CosineAnnealingLR` with minimum learning rate.
194    pub fn with_eta_min<O: Optimizer>(optimizer: &O, t_max: usize, eta_min: f32) -> Self {
195        let initial_lr = optimizer.get_lr();
196        Self {
197            initial_lr,
198            t_max,
199            eta_min,
200            current_step: 0,
201            last_lr: initial_lr,
202        }
203    }
204}
205
206impl LRScheduler for CosineAnnealingLR {
207    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
208        self.current_step += 1;
209
210        let progress = self.current_step as f32 / self.t_max as f32;
211        let new_lr = self.eta_min
212            + (self.initial_lr - self.eta_min) * (1.0 + (std::f32::consts::PI * progress).cos())
213                / 2.0;
214
215        optimizer.set_lr(new_lr);
216        self.last_lr = new_lr;
217    }
218
219    fn get_last_lr(&self) -> f32 {
220        self.last_lr
221    }
222
223    fn get_step(&self) -> usize {
224        self.current_step
225    }
226}
227
228// =============================================================================
229// ReduceLROnPlateau
230// =============================================================================
231
232/// Reduces learning rate when a metric has stopped improving.
233pub struct ReduceLROnPlateau {
234    mode: String,
235    factor: f32,
236    patience: usize,
237    threshold: f32,
238    cooldown: usize,
239    min_lr: f32,
240    best: f32,
241    num_bad_epochs: usize,
242    cooldown_counter: usize,
243    current_step: usize,
244    last_lr: f32,
245}
246
247impl ReduceLROnPlateau {
248    /// Creates a new `ReduceLROnPlateau` scheduler for minimizing.
249    pub fn new<O: Optimizer>(optimizer: &O) -> Self {
250        Self::with_options(optimizer, "min", 0.1, 10, 1e-4, 0, 0.0)
251    }
252
253    /// Creates a `ReduceLROnPlateau` with options.
254    pub fn with_options<O: Optimizer>(
255        optimizer: &O,
256        mode: &str,
257        factor: f32,
258        patience: usize,
259        threshold: f32,
260        cooldown: usize,
261        min_lr: f32,
262    ) -> Self {
263        let initial_lr = optimizer.get_lr();
264        let best = if mode == "min" {
265            f32::INFINITY
266        } else {
267            f32::NEG_INFINITY
268        };
269        Self {
270            mode: mode.to_string(),
271            factor,
272            patience,
273            threshold,
274            cooldown,
275            min_lr,
276            best,
277            num_bad_epochs: 0,
278            cooldown_counter: 0,
279            current_step: 0,
280            last_lr: initial_lr,
281        }
282    }
283
284    /// Steps the scheduler based on a metric value.
285    pub fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
286        self.current_step += 1;
287
288        // Check if we're in cooldown
289        if self.cooldown_counter > 0 {
290            self.cooldown_counter -= 1;
291            return;
292        }
293
294        // Check if metric improved
295        let improved = if self.mode == "min" {
296            metric < self.best * (1.0 - self.threshold)
297        } else {
298            metric > self.best * (1.0 + self.threshold)
299        };
300
301        if improved {
302            self.best = metric;
303            self.num_bad_epochs = 0;
304        } else {
305            self.num_bad_epochs += 1;
306        }
307
308        // Reduce learning rate if patience exceeded
309        if self.num_bad_epochs > self.patience {
310            let current_lr = optimizer.get_lr();
311            let new_lr = (current_lr * self.factor).max(self.min_lr);
312            optimizer.set_lr(new_lr);
313            self.last_lr = new_lr;
314            self.cooldown_counter = self.cooldown;
315            self.num_bad_epochs = 0;
316        }
317    }
318}
319
320impl LRScheduler for ReduceLROnPlateau {
321    fn step<O: Optimizer>(&mut self, _optimizer: &mut O) {
322        // This scheduler requires a metric value
323        // Use step_with_metric instead
324        self.current_step += 1;
325    }
326
327    fn get_last_lr(&self) -> f32 {
328        self.last_lr
329    }
330
331    fn get_step(&self) -> usize {
332        self.current_step
333    }
334}
335
336// =============================================================================
337// OneCycleLR
338// =============================================================================
339
340/// One-cycle learning rate scheduler.
341///
342/// Implements the 1cycle policy from "Super-Convergence" paper.
343pub struct OneCycleLR {
344    max_lr: f32,
345    total_steps: usize,
346    pct_start: f32,
347    div_factor: f32,
348    final_div_factor: f32,
349    current_step: usize,
350    last_lr: f32,
351}
352
353impl OneCycleLR {
354    /// Creates a new `OneCycleLR` scheduler.
355    pub fn new<O: Optimizer>(optimizer: &O, max_lr: f32, total_steps: usize) -> Self {
356        Self::with_options(optimizer, max_lr, total_steps, 0.3, 25.0, 1e4)
357    }
358
359    /// Creates `OneCycleLR` with options.
360    pub fn with_options<O: Optimizer>(
361        _optimizer: &O,
362        max_lr: f32,
363        total_steps: usize,
364        pct_start: f32,
365        div_factor: f32,
366        final_div_factor: f32,
367    ) -> Self {
368        let initial_lr = max_lr / div_factor;
369        Self {
370            max_lr,
371            total_steps,
372            pct_start,
373            div_factor,
374            final_div_factor,
375            current_step: 0,
376            last_lr: initial_lr,
377        }
378    }
379}
380
381impl LRScheduler for OneCycleLR {
382    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
383        self.current_step += 1;
384
385        let step_ratio = self.current_step as f32 / self.total_steps as f32;
386        let initial_lr = self.max_lr / self.div_factor;
387        let min_lr = self.max_lr / self.final_div_factor;
388
389        let new_lr = if step_ratio <= self.pct_start {
390            // Warmup phase: linear increase from initial_lr to max_lr
391            let phase_ratio = step_ratio / self.pct_start;
392            initial_lr + (self.max_lr - initial_lr) * phase_ratio
393        } else {
394            // Annealing phase: cosine decrease from max_lr to min_lr
395            let phase_ratio = (step_ratio - self.pct_start) / (1.0 - self.pct_start);
396            min_lr
397                + (self.max_lr - min_lr) * (1.0 + (std::f32::consts::PI * phase_ratio).cos()) / 2.0
398        };
399
400        optimizer.set_lr(new_lr);
401        self.last_lr = new_lr;
402    }
403
404    fn get_last_lr(&self) -> f32 {
405        self.last_lr
406    }
407
408    fn get_step(&self) -> usize {
409        self.current_step
410    }
411}
412
413// =============================================================================
414// WarmupLR
415// =============================================================================
416
417/// Linear warmup scheduler.
418///
419/// Linearly increases learning rate from 0 to `initial_lr` over `warmup_steps`.
420pub struct WarmupLR {
421    initial_lr: f32,
422    warmup_steps: usize,
423    current_step: usize,
424    last_lr: f32,
425}
426
427impl WarmupLR {
428    /// Creates a new `WarmupLR` scheduler.
429    pub fn new<O: Optimizer>(optimizer: &O, warmup_steps: usize) -> Self {
430        let initial_lr = optimizer.get_lr();
431        Self {
432            initial_lr,
433            warmup_steps,
434            current_step: 0,
435            last_lr: 0.0,
436        }
437    }
438}
439
440impl LRScheduler for WarmupLR {
441    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
442        self.current_step += 1;
443
444        let new_lr = if self.current_step <= self.warmup_steps {
445            self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
446        } else {
447            self.initial_lr
448        };
449
450        optimizer.set_lr(new_lr);
451        self.last_lr = new_lr;
452    }
453
454    fn get_last_lr(&self) -> f32 {
455        self.last_lr
456    }
457
458    fn get_step(&self) -> usize {
459        self.current_step
460    }
461}
462
463// =============================================================================
464// Tests
465// =============================================================================
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    use crate::SGD;
471    use axonml_autograd::Variable;
472    use axonml_nn::Parameter;
473    use axonml_tensor::Tensor;
474
475    fn create_test_optimizer() -> SGD {
476        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
477        let param = Parameter::from_variable(var);
478        SGD::new(vec![param], 0.1)
479    }
480
481    #[test]
482    fn test_step_lr() {
483        let mut optimizer = create_test_optimizer();
484        let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
485
486        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
487
488        for _ in 0..10 {
489            scheduler.step(&mut optimizer);
490        }
491
492        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
493
494        for _ in 0..10 {
495            scheduler.step(&mut optimizer);
496        }
497
498        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
499    }
500
501    #[test]
502    fn test_multi_step_lr() {
503        let mut optimizer = create_test_optimizer();
504        let mut scheduler = MultiStepLR::new(&optimizer, vec![5, 15], 0.1);
505
506        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
507
508        for _ in 0..5 {
509            scheduler.step(&mut optimizer);
510        }
511        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
512
513        for _ in 0..10 {
514            scheduler.step(&mut optimizer);
515        }
516        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
517    }
518
519    #[test]
520    fn test_exponential_lr() {
521        let mut optimizer = create_test_optimizer();
522        let mut scheduler = ExponentialLR::new(&optimizer, 0.9);
523
524        scheduler.step(&mut optimizer);
525        assert!((optimizer.get_lr() - 0.09).abs() < 1e-6);
526
527        scheduler.step(&mut optimizer);
528        assert!((optimizer.get_lr() - 0.081).abs() < 1e-6);
529    }
530
531    #[test]
532    fn test_cosine_annealing_lr() {
533        let mut optimizer = create_test_optimizer();
534        let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
535
536        // At step 50 (halfway), should be at eta_min + (initial - eta_min) * 0.5
537        for _ in 0..50 {
538            scheduler.step(&mut optimizer);
539        }
540        assert!((optimizer.get_lr() - 0.05).abs() < 0.01);
541
542        // At step 100 (end), should be at eta_min
543        for _ in 0..50 {
544            scheduler.step(&mut optimizer);
545        }
546        assert!(optimizer.get_lr() < 0.01);
547    }
548
549    #[test]
550    fn test_warmup_lr() {
551        let mut optimizer = create_test_optimizer();
552        let mut scheduler = WarmupLR::new(&optimizer, 10);
553
554        scheduler.step(&mut optimizer);
555        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
556
557        for _ in 0..9 {
558            scheduler.step(&mut optimizer);
559        }
560        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
561
562        // After warmup, should stay at initial_lr
563        scheduler.step(&mut optimizer);
564        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
565    }
566
567    #[test]
568    fn test_one_cycle_lr() {
569        let mut optimizer = create_test_optimizer();
570        let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
571
572        // At start, should be at initial_lr = max_lr / div_factor
573        assert!((scheduler.get_last_lr() - 0.004).abs() < 0.001);
574
575        // Step through warmup phase
576        for _ in 0..30 {
577            scheduler.step(&mut optimizer);
578        }
579
580        // Should be at or near max_lr
581        assert!(optimizer.get_lr() > 0.08);
582    }
583
584    #[test]
585    fn test_reduce_lr_on_plateau() {
586        let mut optimizer = create_test_optimizer();
587        let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 2, 0.0, 0, 0.0);
588
589        let initial_lr = optimizer.get_lr();
590
591        // Simulate improving metric
592        scheduler.step_with_metric(&mut optimizer, 1.0);
593        scheduler.step_with_metric(&mut optimizer, 0.9);
594        assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
595
596        // Simulate plateau
597        scheduler.step_with_metric(&mut optimizer, 0.91);
598        scheduler.step_with_metric(&mut optimizer, 0.91);
599        scheduler.step_with_metric(&mut optimizer, 0.91);
600
601        // LR should have been reduced
602        assert!(optimizer.get_lr() < initial_lr);
603    }
604}