Skip to main content

axonml_optim/
lr_scheduler.rs

1//! Learning Rate Schedulers
2//!
3//! # File
4//! `crates/axonml-optim/src/lr_scheduler.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::optimizer::Optimizer;
18
19// =============================================================================
20// LRScheduler Trait
21// =============================================================================
22
23/// Trait for learning rate schedulers.
24pub trait LRScheduler {
25    /// Updates the learning rate.
26    fn step<O: Optimizer>(&mut self, optimizer: &mut O);
27
28    /// Returns the current learning rate.
29    fn get_last_lr(&self) -> f32;
30
31    /// Returns the current epoch/step count.
32    fn get_step(&self) -> usize;
33}
34
35// =============================================================================
36// StepLR
37// =============================================================================
38
39/// Decays learning rate by gamma every `step_size` epochs.
40///
41/// lr = `initial_lr` * gamma^(epoch // `step_size`)
42pub struct StepLR {
43    initial_lr: f32,
44    step_size: usize,
45    gamma: f32,
46    current_step: usize,
47    last_lr: f32,
48}
49
50impl StepLR {
51    /// Creates a new `StepLR` scheduler.
52    pub fn new<O: Optimizer>(optimizer: &O, step_size: usize, gamma: f32) -> Self {
53        let initial_lr = optimizer.get_lr();
54        Self {
55            initial_lr,
56            step_size,
57            gamma,
58            current_step: 0,
59            last_lr: initial_lr,
60        }
61    }
62}
63
64impl LRScheduler for StepLR {
65    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
66        self.current_step += 1;
67        let num_decays = self.current_step / self.step_size;
68        let new_lr = self.initial_lr * self.gamma.powi(num_decays as i32);
69        optimizer.set_lr(new_lr);
70        self.last_lr = new_lr;
71    }
72
73    fn get_last_lr(&self) -> f32 {
74        self.last_lr
75    }
76
77    fn get_step(&self) -> usize {
78        self.current_step
79    }
80}
81
82// =============================================================================
83// MultiStepLR
84// =============================================================================
85
86/// Decays learning rate by gamma at each milestone.
87pub struct MultiStepLR {
88    initial_lr: f32,
89    milestones: Vec<usize>,
90    gamma: f32,
91    current_step: usize,
92    last_lr: f32,
93    milestone_idx: usize,
94}
95
96impl MultiStepLR {
97    /// Creates a new `MultiStepLR` scheduler.
98    pub fn new<O: Optimizer>(optimizer: &O, mut milestones: Vec<usize>, gamma: f32) -> Self {
99        let initial_lr = optimizer.get_lr();
100        milestones.sort_unstable();
101        Self {
102            initial_lr,
103            milestones,
104            gamma,
105            current_step: 0,
106            last_lr: initial_lr,
107            milestone_idx: 0,
108        }
109    }
110}
111
112impl LRScheduler for MultiStepLR {
113    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
114        self.current_step += 1;
115
116        // Check if we've passed any milestones
117        while self.milestone_idx < self.milestones.len()
118            && self.current_step >= self.milestones[self.milestone_idx]
119        {
120            self.milestone_idx += 1;
121        }
122
123        let new_lr = self.initial_lr * self.gamma.powi(self.milestone_idx as i32);
124        optimizer.set_lr(new_lr);
125        self.last_lr = new_lr;
126    }
127
128    fn get_last_lr(&self) -> f32 {
129        self.last_lr
130    }
131
132    fn get_step(&self) -> usize {
133        self.current_step
134    }
135}
136
137// =============================================================================
138// ExponentialLR
139// =============================================================================
140
141/// Decays learning rate by gamma every epoch.
142///
143/// lr = `initial_lr` * gamma^epoch
144pub struct ExponentialLR {
145    initial_lr: f32,
146    gamma: f32,
147    current_step: usize,
148    last_lr: f32,
149}
150
151impl ExponentialLR {
152    /// Creates a new `ExponentialLR` scheduler.
153    pub fn new<O: Optimizer>(optimizer: &O, gamma: f32) -> Self {
154        let initial_lr = optimizer.get_lr();
155        Self {
156            initial_lr,
157            gamma,
158            current_step: 0,
159            last_lr: initial_lr,
160        }
161    }
162}
163
164impl LRScheduler for ExponentialLR {
165    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
166        self.current_step += 1;
167        let new_lr = self.initial_lr * self.gamma.powi(self.current_step as i32);
168        optimizer.set_lr(new_lr);
169        self.last_lr = new_lr;
170    }
171
172    fn get_last_lr(&self) -> f32 {
173        self.last_lr
174    }
175
176    fn get_step(&self) -> usize {
177        self.current_step
178    }
179}
180
181// =============================================================================
182// CosineAnnealingLR
183// =============================================================================
184
185/// Cosine annealing learning rate scheduler.
186///
187/// lr = `eta_min` + (`initial_lr` - `eta_min`) * (1 + cos(pi * epoch / `T_max`)) / 2
188pub struct CosineAnnealingLR {
189    initial_lr: f32,
190    t_max: usize,
191    eta_min: f32,
192    current_step: usize,
193    last_lr: f32,
194}
195
196impl CosineAnnealingLR {
197    /// Creates a new `CosineAnnealingLR` scheduler.
198    pub fn new<O: Optimizer>(optimizer: &O, t_max: usize) -> Self {
199        Self::with_eta_min(optimizer, t_max, 0.0)
200    }
201
202    /// Creates a `CosineAnnealingLR` with minimum learning rate.
203    pub fn with_eta_min<O: Optimizer>(optimizer: &O, t_max: usize, eta_min: f32) -> Self {
204        let initial_lr = optimizer.get_lr();
205        Self {
206            initial_lr,
207            t_max,
208            eta_min,
209            current_step: 0,
210            last_lr: initial_lr,
211        }
212    }
213}
214
215impl LRScheduler for CosineAnnealingLR {
216    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
217        self.current_step += 1;
218
219        let progress = self.current_step as f32 / self.t_max as f32;
220        let new_lr = self.eta_min
221            + (self.initial_lr - self.eta_min) * (1.0 + (std::f32::consts::PI * progress).cos())
222                / 2.0;
223
224        optimizer.set_lr(new_lr);
225        self.last_lr = new_lr;
226    }
227
228    fn get_last_lr(&self) -> f32 {
229        self.last_lr
230    }
231
232    fn get_step(&self) -> usize {
233        self.current_step
234    }
235}
236
237// =============================================================================
238// ReduceLROnPlateau
239// =============================================================================
240
241/// Reduces learning rate when a metric has stopped improving.
242pub struct ReduceLROnPlateau {
243    mode: String,
244    factor: f32,
245    patience: usize,
246    threshold: f32,
247    cooldown: usize,
248    min_lr: f32,
249    best: f32,
250    num_bad_epochs: usize,
251    cooldown_counter: usize,
252    current_step: usize,
253    last_lr: f32,
254}
255
256impl ReduceLROnPlateau {
257    /// Creates a new `ReduceLROnPlateau` scheduler for minimizing.
258    pub fn new<O: Optimizer>(optimizer: &O) -> Self {
259        Self::with_options(optimizer, "min", 0.1, 10, 1e-4, 0, 0.0)
260    }
261
262    /// Creates a `ReduceLROnPlateau` with options.
263    pub fn with_options<O: Optimizer>(
264        optimizer: &O,
265        mode: &str,
266        factor: f32,
267        patience: usize,
268        threshold: f32,
269        cooldown: usize,
270        min_lr: f32,
271    ) -> Self {
272        let initial_lr = optimizer.get_lr();
273        let best = if mode == "min" {
274            f32::INFINITY
275        } else {
276            f32::NEG_INFINITY
277        };
278        Self {
279            mode: mode.to_string(),
280            factor,
281            patience,
282            threshold,
283            cooldown,
284            min_lr,
285            best,
286            num_bad_epochs: 0,
287            cooldown_counter: 0,
288            current_step: 0,
289            last_lr: initial_lr,
290        }
291    }
292
293    /// Steps the scheduler based on a metric value.
294    pub fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
295        self.current_step += 1;
296
297        // Check if we're in cooldown
298        if self.cooldown_counter > 0 {
299            self.cooldown_counter -= 1;
300            return;
301        }
302
303        // Check if metric improved
304        let improved = if self.mode == "min" {
305            metric < self.best * (1.0 - self.threshold)
306        } else {
307            metric > self.best * (1.0 + self.threshold)
308        };
309
310        if improved {
311            self.best = metric;
312            self.num_bad_epochs = 0;
313        } else {
314            self.num_bad_epochs += 1;
315        }
316
317        // Reduce learning rate if patience exceeded
318        if self.num_bad_epochs > self.patience {
319            let current_lr = optimizer.get_lr();
320            let new_lr = (current_lr * self.factor).max(self.min_lr);
321            optimizer.set_lr(new_lr);
322            self.last_lr = new_lr;
323            self.cooldown_counter = self.cooldown;
324            self.num_bad_epochs = 0;
325        }
326    }
327}
328
329impl LRScheduler for ReduceLROnPlateau {
330    fn step<O: Optimizer>(&mut self, _optimizer: &mut O) {
331        // This scheduler requires a metric value
332        // Use step_with_metric instead
333        self.current_step += 1;
334    }
335
336    fn get_last_lr(&self) -> f32 {
337        self.last_lr
338    }
339
340    fn get_step(&self) -> usize {
341        self.current_step
342    }
343}
344
345// =============================================================================
346// OneCycleLR
347// =============================================================================
348
349/// One-cycle learning rate scheduler.
350///
351/// Implements the 1cycle policy from "Super-Convergence" paper.
352pub struct OneCycleLR {
353    max_lr: f32,
354    total_steps: usize,
355    pct_start: f32,
356    div_factor: f32,
357    final_div_factor: f32,
358    current_step: usize,
359    last_lr: f32,
360}
361
362impl OneCycleLR {
363    /// Creates a new `OneCycleLR` scheduler.
364    pub fn new<O: Optimizer>(optimizer: &O, max_lr: f32, total_steps: usize) -> Self {
365        Self::with_options(optimizer, max_lr, total_steps, 0.3, 25.0, 1e4)
366    }
367
368    /// Creates `OneCycleLR` with options.
369    pub fn with_options<O: Optimizer>(
370        _optimizer: &O,
371        max_lr: f32,
372        total_steps: usize,
373        pct_start: f32,
374        div_factor: f32,
375        final_div_factor: f32,
376    ) -> Self {
377        let initial_lr = max_lr / div_factor;
378        Self {
379            max_lr,
380            total_steps,
381            pct_start,
382            div_factor,
383            final_div_factor,
384            current_step: 0,
385            last_lr: initial_lr,
386        }
387    }
388}
389
390impl LRScheduler for OneCycleLR {
391    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
392        self.current_step += 1;
393
394        let step_ratio = self.current_step as f32 / self.total_steps as f32;
395        let initial_lr = self.max_lr / self.div_factor;
396        let min_lr = self.max_lr / self.final_div_factor;
397
398        let new_lr = if step_ratio <= self.pct_start {
399            // Warmup phase: linear increase from initial_lr to max_lr
400            let phase_ratio = step_ratio / self.pct_start;
401            initial_lr + (self.max_lr - initial_lr) * phase_ratio
402        } else {
403            // Annealing phase: cosine decrease from max_lr to min_lr
404            let phase_ratio = (step_ratio - self.pct_start) / (1.0 - self.pct_start);
405            min_lr
406                + (self.max_lr - min_lr) * (1.0 + (std::f32::consts::PI * phase_ratio).cos()) / 2.0
407        };
408
409        optimizer.set_lr(new_lr);
410        self.last_lr = new_lr;
411    }
412
413    fn get_last_lr(&self) -> f32 {
414        self.last_lr
415    }
416
417    fn get_step(&self) -> usize {
418        self.current_step
419    }
420}
421
422// =============================================================================
423// WarmupLR
424// =============================================================================
425
426/// Linear warmup scheduler.
427///
428/// Linearly increases learning rate from 0 to `initial_lr` over `warmup_steps`.
429pub struct WarmupLR {
430    initial_lr: f32,
431    warmup_steps: usize,
432    current_step: usize,
433    last_lr: f32,
434}
435
436impl WarmupLR {
437    /// Creates a new `WarmupLR` scheduler.
438    pub fn new<O: Optimizer>(optimizer: &O, warmup_steps: usize) -> Self {
439        let initial_lr = optimizer.get_lr();
440        Self {
441            initial_lr,
442            warmup_steps,
443            current_step: 0,
444            last_lr: 0.0,
445        }
446    }
447}
448
449impl LRScheduler for WarmupLR {
450    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
451        self.current_step += 1;
452
453        let new_lr = if self.current_step <= self.warmup_steps {
454            self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
455        } else {
456            self.initial_lr
457        };
458
459        optimizer.set_lr(new_lr);
460        self.last_lr = new_lr;
461    }
462
463    fn get_last_lr(&self) -> f32 {
464        self.last_lr
465    }
466
467    fn get_step(&self) -> usize {
468        self.current_step
469    }
470}
471
472// =============================================================================
473// Tests
474// =============================================================================
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479    use crate::SGD;
480    use axonml_autograd::Variable;
481    use axonml_nn::Parameter;
482    use axonml_tensor::Tensor;
483
484    fn create_test_optimizer() -> SGD {
485        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
486        let param = Parameter::from_variable(var);
487        SGD::new(vec![param], 0.1)
488    }
489
490    #[test]
491    fn test_step_lr() {
492        let mut optimizer = create_test_optimizer();
493        let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
494
495        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
496
497        for _ in 0..10 {
498            scheduler.step(&mut optimizer);
499        }
500
501        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
502
503        for _ in 0..10 {
504            scheduler.step(&mut optimizer);
505        }
506
507        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
508    }
509
510    #[test]
511    fn test_multi_step_lr() {
512        let mut optimizer = create_test_optimizer();
513        let mut scheduler = MultiStepLR::new(&optimizer, vec![5, 15], 0.1);
514
515        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
516
517        for _ in 0..5 {
518            scheduler.step(&mut optimizer);
519        }
520        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
521
522        for _ in 0..10 {
523            scheduler.step(&mut optimizer);
524        }
525        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
526    }
527
528    #[test]
529    fn test_exponential_lr() {
530        let mut optimizer = create_test_optimizer();
531        let mut scheduler = ExponentialLR::new(&optimizer, 0.9);
532
533        scheduler.step(&mut optimizer);
534        assert!((optimizer.get_lr() - 0.09).abs() < 1e-6);
535
536        scheduler.step(&mut optimizer);
537        assert!((optimizer.get_lr() - 0.081).abs() < 1e-6);
538    }
539
540    #[test]
541    fn test_cosine_annealing_lr() {
542        let mut optimizer = create_test_optimizer();
543        let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
544
545        // At step 50 (halfway), should be at eta_min + (initial - eta_min) * 0.5
546        for _ in 0..50 {
547            scheduler.step(&mut optimizer);
548        }
549        assert!((optimizer.get_lr() - 0.05).abs() < 0.01);
550
551        // At step 100 (end), should be at eta_min
552        for _ in 0..50 {
553            scheduler.step(&mut optimizer);
554        }
555        assert!(optimizer.get_lr() < 0.01);
556    }
557
558    #[test]
559    fn test_warmup_lr() {
560        let mut optimizer = create_test_optimizer();
561        let mut scheduler = WarmupLR::new(&optimizer, 10);
562
563        scheduler.step(&mut optimizer);
564        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
565
566        for _ in 0..9 {
567            scheduler.step(&mut optimizer);
568        }
569        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
570
571        // After warmup, should stay at initial_lr
572        scheduler.step(&mut optimizer);
573        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
574    }
575
576    #[test]
577    fn test_one_cycle_lr() {
578        let mut optimizer = create_test_optimizer();
579        let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
580
581        // At start, should be at initial_lr = max_lr / div_factor
582        assert!((scheduler.get_last_lr() - 0.004).abs() < 0.001);
583
584        // Step through warmup phase
585        for _ in 0..30 {
586            scheduler.step(&mut optimizer);
587        }
588
589        // Should be at or near max_lr
590        assert!(optimizer.get_lr() > 0.08);
591    }
592
593    #[test]
594    fn test_reduce_lr_on_plateau() {
595        let mut optimizer = create_test_optimizer();
596        let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 2, 0.0, 0, 0.0);
597
598        let initial_lr = optimizer.get_lr();
599
600        // Simulate improving metric
601        scheduler.step_with_metric(&mut optimizer, 1.0);
602        scheduler.step_with_metric(&mut optimizer, 0.9);
603        assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
604
605        // Simulate plateau
606        scheduler.step_with_metric(&mut optimizer, 0.91);
607        scheduler.step_with_metric(&mut optimizer, 0.91);
608        scheduler.step_with_metric(&mut optimizer, 0.91);
609
610        // LR should have been reduced
611        assert!(optimizer.get_lr() < initial_lr);
612    }
613}