Skip to main content

axonml_optim/
lr_scheduler.rs

1//! Learning Rate Schedulers
2//!
3//! # File
4//! `crates/axonml-optim/src/lr_scheduler.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::optimizer::Optimizer;
18
19// =============================================================================
20// LRScheduler Trait
21// =============================================================================
22
23/// Trait for learning rate schedulers.
24pub trait LRScheduler {
25    /// Updates the learning rate (epoch-based schedulers).
26    fn step<O: Optimizer>(&mut self, optimizer: &mut O);
27
28    /// Updates the learning rate based on a metric value.
29    ///
30    /// Default implementation delegates to `step()`, ignoring the metric.
31    /// Override for metric-based schedulers like `ReduceLROnPlateau`.
32    fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, _metric: f32) {
33        self.step(optimizer);
34    }
35
36    /// Returns the current learning rate.
37    fn get_last_lr(&self) -> f32;
38
39    /// Returns the current epoch/step count.
40    fn get_step(&self) -> usize;
41}
42
43// =============================================================================
44// StepLR
45// =============================================================================
46
47/// Decays learning rate by gamma every `step_size` epochs.
48///
49/// lr = `initial_lr` * gamma^(epoch // `step_size`)
50pub struct StepLR {
51    initial_lr: f32,
52    step_size: usize,
53    gamma: f32,
54    current_step: usize,
55    last_lr: f32,
56}
57
58impl StepLR {
59    /// Creates a new `StepLR` scheduler.
60    pub fn new<O: Optimizer>(optimizer: &O, step_size: usize, gamma: f32) -> Self {
61        let initial_lr = optimizer.get_lr();
62        Self {
63            initial_lr,
64            step_size,
65            gamma,
66            current_step: 0,
67            last_lr: initial_lr,
68        }
69    }
70}
71
72impl LRScheduler for StepLR {
73    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
74        self.current_step += 1;
75        let num_decays = self.current_step / self.step_size;
76        let new_lr = self.initial_lr * self.gamma.powi(num_decays as i32);
77        optimizer.set_lr(new_lr);
78        self.last_lr = new_lr;
79    }
80
81    fn get_last_lr(&self) -> f32 {
82        self.last_lr
83    }
84
85    fn get_step(&self) -> usize {
86        self.current_step
87    }
88}
89
90// =============================================================================
91// MultiStepLR
92// =============================================================================
93
94/// Decays learning rate by gamma at each milestone.
95pub struct MultiStepLR {
96    initial_lr: f32,
97    milestones: Vec<usize>,
98    gamma: f32,
99    current_step: usize,
100    last_lr: f32,
101    milestone_idx: usize,
102}
103
104impl MultiStepLR {
105    /// Creates a new `MultiStepLR` scheduler.
106    pub fn new<O: Optimizer>(optimizer: &O, mut milestones: Vec<usize>, gamma: f32) -> Self {
107        let initial_lr = optimizer.get_lr();
108        milestones.sort_unstable();
109        Self {
110            initial_lr,
111            milestones,
112            gamma,
113            current_step: 0,
114            last_lr: initial_lr,
115            milestone_idx: 0,
116        }
117    }
118}
119
120impl LRScheduler for MultiStepLR {
121    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
122        self.current_step += 1;
123
124        // Check if we've passed any milestones
125        while self.milestone_idx < self.milestones.len()
126            && self.current_step >= self.milestones[self.milestone_idx]
127        {
128            self.milestone_idx += 1;
129        }
130
131        let new_lr = self.initial_lr * self.gamma.powi(self.milestone_idx as i32);
132        optimizer.set_lr(new_lr);
133        self.last_lr = new_lr;
134    }
135
136    fn get_last_lr(&self) -> f32 {
137        self.last_lr
138    }
139
140    fn get_step(&self) -> usize {
141        self.current_step
142    }
143}
144
145// =============================================================================
146// ExponentialLR
147// =============================================================================
148
149/// Decays learning rate by gamma every epoch.
150///
151/// lr = `initial_lr` * gamma^epoch
152pub struct ExponentialLR {
153    initial_lr: f32,
154    gamma: f32,
155    current_step: usize,
156    last_lr: f32,
157}
158
159impl ExponentialLR {
160    /// Creates a new `ExponentialLR` scheduler.
161    pub fn new<O: Optimizer>(optimizer: &O, gamma: f32) -> Self {
162        let initial_lr = optimizer.get_lr();
163        Self {
164            initial_lr,
165            gamma,
166            current_step: 0,
167            last_lr: initial_lr,
168        }
169    }
170}
171
172impl LRScheduler for ExponentialLR {
173    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
174        self.current_step += 1;
175        let new_lr = self.initial_lr * self.gamma.powi(self.current_step as i32);
176        optimizer.set_lr(new_lr);
177        self.last_lr = new_lr;
178    }
179
180    fn get_last_lr(&self) -> f32 {
181        self.last_lr
182    }
183
184    fn get_step(&self) -> usize {
185        self.current_step
186    }
187}
188
189// =============================================================================
190// CosineAnnealingLR
191// =============================================================================
192
193/// Cosine annealing learning rate scheduler.
194///
195/// lr = `eta_min` + (`initial_lr` - `eta_min`) * (1 + cos(pi * epoch / `T_max`)) / 2
196pub struct CosineAnnealingLR {
197    initial_lr: f32,
198    t_max: usize,
199    eta_min: f32,
200    current_step: usize,
201    last_lr: f32,
202}
203
204impl CosineAnnealingLR {
205    /// Creates a new `CosineAnnealingLR` scheduler.
206    pub fn new<O: Optimizer>(optimizer: &O, t_max: usize) -> Self {
207        Self::with_eta_min(optimizer, t_max, 0.0)
208    }
209
210    /// Creates a `CosineAnnealingLR` with minimum learning rate.
211    pub fn with_eta_min<O: Optimizer>(optimizer: &O, t_max: usize, eta_min: f32) -> Self {
212        let initial_lr = optimizer.get_lr();
213        Self {
214            initial_lr,
215            t_max,
216            eta_min,
217            current_step: 0,
218            last_lr: initial_lr,
219        }
220    }
221}
222
223impl LRScheduler for CosineAnnealingLR {
224    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
225        self.current_step += 1;
226
227        let progress = self.current_step as f32 / self.t_max as f32;
228        let new_lr = self.eta_min
229            + (self.initial_lr - self.eta_min) * (1.0 + (std::f32::consts::PI * progress).cos())
230                / 2.0;
231
232        optimizer.set_lr(new_lr);
233        self.last_lr = new_lr;
234    }
235
236    fn get_last_lr(&self) -> f32 {
237        self.last_lr
238    }
239
240    fn get_step(&self) -> usize {
241        self.current_step
242    }
243}
244
245// =============================================================================
246// ReduceLROnPlateau
247// =============================================================================
248
249/// Reduces learning rate when a metric has stopped improving.
250pub struct ReduceLROnPlateau {
251    mode: String,
252    factor: f32,
253    patience: usize,
254    threshold: f32,
255    cooldown: usize,
256    min_lr: f32,
257    best: f32,
258    num_bad_epochs: usize,
259    cooldown_counter: usize,
260    current_step: usize,
261    last_lr: f32,
262}
263
264impl ReduceLROnPlateau {
265    /// Creates a new `ReduceLROnPlateau` scheduler for minimizing.
266    pub fn new<O: Optimizer>(optimizer: &O) -> Self {
267        Self::with_options(optimizer, "min", 0.1, 10, 1e-4, 0, 0.0)
268    }
269
270    /// Creates a `ReduceLROnPlateau` with options.
271    pub fn with_options<O: Optimizer>(
272        optimizer: &O,
273        mode: &str,
274        factor: f32,
275        patience: usize,
276        threshold: f32,
277        cooldown: usize,
278        min_lr: f32,
279    ) -> Self {
280        let initial_lr = optimizer.get_lr();
281        let best = if mode == "min" {
282            f32::INFINITY
283        } else {
284            f32::NEG_INFINITY
285        };
286        Self {
287            mode: mode.to_string(),
288            factor,
289            patience,
290            threshold,
291            cooldown,
292            min_lr,
293            best,
294            num_bad_epochs: 0,
295            cooldown_counter: 0,
296            current_step: 0,
297            last_lr: initial_lr,
298        }
299    }
300
301    /// Internal: steps the scheduler based on a metric value.
302    fn step_metric_impl<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
303        self.current_step += 1;
304
305        // Check if we're in cooldown
306        if self.cooldown_counter > 0 {
307            self.cooldown_counter -= 1;
308            return;
309        }
310
311        // Check if metric improved
312        let improved = if self.mode == "min" {
313            metric < self.best * (1.0 - self.threshold)
314        } else {
315            metric > self.best * (1.0 + self.threshold)
316        };
317
318        if improved {
319            self.best = metric;
320            self.num_bad_epochs = 0;
321        } else {
322            self.num_bad_epochs += 1;
323        }
324
325        // Reduce learning rate if patience exceeded
326        if self.num_bad_epochs > self.patience {
327            let current_lr = optimizer.get_lr();
328            let new_lr = (current_lr * self.factor).max(self.min_lr);
329            optimizer.set_lr(new_lr);
330            self.last_lr = new_lr;
331            self.cooldown_counter = self.cooldown;
332            self.num_bad_epochs = 0;
333        }
334    }
335}
336
337impl LRScheduler for ReduceLROnPlateau {
338    fn step<O: Optimizer>(&mut self, _optimizer: &mut O) {
339        // No-op: this scheduler requires a metric. Use step_with_metric().
340        self.current_step += 1;
341    }
342
343    fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
344        self.step_metric_impl(optimizer, metric);
345    }
346
347    fn get_last_lr(&self) -> f32 {
348        self.last_lr
349    }
350
351    fn get_step(&self) -> usize {
352        self.current_step
353    }
354}
355
356// =============================================================================
357// OneCycleLR
358// =============================================================================
359
360/// One-cycle learning rate scheduler.
361///
362/// Implements the 1cycle policy from "Super-Convergence" paper.
363pub struct OneCycleLR {
364    max_lr: f32,
365    total_steps: usize,
366    pct_start: f32,
367    div_factor: f32,
368    final_div_factor: f32,
369    current_step: usize,
370    last_lr: f32,
371}
372
373impl OneCycleLR {
374    /// Creates a new `OneCycleLR` scheduler.
375    pub fn new<O: Optimizer>(optimizer: &O, max_lr: f32, total_steps: usize) -> Self {
376        Self::with_options(optimizer, max_lr, total_steps, 0.3, 25.0, 1e4)
377    }
378
379    /// Creates `OneCycleLR` with options.
380    pub fn with_options<O: Optimizer>(
381        _optimizer: &O,
382        max_lr: f32,
383        total_steps: usize,
384        pct_start: f32,
385        div_factor: f32,
386        final_div_factor: f32,
387    ) -> Self {
388        let initial_lr = max_lr / div_factor;
389        Self {
390            max_lr,
391            total_steps,
392            pct_start,
393            div_factor,
394            final_div_factor,
395            current_step: 0,
396            last_lr: initial_lr,
397        }
398    }
399}
400
401impl LRScheduler for OneCycleLR {
402    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
403        self.current_step += 1;
404
405        let step_ratio = self.current_step as f32 / self.total_steps as f32;
406        let initial_lr = self.max_lr / self.div_factor;
407        let min_lr = self.max_lr / self.final_div_factor;
408
409        let new_lr = if step_ratio <= self.pct_start {
410            // Warmup phase: linear increase from initial_lr to max_lr
411            let phase_ratio = step_ratio / self.pct_start;
412            initial_lr + (self.max_lr - initial_lr) * phase_ratio
413        } else {
414            // Annealing phase: cosine decrease from max_lr to min_lr
415            let phase_ratio = (step_ratio - self.pct_start) / (1.0 - self.pct_start);
416            min_lr
417                + (self.max_lr - min_lr) * (1.0 + (std::f32::consts::PI * phase_ratio).cos()) / 2.0
418        };
419
420        optimizer.set_lr(new_lr);
421        self.last_lr = new_lr;
422    }
423
424    fn get_last_lr(&self) -> f32 {
425        self.last_lr
426    }
427
428    fn get_step(&self) -> usize {
429        self.current_step
430    }
431}
432
433// =============================================================================
434// WarmupLR
435// =============================================================================
436
437/// Linear warmup scheduler.
438///
439/// Linearly increases learning rate from 0 to `initial_lr` over `warmup_steps`.
440pub struct WarmupLR {
441    initial_lr: f32,
442    warmup_steps: usize,
443    current_step: usize,
444    last_lr: f32,
445}
446
447impl WarmupLR {
448    /// Creates a new `WarmupLR` scheduler.
449    pub fn new<O: Optimizer>(optimizer: &O, warmup_steps: usize) -> Self {
450        let initial_lr = optimizer.get_lr();
451        Self {
452            initial_lr,
453            warmup_steps,
454            current_step: 0,
455            last_lr: 0.0,
456        }
457    }
458}
459
460impl LRScheduler for WarmupLR {
461    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
462        self.current_step += 1;
463
464        let new_lr = if self.current_step <= self.warmup_steps {
465            self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
466        } else {
467            self.initial_lr
468        };
469
470        optimizer.set_lr(new_lr);
471        self.last_lr = new_lr;
472    }
473
474    fn get_last_lr(&self) -> f32 {
475        self.last_lr
476    }
477
478    fn get_step(&self) -> usize {
479        self.current_step
480    }
481}
482
483// =============================================================================
484// Tests
485// =============================================================================
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use crate::SGD;
491    use axonml_autograd::Variable;
492    use axonml_nn::Parameter;
493    use axonml_tensor::Tensor;
494
495    fn create_test_optimizer() -> SGD {
496        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
497        let param = Parameter::from_variable(var);
498        SGD::new(vec![param], 0.1)
499    }
500
501    #[test]
502    fn test_step_lr() {
503        let mut optimizer = create_test_optimizer();
504        let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
505
506        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
507
508        for _ in 0..10 {
509            scheduler.step(&mut optimizer);
510        }
511
512        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
513
514        for _ in 0..10 {
515            scheduler.step(&mut optimizer);
516        }
517
518        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
519    }
520
521    #[test]
522    fn test_multi_step_lr() {
523        let mut optimizer = create_test_optimizer();
524        let mut scheduler = MultiStepLR::new(&optimizer, vec![5, 15], 0.1);
525
526        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
527
528        for _ in 0..5 {
529            scheduler.step(&mut optimizer);
530        }
531        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
532
533        for _ in 0..10 {
534            scheduler.step(&mut optimizer);
535        }
536        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
537    }
538
539    #[test]
540    fn test_exponential_lr() {
541        let mut optimizer = create_test_optimizer();
542        let mut scheduler = ExponentialLR::new(&optimizer, 0.9);
543
544        scheduler.step(&mut optimizer);
545        assert!((optimizer.get_lr() - 0.09).abs() < 1e-6);
546
547        scheduler.step(&mut optimizer);
548        assert!((optimizer.get_lr() - 0.081).abs() < 1e-6);
549    }
550
551    #[test]
552    fn test_cosine_annealing_lr() {
553        let mut optimizer = create_test_optimizer();
554        let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
555
556        // At step 50 (halfway), should be at eta_min + (initial - eta_min) * 0.5
557        for _ in 0..50 {
558            scheduler.step(&mut optimizer);
559        }
560        assert!((optimizer.get_lr() - 0.05).abs() < 0.01);
561
562        // At step 100 (end), should be at eta_min
563        for _ in 0..50 {
564            scheduler.step(&mut optimizer);
565        }
566        assert!(optimizer.get_lr() < 0.01);
567    }
568
569    #[test]
570    fn test_warmup_lr() {
571        let mut optimizer = create_test_optimizer();
572        let mut scheduler = WarmupLR::new(&optimizer, 10);
573
574        scheduler.step(&mut optimizer);
575        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
576
577        for _ in 0..9 {
578            scheduler.step(&mut optimizer);
579        }
580        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
581
582        // After warmup, should stay at initial_lr
583        scheduler.step(&mut optimizer);
584        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
585    }
586
587    #[test]
588    fn test_one_cycle_lr() {
589        let mut optimizer = create_test_optimizer();
590        let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
591
592        // At start, should be at initial_lr = max_lr / div_factor
593        assert!((scheduler.get_last_lr() - 0.004).abs() < 0.001);
594
595        // Step through warmup phase
596        for _ in 0..30 {
597            scheduler.step(&mut optimizer);
598        }
599
600        // Should be at or near max_lr
601        assert!(optimizer.get_lr() > 0.08);
602    }
603
604    #[test]
605    fn test_reduce_lr_on_plateau() {
606        let mut optimizer = create_test_optimizer();
607        let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 2, 0.0, 0, 0.0);
608
609        let initial_lr = optimizer.get_lr();
610
611        // Simulate improving metric
612        scheduler.step_with_metric(&mut optimizer, 1.0);
613        scheduler.step_with_metric(&mut optimizer, 0.9);
614        assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
615
616        // Simulate plateau
617        scheduler.step_with_metric(&mut optimizer, 0.91);
618        scheduler.step_with_metric(&mut optimizer, 0.91);
619        scheduler.step_with_metric(&mut optimizer, 0.91);
620
621        // LR should have been reduced
622        assert!(optimizer.get_lr() < initial_lr);
623    }
624
625    // =========================================================================
626    // ReduceLROnPlateau — Comprehensive
627    // =========================================================================
628
629    #[test]
630    fn test_reduce_lr_on_plateau_max_mode() {
631        let mut optimizer = create_test_optimizer();
632        let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "max", 0.5, 2, 0.0, 0, 0.0);
633
634        let initial_lr = optimizer.get_lr();
635
636        // Improving metric (higher is better)
637        scheduler.step_with_metric(&mut optimizer, 0.8);
638        scheduler.step_with_metric(&mut optimizer, 0.9);
639        assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
640
641        // Plateau (metric not improving)
642        scheduler.step_with_metric(&mut optimizer, 0.85);
643        scheduler.step_with_metric(&mut optimizer, 0.85);
644        scheduler.step_with_metric(&mut optimizer, 0.85);
645
646        assert!(
647            optimizer.get_lr() < initial_lr,
648            "LR should reduce on plateau in max mode"
649        );
650    }
651
652    #[test]
653    fn test_reduce_lr_on_plateau_min_lr_floor() {
654        let mut optimizer = create_test_optimizer();
655        let mut scheduler =
656            ReduceLROnPlateau::with_options(&optimizer, "min", 0.1, 0, 0.0, 0, 0.001);
657
658        // Force many reductions
659        for _ in 0..50 {
660            scheduler.step_with_metric(&mut optimizer, 999.0); // never improves
661        }
662
663        assert!(
664            optimizer.get_lr() >= 0.001,
665            "LR should not go below min_lr, got {}",
666            optimizer.get_lr()
667        );
668    }
669
670    #[test]
671    fn test_reduce_lr_cooldown() {
672        let mut optimizer = create_test_optimizer();
673        let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 0, 0.0, 3, 0.0);
674
675        let initial_lr = optimizer.get_lr();
676
677        // Trigger reduction
678        scheduler.step_with_metric(&mut optimizer, 999.0);
679        scheduler.step_with_metric(&mut optimizer, 999.0);
680        let lr_after_first_reduce = optimizer.get_lr();
681        assert!(lr_after_first_reduce < initial_lr);
682
683        // During cooldown (3 steps), LR should not change again
684        scheduler.step_with_metric(&mut optimizer, 999.0);
685        scheduler.step_with_metric(&mut optimizer, 999.0);
686        scheduler.step_with_metric(&mut optimizer, 999.0);
687        assert!(
688            (optimizer.get_lr() - lr_after_first_reduce).abs() < 1e-8,
689            "LR should not change during cooldown"
690        );
691    }
692
693    // =========================================================================
694    // OneCycleLR — Comprehensive
695    // =========================================================================
696
697    #[test]
698    fn test_one_cycle_lr_full_cycle() {
699        let mut optimizer = create_test_optimizer();
700        let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
701
702        let mut lrs = Vec::new();
703        for _ in 0..100 {
704            scheduler.step(&mut optimizer);
705            lrs.push(optimizer.get_lr());
706        }
707
708        // Should start low, peak around 30%, end very low
709        let max_lr = lrs.iter().cloned().fold(f32::MIN, f32::max);
710        let final_lr = *lrs.last().unwrap();
711
712        assert!(
713            max_lr > 0.08,
714            "Peak should be near max_lr=0.1, got {}",
715            max_lr
716        );
717        assert!(
718            final_lr < 0.001,
719            "Final LR should be very small, got {}",
720            final_lr
721        );
722
723        // Peak should occur around 30% of total steps
724        let peak_idx = lrs
725            .iter()
726            .enumerate()
727            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
728            .unwrap()
729            .0;
730        assert!(
731            peak_idx >= 25 && peak_idx <= 35,
732            "Peak should be around step 30, was at step {}",
733            peak_idx
734        );
735    }
736
737    #[test]
738    fn test_one_cycle_lr_monotonic_phases() {
739        let mut optimizer = create_test_optimizer();
740        let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
741
742        let mut lrs = Vec::new();
743        for _ in 0..100 {
744            scheduler.step(&mut optimizer);
745            lrs.push(optimizer.get_lr());
746        }
747
748        // Warmup phase (steps 1-30): should be monotonically increasing
749        for i in 1..29 {
750            assert!(
751                lrs[i] >= lrs[i - 1] - 1e-6,
752                "Warmup should increase: step {} lr={} < step {} lr={}",
753                i,
754                lrs[i],
755                i - 1,
756                lrs[i - 1]
757            );
758        }
759
760        // Annealing phase (steps 31-100): should be monotonically decreasing
761        for i in 32..99 {
762            assert!(
763                lrs[i] <= lrs[i - 1] + 1e-6,
764                "Annealing should decrease: step {} lr={} > step {} lr={}",
765                i,
766                lrs[i],
767                i - 1,
768                lrs[i - 1]
769            );
770        }
771    }
772
773    // =========================================================================
774    // CosineAnnealingLR — Comprehensive
775    // =========================================================================
776
777    #[test]
778    fn test_cosine_annealing_with_eta_min() {
779        let mut optimizer = create_test_optimizer();
780        let mut scheduler = CosineAnnealingLR::with_eta_min(&optimizer, 100, 0.001);
781
782        for _ in 0..100 {
783            scheduler.step(&mut optimizer);
784        }
785
786        // At end should be at eta_min
787        assert!(
788            (optimizer.get_lr() - 0.001).abs() < 0.002,
789            "Should reach eta_min at end, got {}",
790            optimizer.get_lr()
791        );
792    }
793
794    #[test]
795    fn test_cosine_annealing_monotonic_decrease() {
796        let mut optimizer = create_test_optimizer();
797        let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
798
799        let mut lrs = Vec::new();
800        for _ in 0..100 {
801            scheduler.step(&mut optimizer);
802            lrs.push(optimizer.get_lr());
803        }
804
805        // Cosine annealing should monotonically decrease
806        for i in 1..lrs.len() {
807            assert!(
808                lrs[i] <= lrs[i - 1] + 1e-6,
809                "Cosine should decrease: step {} lr={} > step {} lr={}",
810                i + 1,
811                lrs[i],
812                i,
813                lrs[i - 1]
814            );
815        }
816
817        // All LRs should be non-negative
818        assert!(
819            lrs.iter().all(|lr| *lr >= 0.0),
820            "LRs should be non-negative"
821        );
822    }
823
824    // =========================================================================
825    // WarmupLR — Edge Cases
826    // =========================================================================
827
828    #[test]
829    fn test_warmup_lr_stays_constant_after() {
830        let mut optimizer = create_test_optimizer();
831        let mut scheduler = WarmupLR::new(&optimizer, 5);
832
833        for _ in 0..5 {
834            scheduler.step(&mut optimizer);
835        }
836        let target = optimizer.get_lr();
837
838        // Should stay constant for many more steps
839        for _ in 0..100 {
840            scheduler.step(&mut optimizer);
841            assert!(
842                (optimizer.get_lr() - target).abs() < 1e-8,
843                "LR should stay at {} after warmup, got {}",
844                target,
845                optimizer.get_lr()
846            );
847        }
848    }
849}