rust_lstm/
schedulers.rs

1use std::f64::consts::PI;
2
3/// Learning rate scheduler trait for adaptive learning rate adjustment during training
4pub trait LearningRateScheduler {
5    /// Get the learning rate for the current epoch
6    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64;
7    
8    /// Reset the scheduler state (useful for multiple training runs)
9    fn reset(&mut self);
10    
11    /// Get the name of the scheduler for logging
12    fn name(&self) -> &'static str;
13}
14
15/// Constant learning rate (no scheduling)
16#[derive(Clone, Debug)]
17pub struct ConstantLR;
18
19impl LearningRateScheduler for ConstantLR {
20    fn get_lr(&mut self, _epoch: usize, base_lr: f64) -> f64 {
21        base_lr
22    }
23    
24    fn reset(&mut self) {}
25    
26    fn name(&self) -> &'static str {
27        "ConstantLR"
28    }
29}
30
31/// Step decay scheduler: multiply LR by gamma every step_size epochs
32#[derive(Clone, Debug)]
33pub struct StepLR {
34    step_size: usize,
35    gamma: f64,
36}
37
38impl StepLR {
39    pub fn new(step_size: usize, gamma: f64) -> Self {
40        StepLR { step_size, gamma }
41    }
42}
43
44impl LearningRateScheduler for StepLR {
45    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
46        let steps = epoch / self.step_size;
47        base_lr * self.gamma.powi(steps as i32)
48    }
49    
50    fn reset(&mut self) {}
51    
52    fn name(&self) -> &'static str {
53        "StepLR"
54    }
55}
56
57/// Multi-step decay: multiply LR by gamma at specific milestones
58#[derive(Clone, Debug)]
59pub struct MultiStepLR {
60    milestones: Vec<usize>,
61    gamma: f64,
62}
63
64impl MultiStepLR {
65    pub fn new(milestones: Vec<usize>, gamma: f64) -> Self {
66        MultiStepLR { milestones, gamma }
67    }
68}
69
70impl LearningRateScheduler for MultiStepLR {
71    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
72        let num_reductions = self.milestones.iter()
73            .filter(|&&milestone| epoch >= milestone)
74            .count();
75        base_lr * self.gamma.powi(num_reductions as i32)
76    }
77    
78    fn reset(&mut self) {}
79    
80    fn name(&self) -> &'static str {
81        "MultiStepLR"
82    }
83}
84
85/// Exponential decay scheduler: multiply LR by gamma every epoch
86#[derive(Clone, Debug)]
87pub struct ExponentialLR {
88    gamma: f64,
89}
90
91impl ExponentialLR {
92    pub fn new(gamma: f64) -> Self {
93        ExponentialLR { gamma }
94    }
95}
96
97impl LearningRateScheduler for ExponentialLR {
98    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
99        base_lr * self.gamma.powi(epoch as i32)
100    }
101    
102    fn reset(&mut self) {}
103    
104    fn name(&self) -> &'static str {
105        "ExponentialLR"
106    }
107}
108
109/// Cosine annealing scheduler with warm restarts
110#[derive(Clone, Debug)]
111pub struct CosineAnnealingLR {
112    t_max: usize,
113    eta_min: f64,
114    last_epoch: usize,
115}
116
117impl CosineAnnealingLR {
118    pub fn new(t_max: usize, eta_min: f64) -> Self {
119        CosineAnnealingLR {
120            t_max,
121            eta_min,
122            last_epoch: 0,
123        }
124    }
125}
126
127impl LearningRateScheduler for CosineAnnealingLR {
128    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
129        self.last_epoch = epoch;
130        if epoch == 0 {
131            return base_lr;
132        }
133        
134        let t = epoch % self.t_max;
135        self.eta_min + (base_lr - self.eta_min) * 
136            (1.0 + (PI * t as f64 / self.t_max as f64).cos()) / 2.0
137    }
138    
139    fn reset(&mut self) {
140        self.last_epoch = 0;
141    }
142    
143    fn name(&self) -> &'static str {
144        "CosineAnnealingLR"
145    }
146}
147
148/// Cosine annealing with warm restarts
149#[derive(Clone, Debug)]
150pub struct CosineAnnealingWarmRestarts {
151    t_0: usize,
152    t_mult: usize,
153    eta_min: f64,
154    last_restart: usize,
155    restart_count: usize,
156}
157
158impl CosineAnnealingWarmRestarts {
159    pub fn new(t_0: usize, t_mult: usize, eta_min: f64) -> Self {
160        CosineAnnealingWarmRestarts {
161            t_0,
162            t_mult,
163            eta_min,
164            last_restart: 0,
165            restart_count: 0,
166        }
167    }
168}
169
170impl LearningRateScheduler for CosineAnnealingWarmRestarts {
171    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
172        if epoch == 0 {
173            return base_lr;
174        }
175        
176        let t_cur = epoch - self.last_restart;
177        let t_i = self.t_0 * self.t_mult.pow(self.restart_count as u32);
178        
179        if t_cur >= t_i {
180            self.last_restart = epoch;
181            self.restart_count += 1;
182            return base_lr;
183        }
184        
185        self.eta_min + (base_lr - self.eta_min) * 
186            (1.0 + (PI * t_cur as f64 / t_i as f64).cos()) / 2.0
187    }
188    
189    fn reset(&mut self) {
190        self.last_restart = 0;
191        self.restart_count = 0;
192    }
193    
194    fn name(&self) -> &'static str {
195        "CosineAnnealingWarmRestarts"
196    }
197}
198
199/// One cycle learning rate policy (popular for modern deep learning)
200#[derive(Clone, Debug)]
201pub struct OneCycleLR {
202    max_lr: f64,
203    total_steps: usize,
204    pct_start: f64,
205    anneal_strategy: AnnealStrategy,
206    div_factor: f64,
207    final_div_factor: f64,
208}
209
210#[derive(Clone, Debug)]
211pub enum AnnealStrategy {
212    Cos,
213    Linear,
214}
215
216impl OneCycleLR {
217    pub fn new(max_lr: f64, total_steps: usize) -> Self {
218        OneCycleLR {
219            max_lr,
220            total_steps,
221            pct_start: 0.3,
222            anneal_strategy: AnnealStrategy::Cos,
223            div_factor: 25.0,
224            final_div_factor: 10000.0,
225        }
226    }
227    
228    pub fn with_params(
229        max_lr: f64,
230        total_steps: usize,
231        pct_start: f64,
232        anneal_strategy: AnnealStrategy,
233        div_factor: f64,
234        final_div_factor: f64,
235    ) -> Self {
236        OneCycleLR {
237            max_lr,
238            total_steps,
239            pct_start,
240            anneal_strategy,
241            div_factor,
242            final_div_factor,
243        }
244    }
245}
246
247impl LearningRateScheduler for OneCycleLR {
248    fn get_lr(&mut self, epoch: usize, _base_lr: f64) -> f64 {
249        if epoch >= self.total_steps {
250            return self.max_lr / self.final_div_factor;
251        }
252        
253        let _step_ratio = epoch as f64 / self.total_steps as f64;
254        let warmup_steps = (self.total_steps as f64 * self.pct_start) as usize;
255        
256        if epoch < warmup_steps {
257            // Warmup phase
258            let warmup_ratio = epoch as f64 / warmup_steps as f64;
259            (self.max_lr / self.div_factor) + 
260                (self.max_lr - self.max_lr / self.div_factor) * warmup_ratio
261        } else {
262            // Annealing phase
263            let anneal_ratio = (epoch - warmup_steps) as f64 / 
264                (self.total_steps - warmup_steps) as f64;
265            
266            match self.anneal_strategy {
267                AnnealStrategy::Cos => {
268                    let cos_factor = (1.0 + (PI * anneal_ratio).cos()) / 2.0;
269                    (self.max_lr / self.final_div_factor) + 
270                        (self.max_lr - self.max_lr / self.final_div_factor) * cos_factor
271                },
272                AnnealStrategy::Linear => {
273                    self.max_lr - (self.max_lr - self.max_lr / self.final_div_factor) * anneal_ratio
274                }
275            }
276        }
277    }
278    
279    fn reset(&mut self) {}
280    
281    fn name(&self) -> &'static str {
282        "OneCycleLR"
283    }
284}
285
286/// Reduce learning rate on plateau (when validation loss stops improving)
287#[derive(Clone, Debug)]
288pub struct ReduceLROnPlateau {
289    factor: f64,
290    patience: usize,
291    threshold: f64,
292    cooldown: usize,
293    min_lr: f64,
294    best_loss: f64,
295    wait_count: usize,
296    cooldown_counter: usize,
297    current_lr: f64,
298}
299
300impl ReduceLROnPlateau {
301    pub fn new(factor: f64, patience: usize) -> Self {
302        ReduceLROnPlateau {
303            factor,
304            patience,
305            threshold: 1e-4,
306            cooldown: 0,
307            min_lr: 0.0,
308            best_loss: f64::INFINITY,
309            wait_count: 0,
310            cooldown_counter: 0,
311            current_lr: 0.0,
312        }
313    }
314    
315    pub fn with_params(
316        factor: f64,
317        patience: usize,
318        threshold: f64,
319        cooldown: usize,
320        min_lr: f64,
321    ) -> Self {
322        ReduceLROnPlateau {
323            factor,
324            patience,
325            threshold,
326            cooldown,
327            min_lr,
328            best_loss: f64::INFINITY,
329            wait_count: 0,
330            cooldown_counter: 0,
331            current_lr: 0.0,
332        }
333    }
334    
335    /// Update the scheduler with the current validation loss
336    pub fn step(&mut self, val_loss: f64, base_lr: f64) -> f64 {
337        if self.current_lr == 0.0 {
338            self.current_lr = base_lr;
339        }
340        
341        if self.cooldown_counter > 0 {
342            self.cooldown_counter -= 1;
343            return self.current_lr;
344        }
345        
346        if val_loss < self.best_loss - self.threshold {
347            self.best_loss = val_loss;
348            self.wait_count = 0;
349        } else {
350            self.wait_count += 1;
351            
352            if self.wait_count >= self.patience {
353                let new_lr = self.current_lr * self.factor;
354                self.current_lr = new_lr.max(self.min_lr);
355                self.wait_count = 0;
356                self.cooldown_counter = self.cooldown;
357                println!("ReduceLROnPlateau: reducing learning rate to {:.2e}", self.current_lr);
358            }
359        }
360        
361        self.current_lr
362    }
363}
364
365impl LearningRateScheduler for ReduceLROnPlateau {
366    fn get_lr(&mut self, _epoch: usize, base_lr: f64) -> f64 {
367        if self.current_lr == 0.0 {
368            self.current_lr = base_lr;
369        }
370        self.current_lr
371    }
372    
373    fn reset(&mut self) {
374        self.best_loss = f64::INFINITY;
375        self.wait_count = 0;
376        self.cooldown_counter = 0;
377        self.current_lr = 0.0;
378    }
379    
380    fn name(&self) -> &'static str {
381        "ReduceLROnPlateau"
382    }
383}
384
385/// Linear learning rate schedule
386#[derive(Clone, Debug)]
387pub struct LinearLR {
388    start_factor: f64,
389    end_factor: f64,
390    total_iters: usize,
391}
392
393impl LinearLR {
394    pub fn new(start_factor: f64, end_factor: f64, total_iters: usize) -> Self {
395        LinearLR {
396            start_factor,
397            end_factor,
398            total_iters,
399        }
400    }
401}
402
403impl LearningRateScheduler for LinearLR {
404    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
405        if epoch >= self.total_iters {
406            return base_lr * self.end_factor;
407        }
408        
409        let progress = epoch as f64 / self.total_iters as f64;
410        let factor = self.start_factor + 
411            (self.end_factor - self.start_factor) * progress;
412        
413        base_lr * factor
414    }
415    
416    fn reset(&mut self) {}
417    
418    fn name(&self) -> &'static str {
419        "LinearLR"
420    }
421}
422
423/// Polynomial learning rate decay
424#[derive(Clone, Debug)]
425pub struct PolynomialLR {
426    total_iters: usize,
427    power: f64,
428    end_lr: f64,
429}
430
431impl PolynomialLR {
432    pub fn new(total_iters: usize, power: f64, end_lr: f64) -> Self {
433        PolynomialLR {
434            total_iters,
435            power,
436            end_lr,
437        }
438    }
439}
440
441impl LearningRateScheduler for PolynomialLR {
442    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
443        if epoch >= self.total_iters {
444            return self.end_lr;
445        }
446        
447        let factor = (1.0 - epoch as f64 / self.total_iters as f64).powf(self.power);
448        self.end_lr + (base_lr - self.end_lr) * factor
449    }
450    
451    fn reset(&mut self) {}
452    
453    fn name(&self) -> &'static str {
454        "PolynomialLR"
455    }
456}
457
458/// Cyclical learning rate policy with different modes
459#[derive(Clone, Debug)]
460pub struct CyclicalLR {
461    base_lr: f64,
462    max_lr: f64,
463    step_size: usize,
464    mode: CyclicalMode,
465    gamma: f64,
466    scale_mode: ScaleMode,
467    last_step: usize,
468}
469
470#[derive(Clone, Debug)]
471pub enum CyclicalMode {
472    Triangular,
473    Triangular2,
474    ExpRange,
475}
476
477#[derive(Clone, Debug)]
478pub enum ScaleMode {
479    Cycle,
480    Iterations,
481}
482
483impl CyclicalLR {
484    pub fn new(base_lr: f64, max_lr: f64, step_size: usize) -> Self {
485        CyclicalLR {
486            base_lr,
487            max_lr,
488            step_size,
489            mode: CyclicalMode::Triangular,
490            gamma: 1.0,
491            scale_mode: ScaleMode::Cycle,
492            last_step: 0,
493        }
494    }
495    
496    pub fn with_mode(mut self, mode: CyclicalMode) -> Self {
497        self.mode = mode;
498        self
499    }
500    
501    pub fn with_gamma(mut self, gamma: f64) -> Self {
502        self.gamma = gamma;
503        self
504    }
505    
506    pub fn with_scale_mode(mut self, scale_mode: ScaleMode) -> Self {
507        self.scale_mode = scale_mode;
508        self
509    }
510}
511
512impl LearningRateScheduler for CyclicalLR {
513    fn get_lr(&mut self, epoch: usize, _base_lr: f64) -> f64 {
514        self.last_step = epoch;
515        
516        let cycle = (epoch as f64 / (2.0 * self.step_size as f64)).floor() as usize;
517        let x = (epoch as f64 / self.step_size as f64 - 2.0 * cycle as f64 - 1.0).abs();
518        
519        let scale_factor = match self.mode {
520            CyclicalMode::Triangular => 1.0,
521            CyclicalMode::Triangular2 => 1.0 / (2.0_f64.powi(cycle as i32 - 1)),
522            CyclicalMode::ExpRange => self.gamma.powi(epoch as i32),
523        };
524        
525        let scale_factor = match self.scale_mode {
526            ScaleMode::Cycle => scale_factor,
527            ScaleMode::Iterations => self.gamma.powi(epoch as i32),
528        };
529        
530        self.base_lr + (self.max_lr - self.base_lr) * (1.0 - x).max(0.0) * scale_factor
531    }
532    
533    fn reset(&mut self) {
534        self.last_step = 0;
535    }
536    
537    fn name(&self) -> &'static str {
538        "CyclicalLR"
539    }
540}
541
542/// Warmup scheduler that gradually increases learning rate
543#[derive(Clone, Debug)]
544pub struct WarmupScheduler<S: LearningRateScheduler> {
545    warmup_epochs: usize,
546    base_scheduler: S,
547    warmup_start_lr: f64,
548}
549
550impl<S: LearningRateScheduler> WarmupScheduler<S> {
551    pub fn new(warmup_epochs: usize, base_scheduler: S, warmup_start_lr: f64) -> Self {
552        WarmupScheduler {
553            warmup_epochs,
554            base_scheduler,
555            warmup_start_lr,
556        }
557    }
558}
559
560impl<S: LearningRateScheduler> LearningRateScheduler for WarmupScheduler<S> {
561    fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
562        if epoch < self.warmup_epochs {
563            // Linear warmup
564            let warmup_factor = epoch as f64 / self.warmup_epochs as f64;
565            self.warmup_start_lr + (base_lr - self.warmup_start_lr) * warmup_factor
566        } else {
567            // Use base scheduler after warmup
568            self.base_scheduler.get_lr(epoch - self.warmup_epochs, base_lr)
569        }
570    }
571    
572    fn reset(&mut self) {
573        self.base_scheduler.reset();
574    }
575    
576    fn name(&self) -> &'static str {
577        "WarmupScheduler"
578    }
579}
580
581/// Learning rate schedule visualization helper
582pub struct LRScheduleVisualizer;
583
584impl LRScheduleVisualizer {
585    /// Generate learning rate values for visualization
586    pub fn generate_schedule<S: LearningRateScheduler>(
587        mut scheduler: S,
588        base_lr: f64,
589        epochs: usize,
590    ) -> Vec<(usize, f64)> {
591        let mut schedule = Vec::new();
592        
593        for epoch in 0..epochs {
594            let lr = scheduler.get_lr(epoch, base_lr);
595            schedule.push((epoch, lr));
596        }
597        
598        schedule
599    }
600    
601    /// Print ASCII visualization of learning rate schedule
602    pub fn print_schedule<S: LearningRateScheduler>(
603        scheduler: S,
604        base_lr: f64,
605        epochs: usize,
606        width: usize,
607        height: usize,
608    ) {
609        let schedule = Self::generate_schedule(scheduler, base_lr, epochs);
610        
611        if schedule.is_empty() {
612            return;
613        }
614        
615        let min_lr = schedule.iter().map(|(_, lr)| *lr).fold(f64::INFINITY, f64::min);
616        let max_lr = schedule.iter().map(|(_, lr)| *lr).fold(0.0, f64::max);
617        
618        println!("Learning Rate Schedule Visualization ({}x{})", width, height);
619        println!("Min LR: {:.2e}, Max LR: {:.2e}", min_lr, max_lr);
620        println!("┌{}┐", "─".repeat(width));
621        
622        for row in 0..height {
623            let y_value = max_lr - (max_lr - min_lr) * row as f64 / (height - 1) as f64;
624            print!("│");
625            
626            for col in 0..width {
627                let epoch_idx = col * epochs / width;
628                let lr = if epoch_idx < schedule.len() {
629                    schedule[epoch_idx].1
630                } else {
631                    min_lr
632                };
633                
634                if (lr - y_value).abs() < (max_lr - min_lr) / height as f64 {
635                    print!("█");
636                } else {
637                    print!(" ");
638                }
639            }
640            
641            println!("│ {:.2e}", y_value);
642        }
643        
644        println!("└{}┘", "─".repeat(width));
645        print!(" ");
646        for i in 0..=4 {
647            let epoch = i * epochs / 4;
648            print!("{:>width$}", epoch, width = width / 5);
649        }
650        println!();
651    }
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657
658    #[test]
659    fn test_constant_lr() {
660        let mut scheduler = ConstantLR;
661        let base_lr = 0.01;
662        
663        assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
664        assert_eq!(scheduler.get_lr(10, base_lr), base_lr);
665        assert_eq!(scheduler.get_lr(100, base_lr), base_lr);
666    }
667
668    #[test]
669    fn test_step_lr() {
670        let mut scheduler = StepLR::new(10, 0.1);
671        let base_lr = 0.01;
672        
673        assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
674        assert_eq!(scheduler.get_lr(9, base_lr), base_lr);
675        assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-15);
676        assert!((scheduler.get_lr(20, base_lr) - base_lr * 0.01).abs() < 1e-15);
677    }
678
679    #[test]
680    fn test_exponential_lr() {
681        let mut scheduler = ExponentialLR::new(0.9);
682        let base_lr = 0.01;
683        
684        assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
685        assert!((scheduler.get_lr(1, base_lr) - base_lr * 0.9).abs() < 1e-10);
686        assert!((scheduler.get_lr(2, base_lr) - base_lr * 0.81).abs() < 1e-10);
687    }
688
689    #[test]
690    fn test_multi_step_lr() {
691        let mut scheduler = MultiStepLR::new(vec![10, 20], 0.1);
692        let base_lr = 0.01;
693        
694        assert_eq!(scheduler.get_lr(5, base_lr), base_lr);
695        assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-15);
696        assert!((scheduler.get_lr(15, base_lr) - base_lr * 0.1).abs() < 1e-15);
697        assert!((scheduler.get_lr(20, base_lr) - base_lr * 0.01).abs() < 1e-15);
698    }
699
700    #[test]
701    fn test_one_cycle_lr() {
702        let mut scheduler = OneCycleLR::new(0.1, 100);
703        let base_lr = 0.01;
704        
705        let lr_0 = scheduler.get_lr(0, base_lr);
706        let lr_30 = scheduler.get_lr(30, base_lr); // Should be close to max
707        let lr_100 = scheduler.get_lr(100, base_lr); // Should be very small
708        
709        assert!(lr_0 < lr_30);
710        assert!(lr_100 < lr_0);
711        assert!(lr_30 <= 0.1);
712    }
713
714    #[test]
715    fn test_reduce_lr_on_plateau() {
716        let mut scheduler = ReduceLROnPlateau::new(0.5, 2);
717        let base_lr = 0.01;
718        
719        // Should not reduce initially
720        let lr1 = scheduler.step(1.0, base_lr);
721        assert_eq!(lr1, base_lr);
722        
723        // Should not reduce with improving loss
724        let lr2 = scheduler.step(0.8, base_lr);
725        assert_eq!(lr2, base_lr);
726        
727        // Should reduce after patience epochs without improvement
728        let _lr3 = scheduler.step(0.9, base_lr);
729        let _lr4 = scheduler.step(0.9, base_lr);
730        let lr5 = scheduler.step(0.9, base_lr);
731        
732        assert!(lr5 < base_lr);
733        assert!((lr5 - base_lr * 0.5).abs() < 1e-10);
734    }
735
736    #[test]
737    fn test_linear_lr() {
738        let mut scheduler = LinearLR::new(1.0, 0.1, 10);
739        let base_lr = 0.01;
740        
741        assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
742        assert!((scheduler.get_lr(5, base_lr) - base_lr * 0.55).abs() < 1e-10);
743        assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-10);
744    }
745
746    #[test]
747    fn test_polynomial_lr() {
748        let mut scheduler = PolynomialLR::new(100, 2.0, 0.01);
749        let base_lr = 0.1;
750        
751        assert_eq!(scheduler.get_lr(0, base_lr), 0.1);
752        // At epoch 50: factor = (1 - 50/100)^2 = 0.25
753        // lr = 0.01 + (0.1 - 0.01) * 0.25 = 0.01 + 0.0225 = 0.0325
754        assert!((scheduler.get_lr(50, base_lr) - 0.0325).abs() < 1e-10);
755        assert!((scheduler.get_lr(100, base_lr) - 0.01).abs() < 1e-10);
756    }
757
758    #[test]
759    fn test_cyclical_lr() {
760        let mut scheduler = CyclicalLR::new(0.1, 1.0, 10);
761        let base_lr = 0.1;
762        
763        assert_eq!(scheduler.get_lr(0, base_lr), 0.1);
764        // At epoch 5: cycle=0, x=0.5, lr should be at peak 
765        // lr = 0.1 + (1.0 - 0.1) * (1 - 0.5) = 0.1 + 0.9 * 0.5 = 0.55
766        assert!((scheduler.get_lr(5, base_lr) - 0.55).abs() < 1e-10);
767        // At epoch 10: cycle=0, x=1.0, lr should be at max
768        // lr = 0.1 + (1.0 - 0.1) * (1 - 1.0) = 0.1 + 0.9 * 0.0 = 0.1
769        // But actually at epoch 10, we're at the peak (x=0): 0.1 + 0.9 * 1.0 = 1.0
770        assert_eq!(scheduler.get_lr(10, base_lr), 1.0);
771    }
772
773    #[test]
774    fn test_warmup_scheduler() {
775        let base_scheduler = ConstantLR;
776        let mut scheduler = WarmupScheduler::new(10, base_scheduler, 0.01);
777        let base_lr = 0.1;
778        
779        assert_eq!(scheduler.get_lr(0, base_lr), 0.01);
780        // At epoch 5: warmup_factor = 5/10 = 0.5
781        // lr = 0.01 + (0.1 - 0.01) * 0.5 = 0.01 + 0.045 = 0.055
782        assert!((scheduler.get_lr(5, base_lr) - 0.055).abs() < 1e-10);
783        assert_eq!(scheduler.get_lr(10, base_lr), 0.1);
784    }
785}