Skip to main content

axonml_optim/
lr_scheduler.rs

1//! Learning rate schedulers — seven strategies for LR annealing.
2//!
3//! `StepLR` (decay every N epochs), `MultiStepLR` (decay at milestones),
4//! `ExponentialLR` (multiplicative decay per epoch), `CosineAnnealingLR`
5//! (cosine anneal to eta_min), `OneCycleLR` (1-cycle super-convergence
6//! with warmup/annealing), `WarmupLR` (linear warmup then constant),
7//! `ReduceLROnPlateau` (reduce on metric stall, with cooldown + min_lr).
8//! Each `.step()` updates the optimizer's LR.
9//!
10//! # File
11//! `crates/axonml-optim/src/lr_scheduler.rs`
12//!
13//! # Author
14//! Andrew Jewell Sr. — AutomataNexus LLC
15//! ORCID: 0009-0005-2158-7060
16//!
17//! # Updated
18//! April 14, 2026 11:15 PM EST
19//!
20//! # Disclaimer
21//! Use at own risk. This software is provided "as is", without warranty of any
22//! kind, express or implied. The author and AutomataNexus shall not be held
23//! liable for any damages arising from the use of this software.
24
25use crate::optimizer::Optimizer;
26
27// =============================================================================
28// LRScheduler Trait
29// =============================================================================
30
31/// Trait for learning rate schedulers.
32pub trait LRScheduler {
33    /// Updates the learning rate (epoch-based schedulers).
34    fn step<O: Optimizer>(&mut self, optimizer: &mut O);
35
36    /// Updates the learning rate based on a metric value.
37    ///
38    /// Default implementation delegates to `step()`, ignoring the metric.
39    /// Override for metric-based schedulers like `ReduceLROnPlateau`.
40    fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, _metric: f32) {
41        self.step(optimizer);
42    }
43
44    /// Returns the current learning rate.
45    fn get_last_lr(&self) -> f32;
46
47    /// Returns the current epoch/step count.
48    fn get_step(&self) -> usize;
49}
50
51// =============================================================================
52// StepLR
53// =============================================================================
54
55/// Decays learning rate by gamma every `step_size` epochs.
56///
57/// lr = `initial_lr` * gamma^(epoch // `step_size`)
58pub struct StepLR {
59    initial_lr: f32,
60    step_size: usize,
61    gamma: f32,
62    current_step: usize,
63    last_lr: f32,
64}
65
66impl StepLR {
67    /// Creates a new `StepLR` scheduler.
68    pub fn new<O: Optimizer>(optimizer: &O, step_size: usize, gamma: f32) -> Self {
69        let initial_lr = optimizer.get_lr();
70        Self {
71            initial_lr,
72            step_size,
73            gamma,
74            current_step: 0,
75            last_lr: initial_lr,
76        }
77    }
78}
79
80impl LRScheduler for StepLR {
81    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
82        self.current_step += 1;
83        let num_decays = self.current_step / self.step_size;
84        let new_lr = self.initial_lr * self.gamma.powi(num_decays as i32);
85        optimizer.set_lr(new_lr);
86        self.last_lr = new_lr;
87    }
88
89    fn get_last_lr(&self) -> f32 {
90        self.last_lr
91    }
92
93    fn get_step(&self) -> usize {
94        self.current_step
95    }
96}
97
98// =============================================================================
99// MultiStepLR
100// =============================================================================
101
102/// Decays learning rate by gamma at each milestone.
103pub struct MultiStepLR {
104    initial_lr: f32,
105    milestones: Vec<usize>,
106    gamma: f32,
107    current_step: usize,
108    last_lr: f32,
109    milestone_idx: usize,
110}
111
112impl MultiStepLR {
113    /// Creates a new `MultiStepLR` scheduler.
114    pub fn new<O: Optimizer>(optimizer: &O, mut milestones: Vec<usize>, gamma: f32) -> Self {
115        let initial_lr = optimizer.get_lr();
116        milestones.sort_unstable();
117        Self {
118            initial_lr,
119            milestones,
120            gamma,
121            current_step: 0,
122            last_lr: initial_lr,
123            milestone_idx: 0,
124        }
125    }
126}
127
128impl LRScheduler for MultiStepLR {
129    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
130        self.current_step += 1;
131
132        // Check if we've passed any milestones
133        while self.milestone_idx < self.milestones.len()
134            && self.current_step >= self.milestones[self.milestone_idx]
135        {
136            self.milestone_idx += 1;
137        }
138
139        let new_lr = self.initial_lr * self.gamma.powi(self.milestone_idx as i32);
140        optimizer.set_lr(new_lr);
141        self.last_lr = new_lr;
142    }
143
144    fn get_last_lr(&self) -> f32 {
145        self.last_lr
146    }
147
148    fn get_step(&self) -> usize {
149        self.current_step
150    }
151}
152
153// =============================================================================
154// ExponentialLR
155// =============================================================================
156
157/// Decays learning rate by gamma every epoch.
158///
159/// lr = `initial_lr` * gamma^epoch
160pub struct ExponentialLR {
161    initial_lr: f32,
162    gamma: f32,
163    current_step: usize,
164    last_lr: f32,
165}
166
167impl ExponentialLR {
168    /// Creates a new `ExponentialLR` scheduler.
169    pub fn new<O: Optimizer>(optimizer: &O, gamma: f32) -> Self {
170        let initial_lr = optimizer.get_lr();
171        Self {
172            initial_lr,
173            gamma,
174            current_step: 0,
175            last_lr: initial_lr,
176        }
177    }
178}
179
180impl LRScheduler for ExponentialLR {
181    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
182        self.current_step += 1;
183        let new_lr = self.initial_lr * self.gamma.powi(self.current_step as i32);
184        optimizer.set_lr(new_lr);
185        self.last_lr = new_lr;
186    }
187
188    fn get_last_lr(&self) -> f32 {
189        self.last_lr
190    }
191
192    fn get_step(&self) -> usize {
193        self.current_step
194    }
195}
196
197// =============================================================================
198// CosineAnnealingLR
199// =============================================================================
200
201/// Cosine annealing learning rate scheduler.
202///
203/// lr = `eta_min` + (`initial_lr` - `eta_min`) * (1 + cos(pi * epoch / `T_max`)) / 2
204pub struct CosineAnnealingLR {
205    initial_lr: f32,
206    t_max: usize,
207    eta_min: f32,
208    current_step: usize,
209    last_lr: f32,
210}
211
212impl CosineAnnealingLR {
213    /// Creates a new `CosineAnnealingLR` scheduler.
214    pub fn new<O: Optimizer>(optimizer: &O, t_max: usize) -> Self {
215        Self::with_eta_min(optimizer, t_max, 0.0)
216    }
217
218    /// Creates a `CosineAnnealingLR` with minimum learning rate.
219    pub fn with_eta_min<O: Optimizer>(optimizer: &O, t_max: usize, eta_min: f32) -> Self {
220        let initial_lr = optimizer.get_lr();
221        Self {
222            initial_lr,
223            t_max,
224            eta_min,
225            current_step: 0,
226            last_lr: initial_lr,
227        }
228    }
229}
230
231impl LRScheduler for CosineAnnealingLR {
232    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
233        self.current_step += 1;
234
235        let progress = self.current_step as f32 / self.t_max as f32;
236        let new_lr = self.eta_min
237            + (self.initial_lr - self.eta_min) * (1.0 + (std::f32::consts::PI * progress).cos())
238                / 2.0;
239
240        optimizer.set_lr(new_lr);
241        self.last_lr = new_lr;
242    }
243
244    fn get_last_lr(&self) -> f32 {
245        self.last_lr
246    }
247
248    fn get_step(&self) -> usize {
249        self.current_step
250    }
251}
252
253// =============================================================================
254// ReduceLROnPlateau
255// =============================================================================
256
257/// Reduces learning rate when a metric has stopped improving.
258pub struct ReduceLROnPlateau {
259    mode: String,
260    factor: f32,
261    patience: usize,
262    threshold: f32,
263    cooldown: usize,
264    min_lr: f32,
265    best: f32,
266    num_bad_epochs: usize,
267    cooldown_counter: usize,
268    current_step: usize,
269    last_lr: f32,
270}
271
272impl ReduceLROnPlateau {
273    /// Creates a new `ReduceLROnPlateau` scheduler for minimizing.
274    pub fn new<O: Optimizer>(optimizer: &O) -> Self {
275        Self::with_options(optimizer, "min", 0.1, 10, 1e-4, 0, 0.0)
276    }
277
278    /// Creates a `ReduceLROnPlateau` with options.
279    pub fn with_options<O: Optimizer>(
280        optimizer: &O,
281        mode: &str,
282        factor: f32,
283        patience: usize,
284        threshold: f32,
285        cooldown: usize,
286        min_lr: f32,
287    ) -> Self {
288        let initial_lr = optimizer.get_lr();
289        let best = if mode == "min" {
290            f32::INFINITY
291        } else {
292            f32::NEG_INFINITY
293        };
294        Self {
295            mode: mode.to_string(),
296            factor,
297            patience,
298            threshold,
299            cooldown,
300            min_lr,
301            best,
302            num_bad_epochs: 0,
303            cooldown_counter: 0,
304            current_step: 0,
305            last_lr: initial_lr,
306        }
307    }
308
309    /// Internal: steps the scheduler based on a metric value.
310    fn step_metric_impl<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
311        self.current_step += 1;
312
313        // Check if we're in cooldown
314        if self.cooldown_counter > 0 {
315            self.cooldown_counter -= 1;
316            return;
317        }
318
319        // Check if metric improved
320        let improved = if self.mode == "min" {
321            metric < self.best * (1.0 - self.threshold)
322        } else {
323            metric > self.best * (1.0 + self.threshold)
324        };
325
326        if improved {
327            self.best = metric;
328            self.num_bad_epochs = 0;
329        } else {
330            self.num_bad_epochs += 1;
331        }
332
333        // Reduce learning rate if patience exceeded
334        if self.num_bad_epochs > self.patience {
335            let current_lr = optimizer.get_lr();
336            let new_lr = (current_lr * self.factor).max(self.min_lr);
337            optimizer.set_lr(new_lr);
338            self.last_lr = new_lr;
339            self.cooldown_counter = self.cooldown;
340            self.num_bad_epochs = 0;
341        }
342    }
343}
344
345impl LRScheduler for ReduceLROnPlateau {
346    fn step<O: Optimizer>(&mut self, _optimizer: &mut O) {
347        // No-op: this scheduler requires a metric. Use step_with_metric().
348        self.current_step += 1;
349    }
350
351    fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
352        self.step_metric_impl(optimizer, metric);
353    }
354
355    fn get_last_lr(&self) -> f32 {
356        self.last_lr
357    }
358
359    fn get_step(&self) -> usize {
360        self.current_step
361    }
362}
363
364// =============================================================================
365// OneCycleLR
366// =============================================================================
367
368/// One-cycle learning rate scheduler.
369///
370/// Implements the 1cycle policy from "Super-Convergence" paper.
371pub struct OneCycleLR {
372    max_lr: f32,
373    total_steps: usize,
374    pct_start: f32,
375    div_factor: f32,
376    final_div_factor: f32,
377    current_step: usize,
378    last_lr: f32,
379}
380
381impl OneCycleLR {
382    /// Creates a new `OneCycleLR` scheduler.
383    pub fn new<O: Optimizer>(optimizer: &O, max_lr: f32, total_steps: usize) -> Self {
384        Self::with_options(optimizer, max_lr, total_steps, 0.3, 25.0, 1e4)
385    }
386
387    /// Creates `OneCycleLR` with options.
388    pub fn with_options<O: Optimizer>(
389        _optimizer: &O,
390        max_lr: f32,
391        total_steps: usize,
392        pct_start: f32,
393        div_factor: f32,
394        final_div_factor: f32,
395    ) -> Self {
396        let initial_lr = max_lr / div_factor;
397        Self {
398            max_lr,
399            total_steps,
400            pct_start,
401            div_factor,
402            final_div_factor,
403            current_step: 0,
404            last_lr: initial_lr,
405        }
406    }
407}
408
409impl LRScheduler for OneCycleLR {
410    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
411        self.current_step += 1;
412
413        let step_ratio = self.current_step as f32 / self.total_steps as f32;
414        let initial_lr = self.max_lr / self.div_factor;
415        let min_lr = self.max_lr / self.final_div_factor;
416
417        let new_lr = if step_ratio <= self.pct_start {
418            // Warmup phase: linear increase from initial_lr to max_lr
419            let phase_ratio = step_ratio / self.pct_start;
420            initial_lr + (self.max_lr - initial_lr) * phase_ratio
421        } else {
422            // Annealing phase: cosine decrease from max_lr to min_lr
423            let phase_ratio = (step_ratio - self.pct_start) / (1.0 - self.pct_start);
424            min_lr
425                + (self.max_lr - min_lr) * (1.0 + (std::f32::consts::PI * phase_ratio).cos()) / 2.0
426        };
427
428        optimizer.set_lr(new_lr);
429        self.last_lr = new_lr;
430    }
431
432    fn get_last_lr(&self) -> f32 {
433        self.last_lr
434    }
435
436    fn get_step(&self) -> usize {
437        self.current_step
438    }
439}
440
441// =============================================================================
442// WarmupLR
443// =============================================================================
444
445/// Linear warmup scheduler.
446///
447/// Linearly increases learning rate from 0 to `initial_lr` over `warmup_steps`.
448pub struct WarmupLR {
449    initial_lr: f32,
450    warmup_steps: usize,
451    current_step: usize,
452    last_lr: f32,
453}
454
455impl WarmupLR {
456    /// Creates a new `WarmupLR` scheduler.
457    pub fn new<O: Optimizer>(optimizer: &O, warmup_steps: usize) -> Self {
458        let initial_lr = optimizer.get_lr();
459        Self {
460            initial_lr,
461            warmup_steps,
462            current_step: 0,
463            last_lr: 0.0,
464        }
465    }
466}
467
468impl LRScheduler for WarmupLR {
469    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
470        self.current_step += 1;
471
472        let new_lr = if self.current_step <= self.warmup_steps {
473            self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
474        } else {
475            self.initial_lr
476        };
477
478        optimizer.set_lr(new_lr);
479        self.last_lr = new_lr;
480    }
481
482    fn get_last_lr(&self) -> f32 {
483        self.last_lr
484    }
485
486    fn get_step(&self) -> usize {
487        self.current_step
488    }
489}
490
491// =============================================================================
492// Tests
493// =============================================================================
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::SGD;
499    use axonml_autograd::Variable;
500    use axonml_nn::Parameter;
501    use axonml_tensor::Tensor;
502
503    fn create_test_optimizer() -> SGD {
504        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
505        let param = Parameter::from_variable(var);
506        SGD::new(vec![param], 0.1)
507    }
508
509    #[test]
510    fn test_step_lr() {
511        let mut optimizer = create_test_optimizer();
512        let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
513
514        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
515
516        for _ in 0..10 {
517            scheduler.step(&mut optimizer);
518        }
519
520        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
521
522        for _ in 0..10 {
523            scheduler.step(&mut optimizer);
524        }
525
526        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
527    }
528
529    #[test]
530    fn test_multi_step_lr() {
531        let mut optimizer = create_test_optimizer();
532        let mut scheduler = MultiStepLR::new(&optimizer, vec![5, 15], 0.1);
533
534        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
535
536        for _ in 0..5 {
537            scheduler.step(&mut optimizer);
538        }
539        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
540
541        for _ in 0..10 {
542            scheduler.step(&mut optimizer);
543        }
544        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
545    }
546
547    #[test]
548    fn test_exponential_lr() {
549        let mut optimizer = create_test_optimizer();
550        let mut scheduler = ExponentialLR::new(&optimizer, 0.9);
551
552        scheduler.step(&mut optimizer);
553        assert!((optimizer.get_lr() - 0.09).abs() < 1e-6);
554
555        scheduler.step(&mut optimizer);
556        assert!((optimizer.get_lr() - 0.081).abs() < 1e-6);
557    }
558
559    #[test]
560    fn test_cosine_annealing_lr() {
561        let mut optimizer = create_test_optimizer();
562        let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
563
564        // At step 50 (halfway), should be at eta_min + (initial - eta_min) * 0.5
565        for _ in 0..50 {
566            scheduler.step(&mut optimizer);
567        }
568        assert!((optimizer.get_lr() - 0.05).abs() < 0.01);
569
570        // At step 100 (end), should be at eta_min
571        for _ in 0..50 {
572            scheduler.step(&mut optimizer);
573        }
574        assert!(optimizer.get_lr() < 0.01);
575    }
576
577    #[test]
578    fn test_warmup_lr() {
579        let mut optimizer = create_test_optimizer();
580        let mut scheduler = WarmupLR::new(&optimizer, 10);
581
582        scheduler.step(&mut optimizer);
583        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
584
585        for _ in 0..9 {
586            scheduler.step(&mut optimizer);
587        }
588        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
589
590        // After warmup, should stay at initial_lr
591        scheduler.step(&mut optimizer);
592        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
593    }
594
595    #[test]
596    fn test_one_cycle_lr() {
597        let mut optimizer = create_test_optimizer();
598        let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
599
600        // At start, should be at initial_lr = max_lr / div_factor
601        assert!((scheduler.get_last_lr() - 0.004).abs() < 0.001);
602
603        // Step through warmup phase
604        for _ in 0..30 {
605            scheduler.step(&mut optimizer);
606        }
607
608        // Should be at or near max_lr
609        assert!(optimizer.get_lr() > 0.08);
610    }
611
612    #[test]
613    fn test_reduce_lr_on_plateau() {
614        let mut optimizer = create_test_optimizer();
615        let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 2, 0.0, 0, 0.0);
616
617        let initial_lr = optimizer.get_lr();
618
619        // Simulate improving metric
620        scheduler.step_with_metric(&mut optimizer, 1.0);
621        scheduler.step_with_metric(&mut optimizer, 0.9);
622        assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
623
624        // Simulate plateau
625        scheduler.step_with_metric(&mut optimizer, 0.91);
626        scheduler.step_with_metric(&mut optimizer, 0.91);
627        scheduler.step_with_metric(&mut optimizer, 0.91);
628
629        // LR should have been reduced
630        assert!(optimizer.get_lr() < initial_lr);
631    }
632
633    // =========================================================================
634    // ReduceLROnPlateau — Comprehensive
635    // =========================================================================
636
637    #[test]
638    fn test_reduce_lr_on_plateau_max_mode() {
639        let mut optimizer = create_test_optimizer();
640        let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "max", 0.5, 2, 0.0, 0, 0.0);
641
642        let initial_lr = optimizer.get_lr();
643
644        // Improving metric (higher is better)
645        scheduler.step_with_metric(&mut optimizer, 0.8);
646        scheduler.step_with_metric(&mut optimizer, 0.9);
647        assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
648
649        // Plateau (metric not improving)
650        scheduler.step_with_metric(&mut optimizer, 0.85);
651        scheduler.step_with_metric(&mut optimizer, 0.85);
652        scheduler.step_with_metric(&mut optimizer, 0.85);
653
654        assert!(
655            optimizer.get_lr() < initial_lr,
656            "LR should reduce on plateau in max mode"
657        );
658    }
659
660    #[test]
661    fn test_reduce_lr_on_plateau_min_lr_floor() {
662        let mut optimizer = create_test_optimizer();
663        let mut scheduler =
664            ReduceLROnPlateau::with_options(&optimizer, "min", 0.1, 0, 0.0, 0, 0.001);
665
666        // Force many reductions
667        for _ in 0..50 {
668            scheduler.step_with_metric(&mut optimizer, 999.0); // never improves
669        }
670
671        assert!(
672            optimizer.get_lr() >= 0.001,
673            "LR should not go below min_lr, got {}",
674            optimizer.get_lr()
675        );
676    }
677
678    #[test]
679    fn test_reduce_lr_cooldown() {
680        let mut optimizer = create_test_optimizer();
681        let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 0, 0.0, 3, 0.0);
682
683        let initial_lr = optimizer.get_lr();
684
685        // Trigger reduction
686        scheduler.step_with_metric(&mut optimizer, 999.0);
687        scheduler.step_with_metric(&mut optimizer, 999.0);
688        let lr_after_first_reduce = optimizer.get_lr();
689        assert!(lr_after_first_reduce < initial_lr);
690
691        // During cooldown (3 steps), LR should not change again
692        scheduler.step_with_metric(&mut optimizer, 999.0);
693        scheduler.step_with_metric(&mut optimizer, 999.0);
694        scheduler.step_with_metric(&mut optimizer, 999.0);
695        assert!(
696            (optimizer.get_lr() - lr_after_first_reduce).abs() < 1e-8,
697            "LR should not change during cooldown"
698        );
699    }
700
701    // =========================================================================
702    // OneCycleLR — Comprehensive
703    // =========================================================================
704
705    #[test]
706    fn test_one_cycle_lr_full_cycle() {
707        let mut optimizer = create_test_optimizer();
708        let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
709
710        let mut lrs = Vec::new();
711        for _ in 0..100 {
712            scheduler.step(&mut optimizer);
713            lrs.push(optimizer.get_lr());
714        }
715
716        // Should start low, peak around 30%, end very low
717        let max_lr = lrs.iter().copied().fold(f32::MIN, f32::max);
718        let final_lr = *lrs.last().unwrap();
719
720        assert!(
721            max_lr > 0.08,
722            "Peak should be near max_lr=0.1, got {}",
723            max_lr
724        );
725        assert!(
726            final_lr < 0.001,
727            "Final LR should be very small, got {}",
728            final_lr
729        );
730
731        // Peak should occur around 30% of total steps
732        let peak_idx = lrs
733            .iter()
734            .enumerate()
735            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
736            .unwrap()
737            .0;
738        assert!(
739            (25..=35).contains(&peak_idx),
740            "Peak should be around step 30, was at step {}",
741            peak_idx
742        );
743    }
744
745    #[test]
746    fn test_one_cycle_lr_monotonic_phases() {
747        let mut optimizer = create_test_optimizer();
748        let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
749
750        let mut lrs = Vec::new();
751        for _ in 0..100 {
752            scheduler.step(&mut optimizer);
753            lrs.push(optimizer.get_lr());
754        }
755
756        // Warmup phase (steps 1-30): should be monotonically increasing
757        for i in 1..29 {
758            assert!(
759                lrs[i] >= lrs[i - 1] - 1e-6,
760                "Warmup should increase: step {} lr={} < step {} lr={}",
761                i,
762                lrs[i],
763                i - 1,
764                lrs[i - 1]
765            );
766        }
767
768        // Annealing phase (steps 31-100): should be monotonically decreasing
769        for i in 32..99 {
770            assert!(
771                lrs[i] <= lrs[i - 1] + 1e-6,
772                "Annealing should decrease: step {} lr={} > step {} lr={}",
773                i,
774                lrs[i],
775                i - 1,
776                lrs[i - 1]
777            );
778        }
779    }
780
781    // =========================================================================
782    // CosineAnnealingLR — Comprehensive
783    // =========================================================================
784
785    #[test]
786    fn test_cosine_annealing_with_eta_min() {
787        let mut optimizer = create_test_optimizer();
788        let mut scheduler = CosineAnnealingLR::with_eta_min(&optimizer, 100, 0.001);
789
790        for _ in 0..100 {
791            scheduler.step(&mut optimizer);
792        }
793
794        // At end should be at eta_min
795        assert!(
796            (optimizer.get_lr() - 0.001).abs() < 0.002,
797            "Should reach eta_min at end, got {}",
798            optimizer.get_lr()
799        );
800    }
801
802    #[test]
803    fn test_cosine_annealing_monotonic_decrease() {
804        let mut optimizer = create_test_optimizer();
805        let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
806
807        let mut lrs = Vec::new();
808        for _ in 0..100 {
809            scheduler.step(&mut optimizer);
810            lrs.push(optimizer.get_lr());
811        }
812
813        // Cosine annealing should monotonically decrease
814        for i in 1..lrs.len() {
815            assert!(
816                lrs[i] <= lrs[i - 1] + 1e-6,
817                "Cosine should decrease: step {} lr={} > step {} lr={}",
818                i + 1,
819                lrs[i],
820                i,
821                lrs[i - 1]
822            );
823        }
824
825        // All LRs should be non-negative
826        assert!(
827            lrs.iter().all(|lr| *lr >= 0.0),
828            "LRs should be non-negative"
829        );
830    }
831
832    // =========================================================================
833    // WarmupLR — Edge Cases
834    // =========================================================================
835
836    #[test]
837    fn test_warmup_lr_stays_constant_after() {
838        let mut optimizer = create_test_optimizer();
839        let mut scheduler = WarmupLR::new(&optimizer, 5);
840
841        for _ in 0..5 {
842            scheduler.step(&mut optimizer);
843        }
844        let target = optimizer.get_lr();
845
846        // Should stay constant for many more steps
847        for _ in 0..100 {
848            scheduler.step(&mut optimizer);
849            assert!(
850                (optimizer.get_lr() - target).abs() < 1e-8,
851                "LR should stay at {} after warmup, got {}",
852                target,
853                optimizer.get_lr()
854            );
855        }
856    }
857}