aprender/nn/
scheduler.rs

1//! Learning rate schedulers for training neural networks.
2//!
3//! Schedulers adjust the learning rate during training to improve convergence
4//! and final model quality.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use aprender::nn::optim::{Adam, Optimizer};
10//! use aprender::nn::scheduler::{StepLR, LRScheduler};
11//!
12//! let mut optimizer = Adam::new(params, 0.1);
13//! let mut scheduler = StepLR::new(10, 0.1);  // Decay by 0.1 every 10 epochs
14//!
15//! for epoch in 0..100 {
16//!     // Training loop...
17//!
18//!     // Update learning rate at end of epoch
19//!     scheduler.step(&mut optimizer);
20//! }
21//! ```
22//!
23//! # References
24//!
25//! - Loshchilov, I., & Hutter, F. (2017). SGDR: Stochastic gradient descent
26//!   with warm restarts. ICLR.
27//! - Goyal, P., et al. (2017). Accurate, large minibatch SGD: Training
28//!   `ImageNet` in 1 hour. arXiv.
29
30use super::optim::Optimizer;
31
32/// Common trait for learning rate schedulers.
33pub trait LRScheduler {
34    /// Update the optimizer's learning rate.
35    fn step<O: Optimizer>(&mut self, optimizer: &mut O);
36
37    /// Get the current learning rate.
38    fn get_lr(&self) -> f32;
39
40    /// Get the current epoch/step count.
41    fn last_epoch(&self) -> usize;
42}
43
44/// Step decay scheduler.
45///
46/// Decays learning rate by `gamma` every `step_size` epochs.
47///
48/// ```text
49/// lr = initial_lr * gamma^(epoch // step_size)
50/// ```
51#[derive(Debug, Clone)]
52pub struct StepLR {
53    initial_lr: f32,
54    step_size: usize,
55    gamma: f32,
56    current_lr: f32,
57    last_epoch: usize,
58}
59
60impl StepLR {
61    /// Create a new `StepLR` scheduler.
62    ///
63    /// # Arguments
64    ///
65    /// * `step_size` - Number of epochs between LR decays
66    /// * `gamma` - Multiplicative factor of LR decay (e.g., 0.1)
67    #[must_use]
68    pub fn new(step_size: usize, gamma: f32) -> Self {
69        Self {
70            initial_lr: 0.0, // Will be set on first step
71            step_size,
72            gamma,
73            current_lr: 0.0,
74            last_epoch: 0,
75        }
76    }
77
78    /// Create with initial learning rate already known.
79    #[must_use]
80    pub fn with_lr(initial_lr: f32, step_size: usize, gamma: f32) -> Self {
81        Self {
82            initial_lr,
83            step_size,
84            gamma,
85            current_lr: initial_lr,
86            last_epoch: 0,
87        }
88    }
89}
90
91impl LRScheduler for StepLR {
92    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
93        // Initialize on first step
94        if self.last_epoch == 0 && self.initial_lr == 0.0 {
95            self.initial_lr = optimizer.lr();
96            self.current_lr = self.initial_lr;
97        }
98
99        self.last_epoch += 1;
100
101        // Decay at step boundaries
102        if self.last_epoch % self.step_size == 0 {
103            self.current_lr *= self.gamma;
104            optimizer.set_lr(self.current_lr);
105        }
106    }
107
108    fn get_lr(&self) -> f32 {
109        self.current_lr
110    }
111
112    fn last_epoch(&self) -> usize {
113        self.last_epoch
114    }
115}
116
117/// Exponential decay scheduler.
118///
119/// Decays learning rate by `gamma` every epoch.
120///
121/// ```text
122/// lr = initial_lr * gamma^epoch
123/// ```
124#[derive(Debug, Clone)]
125pub struct ExponentialLR {
126    initial_lr: f32,
127    gamma: f32,
128    current_lr: f32,
129    last_epoch: usize,
130}
131
132impl ExponentialLR {
133    /// Create a new `ExponentialLR` scheduler.
134    ///
135    /// # Arguments
136    ///
137    /// * `gamma` - Multiplicative factor (e.g., 0.99)
138    #[must_use]
139    pub fn new(gamma: f32) -> Self {
140        Self {
141            initial_lr: 0.0,
142            gamma,
143            current_lr: 0.0,
144            last_epoch: 0,
145        }
146    }
147
148    #[must_use]
149    pub fn with_lr(initial_lr: f32, gamma: f32) -> Self {
150        Self {
151            initial_lr,
152            gamma,
153            current_lr: initial_lr,
154            last_epoch: 0,
155        }
156    }
157}
158
159impl LRScheduler for ExponentialLR {
160    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
161        if self.last_epoch == 0 && self.initial_lr == 0.0 {
162            self.initial_lr = optimizer.lr();
163            self.current_lr = self.initial_lr;
164        }
165
166        self.last_epoch += 1;
167        self.current_lr *= self.gamma;
168        optimizer.set_lr(self.current_lr);
169    }
170
171    fn get_lr(&self) -> f32 {
172        self.current_lr
173    }
174
175    fn last_epoch(&self) -> usize {
176        self.last_epoch
177    }
178}
179
180/// Cosine annealing scheduler (Loshchilov & Hutter, 2017).
181///
182/// Anneals learning rate following a cosine curve from initial to minimum.
183///
184/// ```text
185/// lr = min_lr + 0.5 * (initial_lr - min_lr) * (1 + cos(π * epoch / T_max))
186/// ```
187#[derive(Debug, Clone)]
188pub struct CosineAnnealingLR {
189    initial_lr: f32,
190    min_lr: f32,
191    t_max: usize,
192    current_lr: f32,
193    last_epoch: usize,
194}
195
196impl CosineAnnealingLR {
197    /// Create a new `CosineAnnealingLR` scheduler.
198    ///
199    /// # Arguments
200    ///
201    /// * `t_max` - Maximum number of epochs
202    /// * `min_lr` - Minimum learning rate (default: 0)
203    #[must_use]
204    pub fn new(t_max: usize) -> Self {
205        Self {
206            initial_lr: 0.0,
207            min_lr: 0.0,
208            t_max,
209            current_lr: 0.0,
210            last_epoch: 0,
211        }
212    }
213
214    #[must_use]
215    pub fn with_min_lr(t_max: usize, min_lr: f32) -> Self {
216        Self {
217            initial_lr: 0.0,
218            min_lr,
219            t_max,
220            current_lr: 0.0,
221            last_epoch: 0,
222        }
223    }
224
225    #[must_use]
226    pub fn with_lr(initial_lr: f32, t_max: usize, min_lr: f32) -> Self {
227        Self {
228            initial_lr,
229            min_lr,
230            t_max,
231            current_lr: initial_lr,
232            last_epoch: 0,
233        }
234    }
235}
236
237impl LRScheduler for CosineAnnealingLR {
238    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
239        if self.last_epoch == 0 && self.initial_lr == 0.0 {
240            self.initial_lr = optimizer.lr();
241            self.current_lr = self.initial_lr;
242        }
243
244        self.last_epoch += 1;
245
246        // Cosine annealing formula
247        let progress = self.last_epoch as f32 / self.t_max as f32;
248        let cosine = (std::f32::consts::PI * progress).cos();
249        self.current_lr = self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * (1.0 + cosine);
250
251        optimizer.set_lr(self.current_lr);
252    }
253
254    fn get_lr(&self) -> f32 {
255        self.current_lr
256    }
257
258    fn last_epoch(&self) -> usize {
259        self.last_epoch
260    }
261}
262
263/// Linear warmup scheduler.
264///
265/// Linearly increases learning rate from 0 to `initial_lr` over `warmup_steps`.
266///
267/// ```text
268/// if epoch < warmup_steps:
269///     lr = initial_lr * epoch / warmup_steps
270/// else:
271///     lr = initial_lr
272/// ```
273#[derive(Debug, Clone)]
274pub struct LinearWarmup {
275    initial_lr: f32,
276    warmup_steps: usize,
277    current_lr: f32,
278    last_epoch: usize,
279}
280
281impl LinearWarmup {
282    /// Create a new `LinearWarmup` scheduler.
283    ///
284    /// # Arguments
285    ///
286    /// * `warmup_steps` - Number of warmup epochs
287    #[must_use]
288    pub fn new(warmup_steps: usize) -> Self {
289        Self {
290            initial_lr: 0.0,
291            warmup_steps,
292            current_lr: 0.0,
293            last_epoch: 0,
294        }
295    }
296
297    #[must_use]
298    pub fn with_lr(initial_lr: f32, warmup_steps: usize) -> Self {
299        Self {
300            initial_lr,
301            warmup_steps,
302            current_lr: 0.0,
303            last_epoch: 0,
304        }
305    }
306}
307
308impl LRScheduler for LinearWarmup {
309    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
310        if self.last_epoch == 0 && self.initial_lr == 0.0 {
311            self.initial_lr = optimizer.lr();
312        }
313
314        self.last_epoch += 1;
315
316        if self.last_epoch <= self.warmup_steps {
317            // Linear warmup
318            self.current_lr = self.initial_lr * (self.last_epoch as f32 / self.warmup_steps as f32);
319        } else {
320            self.current_lr = self.initial_lr;
321        }
322
323        optimizer.set_lr(self.current_lr);
324    }
325
326    fn get_lr(&self) -> f32 {
327        self.current_lr
328    }
329
330    fn last_epoch(&self) -> usize {
331        self.last_epoch
332    }
333}
334
335/// Warmup + Cosine decay scheduler.
336///
337/// Combines linear warmup with cosine annealing, commonly used in modern
338/// transformer training.
339///
340/// ```text
341/// if epoch < warmup_steps:
342///     lr = initial_lr * epoch / warmup_steps
343/// else:
344///     lr = min_lr + 0.5 * (initial_lr - min_lr) * (1 + cos(π * (epoch - warmup) / (total - warmup)))
345/// ```
346#[derive(Debug, Clone)]
347pub struct WarmupCosineScheduler {
348    initial_lr: f32,
349    min_lr: f32,
350    warmup_steps: usize,
351    total_steps: usize,
352    current_lr: f32,
353    last_epoch: usize,
354}
355
356impl WarmupCosineScheduler {
357    /// Create a new `WarmupCosineScheduler`.
358    ///
359    /// # Arguments
360    ///
361    /// * `warmup_steps` - Number of warmup epochs
362    /// * `total_steps` - Total number of training epochs
363    #[must_use]
364    pub fn new(warmup_steps: usize, total_steps: usize) -> Self {
365        Self {
366            initial_lr: 0.0,
367            min_lr: 0.0,
368            warmup_steps,
369            total_steps,
370            current_lr: 0.0,
371            last_epoch: 0,
372        }
373    }
374
375    #[must_use]
376    pub fn with_min_lr(warmup_steps: usize, total_steps: usize, min_lr: f32) -> Self {
377        Self {
378            initial_lr: 0.0,
379            min_lr,
380            warmup_steps,
381            total_steps,
382            current_lr: 0.0,
383            last_epoch: 0,
384        }
385    }
386}
387
388impl LRScheduler for WarmupCosineScheduler {
389    fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
390        if self.last_epoch == 0 && self.initial_lr == 0.0 {
391            self.initial_lr = optimizer.lr();
392        }
393
394        self.last_epoch += 1;
395
396        if self.last_epoch <= self.warmup_steps {
397            // Linear warmup
398            self.current_lr = self.initial_lr * (self.last_epoch as f32 / self.warmup_steps as f32);
399        } else {
400            // Cosine decay
401            let decay_steps = self.total_steps - self.warmup_steps;
402            let decay_epoch = self.last_epoch - self.warmup_steps;
403            let progress = decay_epoch as f32 / decay_steps as f32;
404            let cosine = (std::f32::consts::PI * progress).cos();
405            self.current_lr = self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * (1.0 + cosine);
406        }
407
408        optimizer.set_lr(self.current_lr);
409    }
410
411    fn get_lr(&self) -> f32 {
412        self.current_lr
413    }
414
415    fn last_epoch(&self) -> usize {
416        self.last_epoch
417    }
418}
419
420/// Reduce LR on plateau scheduler.
421///
422/// Reduces learning rate when a metric has stopped improving.
423#[derive(Debug, Clone)]
424pub struct ReduceLROnPlateau {
425    factor: f32,
426    patience: usize,
427    min_lr: f32,
428    threshold: f32,
429    current_lr: f32,
430    best_metric: f32,
431    num_bad_epochs: usize,
432    last_epoch: usize,
433    mode: PlateauMode,
434}
435
436/// Mode for plateau detection.
437#[derive(Debug, Clone, Copy, PartialEq, Eq)]
438pub enum PlateauMode {
439    /// Lower metric is better (e.g., loss)
440    Min,
441    /// Higher metric is better (e.g., accuracy)
442    Max,
443}
444
445impl ReduceLROnPlateau {
446    /// Create a new `ReduceLROnPlateau` scheduler.
447    ///
448    /// # Arguments
449    ///
450    /// * `mode` - Whether to minimize or maximize the metric
451    /// * `factor` - Factor to reduce LR by (e.g., 0.1)
452    /// * `patience` - Number of epochs with no improvement before reducing
453    #[must_use]
454    pub fn new(mode: PlateauMode, factor: f32, patience: usize) -> Self {
455        let best_metric = match mode {
456            PlateauMode::Min => f32::INFINITY,
457            PlateauMode::Max => f32::NEG_INFINITY,
458        };
459
460        Self {
461            factor,
462            patience,
463            min_lr: 1e-8,
464            threshold: 1e-4,
465            current_lr: 0.0,
466            best_metric,
467            num_bad_epochs: 0,
468            last_epoch: 0,
469            mode,
470        }
471    }
472
473    /// Set minimum learning rate.
474    #[must_use]
475    pub fn min_lr(mut self, min_lr: f32) -> Self {
476        self.min_lr = min_lr;
477        self
478    }
479
480    /// Set threshold for measuring improvement.
481    #[must_use]
482    pub fn threshold(mut self, threshold: f32) -> Self {
483        self.threshold = threshold;
484        self
485    }
486
487    /// Update scheduler with current metric value.
488    pub fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
489        if self.last_epoch == 0 && self.current_lr == 0.0 {
490            self.current_lr = optimizer.lr();
491        }
492
493        self.last_epoch += 1;
494
495        // Check if metric improved
496        let is_better = match self.mode {
497            PlateauMode::Min => metric < self.best_metric - self.threshold,
498            PlateauMode::Max => metric > self.best_metric + self.threshold,
499        };
500
501        if is_better {
502            self.best_metric = metric;
503            self.num_bad_epochs = 0;
504        } else {
505            self.num_bad_epochs += 1;
506        }
507
508        // Reduce LR if patience exceeded
509        if self.num_bad_epochs >= self.patience {
510            let new_lr = (self.current_lr * self.factor).max(self.min_lr);
511            if new_lr < self.current_lr {
512                self.current_lr = new_lr;
513                optimizer.set_lr(self.current_lr);
514                self.num_bad_epochs = 0;
515            }
516        }
517    }
518}
519
520impl LRScheduler for ReduceLROnPlateau {
521    fn step<O: Optimizer>(&mut self, _optimizer: &mut O) {
522        // This scheduler needs a metric, use step_with_metric instead
523        self.last_epoch += 1;
524    }
525
526    fn get_lr(&self) -> f32 {
527        self.current_lr
528    }
529
530    fn last_epoch(&self) -> usize {
531        self.last_epoch
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    // Mock optimizer for testing
540    struct MockOptimizer {
541        lr: f32,
542    }
543
544    impl MockOptimizer {
545        fn new(lr: f32) -> Self {
546            Self { lr }
547        }
548    }
549
550    impl Optimizer for MockOptimizer {
551        fn step(&mut self) {}
552        fn zero_grad(&mut self) {}
553        fn lr(&self) -> f32 {
554            self.lr
555        }
556        fn set_lr(&mut self, lr: f32) {
557            self.lr = lr;
558        }
559    }
560
561    #[test]
562    fn test_step_lr() {
563        let mut optimizer = MockOptimizer::new(0.1);
564        let mut scheduler = StepLR::new(3, 0.1);
565
566        // First 3 epochs: lr = 0.1
567        scheduler.step(&mut optimizer);
568        assert!((optimizer.lr() - 0.1).abs() < 1e-6);
569        scheduler.step(&mut optimizer);
570        assert!((optimizer.lr() - 0.1).abs() < 1e-6);
571        scheduler.step(&mut optimizer);
572        // After step 3: lr = 0.1 * 0.1 = 0.01
573        assert!((optimizer.lr() - 0.01).abs() < 1e-6);
574
575        // Next 3 epochs
576        scheduler.step(&mut optimizer);
577        scheduler.step(&mut optimizer);
578        scheduler.step(&mut optimizer);
579        // After step 6: lr = 0.01 * 0.1 = 0.001
580        assert!((optimizer.lr() - 0.001).abs() < 1e-6);
581    }
582
583    #[test]
584    fn test_exponential_lr() {
585        let mut optimizer = MockOptimizer::new(0.1);
586        let mut scheduler = ExponentialLR::new(0.9);
587
588        scheduler.step(&mut optimizer);
589        assert!((optimizer.lr() - 0.09).abs() < 1e-6);
590
591        scheduler.step(&mut optimizer);
592        assert!((optimizer.lr() - 0.081).abs() < 1e-6);
593    }
594
595    #[test]
596    fn test_cosine_annealing() {
597        let mut optimizer = MockOptimizer::new(0.1);
598        let mut scheduler = CosineAnnealingLR::new(10);
599
600        // At epoch 0 (before step): lr = 0.1
601        scheduler.step(&mut optimizer);
602        // At epoch 1: should be close to initial (cosine starts at 1)
603        assert!(optimizer.lr() < 0.1);
604        assert!(optimizer.lr() > 0.09);
605
606        // At epoch 5 (halfway): should be around 0.05
607        for _ in 0..4 {
608            scheduler.step(&mut optimizer);
609        }
610        assert!((optimizer.lr() - 0.05).abs() < 0.01);
611
612        // At epoch 10: should be close to 0
613        for _ in 0..5 {
614            scheduler.step(&mut optimizer);
615        }
616        assert!(optimizer.lr() < 0.01);
617    }
618
619    #[test]
620    fn test_linear_warmup() {
621        let mut optimizer = MockOptimizer::new(0.1);
622        let mut scheduler = LinearWarmup::new(5);
623
624        // During warmup
625        scheduler.step(&mut optimizer);
626        assert!((optimizer.lr() - 0.02).abs() < 1e-6); // 0.1 * 1/5
627
628        scheduler.step(&mut optimizer);
629        assert!((optimizer.lr() - 0.04).abs() < 1e-6); // 0.1 * 2/5
630
631        // After warmup
632        for _ in 0..3 {
633            scheduler.step(&mut optimizer);
634        }
635        assert!((optimizer.lr() - 0.1).abs() < 1e-6);
636
637        scheduler.step(&mut optimizer);
638        assert!((optimizer.lr() - 0.1).abs() < 1e-6); // Stays at initial
639    }
640
641    #[test]
642    fn test_warmup_cosine() {
643        let mut optimizer = MockOptimizer::new(0.1);
644        let mut scheduler = WarmupCosineScheduler::new(5, 20);
645
646        // Warmup phase
647        scheduler.step(&mut optimizer);
648        assert!((optimizer.lr() - 0.02).abs() < 1e-6);
649
650        // Complete warmup
651        for _ in 0..4 {
652            scheduler.step(&mut optimizer);
653        }
654        assert!((optimizer.lr() - 0.1).abs() < 1e-6);
655
656        // Decay phase starts
657        scheduler.step(&mut optimizer);
658        assert!(optimizer.lr() < 0.1);
659    }
660
661    #[test]
662    fn test_reduce_on_plateau() {
663        let mut optimizer = MockOptimizer::new(0.1);
664        let mut scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.1, 3);
665
666        // Improving
667        scheduler.step_with_metric(&mut optimizer, 1.0);
668        assert!((optimizer.lr() - 0.1).abs() < 1e-6);
669
670        scheduler.step_with_metric(&mut optimizer, 0.9);
671        assert!((optimizer.lr() - 0.1).abs() < 1e-6);
672
673        // Plateau (no improvement for 3 epochs)
674        scheduler.step_with_metric(&mut optimizer, 0.9);
675        scheduler.step_with_metric(&mut optimizer, 0.9);
676        scheduler.step_with_metric(&mut optimizer, 0.9);
677
678        // LR should be reduced
679        assert!((optimizer.lr() - 0.01).abs() < 1e-6);
680    }
681
682    #[test]
683    fn test_reduce_on_plateau_max_mode() {
684        let mut optimizer = MockOptimizer::new(0.1);
685        let mut scheduler = ReduceLROnPlateau::new(PlateauMode::Max, 0.5, 2);
686
687        // Improving
688        scheduler.step_with_metric(&mut optimizer, 0.5);
689        scheduler.step_with_metric(&mut optimizer, 0.6);
690        assert!((optimizer.lr() - 0.1).abs() < 1e-6);
691
692        // Plateau
693        scheduler.step_with_metric(&mut optimizer, 0.6);
694        scheduler.step_with_metric(&mut optimizer, 0.6);
695
696        // LR should be reduced
697        assert!((optimizer.lr() - 0.05).abs() < 1e-6);
698    }
699
700    // Additional tests for coverage
701
702    #[test]
703    fn test_step_lr_with_lr() {
704        let mut optimizer = MockOptimizer::new(0.1);
705        let mut scheduler = StepLR::with_lr(0.2, 2, 0.5);
706
707        assert_eq!(scheduler.get_lr(), 0.2);
708        assert_eq!(scheduler.last_epoch(), 0);
709
710        scheduler.step(&mut optimizer);
711        assert_eq!(scheduler.last_epoch(), 1);
712        scheduler.step(&mut optimizer);
713        // After 2 steps: 0.2 * 0.5 = 0.1
714        assert!((scheduler.get_lr() - 0.1).abs() < 1e-6);
715    }
716
717    #[test]
718    fn test_exponential_lr_with_lr() {
719        let mut optimizer = MockOptimizer::new(0.1);
720        let mut scheduler = ExponentialLR::with_lr(0.5, 0.8);
721
722        assert_eq!(scheduler.get_lr(), 0.5);
723        assert_eq!(scheduler.last_epoch(), 0);
724
725        scheduler.step(&mut optimizer);
726        assert!((scheduler.get_lr() - 0.4).abs() < 1e-6);
727        assert_eq!(scheduler.last_epoch(), 1);
728    }
729
730    #[test]
731    fn test_cosine_annealing_with_min_lr() {
732        let mut optimizer = MockOptimizer::new(0.1);
733        let mut scheduler = CosineAnnealingLR::with_min_lr(10, 0.01);
734
735        scheduler.step(&mut optimizer);
736        assert!(scheduler.get_lr() > 0.01);
737        assert!(scheduler.get_lr() < 0.1);
738    }
739
740    #[test]
741    fn test_cosine_annealing_with_lr() {
742        let mut optimizer = MockOptimizer::new(0.05);
743        let mut scheduler = CosineAnnealingLR::with_lr(0.2, 10, 0.02);
744
745        assert_eq!(scheduler.get_lr(), 0.2);
746        scheduler.step(&mut optimizer);
747        // Should use initial_lr of 0.2, not optimizer's 0.05
748        assert!(scheduler.get_lr() < 0.2);
749        assert!(scheduler.get_lr() > 0.02);
750    }
751
752    #[test]
753    fn test_linear_warmup_with_lr() {
754        let mut optimizer = MockOptimizer::new(0.05);
755        let mut scheduler = LinearWarmup::with_lr(0.2, 4);
756
757        assert_eq!(scheduler.get_lr(), 0.0); // before any step
758        scheduler.step(&mut optimizer);
759        assert!((scheduler.get_lr() - 0.05).abs() < 1e-6); // 0.2 * 1/4
760        assert_eq!(scheduler.last_epoch(), 1);
761    }
762
763    #[test]
764    fn test_warmup_cosine_with_min_lr() {
765        let mut optimizer = MockOptimizer::new(0.1);
766        let mut scheduler = WarmupCosineScheduler::with_min_lr(5, 20, 0.001);
767
768        // Complete warmup
769        for _ in 0..5 {
770            scheduler.step(&mut optimizer);
771        }
772        assert!((scheduler.get_lr() - 0.1).abs() < 1e-6);
773
774        // Start decay
775        scheduler.step(&mut optimizer);
776        assert!(scheduler.get_lr() < 0.1);
777        assert!(scheduler.get_lr() > 0.001);
778        assert_eq!(scheduler.last_epoch(), 6);
779    }
780
781    #[test]
782    fn test_reduce_on_plateau_min_lr_builder() {
783        let scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.1, 3).min_lr(0.0001);
784        assert!((scheduler.min_lr - 0.0001).abs() < 1e-8);
785    }
786
787    #[test]
788    fn test_reduce_on_plateau_threshold_builder() {
789        let scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.1, 3).threshold(0.001);
790        assert!((scheduler.threshold - 0.001).abs() < 1e-8);
791    }
792
793    #[test]
794    fn test_reduce_on_plateau_step_without_metric() {
795        let mut optimizer = MockOptimizer::new(0.1);
796        let mut scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.1, 2);
797
798        // Call step without metric (should just increment epoch)
799        scheduler.step(&mut optimizer);
800        assert_eq!(scheduler.last_epoch(), 1);
801        scheduler.step(&mut optimizer);
802        assert_eq!(scheduler.last_epoch(), 2);
803    }
804
805    #[test]
806    fn test_reduce_on_plateau_min_lr_clamp() {
807        let mut optimizer = MockOptimizer::new(0.001);
808        let mut scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.1, 1).min_lr(0.0005);
809
810        // First metric establishes baseline
811        scheduler.step_with_metric(&mut optimizer, 1.0);
812        // No improvement triggers reduction
813        scheduler.step_with_metric(&mut optimizer, 1.0);
814        // LR should be clamped at min_lr
815        assert!(scheduler.get_lr() >= 0.0005);
816    }
817
818    #[test]
819    fn test_step_lr_getters() {
820        let scheduler = StepLR::with_lr(0.1, 5, 0.9);
821        assert_eq!(scheduler.get_lr(), 0.1);
822        assert_eq!(scheduler.last_epoch(), 0);
823    }
824
825    #[test]
826    fn test_exponential_lr_getters() {
827        let scheduler = ExponentialLR::with_lr(0.1, 0.9);
828        assert_eq!(scheduler.get_lr(), 0.1);
829        assert_eq!(scheduler.last_epoch(), 0);
830    }
831
832    #[test]
833    fn test_cosine_annealing_getters() {
834        let scheduler = CosineAnnealingLR::with_lr(0.1, 10, 0.01);
835        assert_eq!(scheduler.get_lr(), 0.1);
836        assert_eq!(scheduler.last_epoch(), 0);
837    }
838
839    #[test]
840    fn test_linear_warmup_getters() {
841        let scheduler = LinearWarmup::with_lr(0.1, 5);
842        assert_eq!(scheduler.get_lr(), 0.0);
843        assert_eq!(scheduler.last_epoch(), 0);
844    }
845
846    #[test]
847    fn test_warmup_cosine_getters() {
848        let scheduler = WarmupCosineScheduler::with_min_lr(5, 20, 0.01);
849        assert_eq!(scheduler.get_lr(), 0.0);
850        assert_eq!(scheduler.last_epoch(), 0);
851    }
852
853    #[test]
854    fn test_reduce_on_plateau_getters() {
855        let scheduler = ReduceLROnPlateau::new(PlateauMode::Max, 0.5, 3);
856        assert_eq!(scheduler.get_lr(), 0.0);
857        assert_eq!(scheduler.last_epoch(), 0);
858    }
859
860    #[test]
861    fn test_plateau_mode_eq() {
862        assert_eq!(PlateauMode::Min, PlateauMode::Min);
863        assert_eq!(PlateauMode::Max, PlateauMode::Max);
864        assert_ne!(PlateauMode::Min, PlateauMode::Max);
865    }
866
867    #[test]
868    fn test_scheduler_clone() {
869        let scheduler = StepLR::with_lr(0.1, 5, 0.9);
870        let cloned = scheduler.clone();
871        assert_eq!(scheduler.get_lr(), cloned.get_lr());
872    }
873
874    #[test]
875    fn test_scheduler_debug() {
876        let scheduler = StepLR::with_lr(0.1, 5, 0.9);
877        let debug = format!("{scheduler:?}");
878        assert!(debug.contains("StepLR"));
879    }
880}