use super::*;
#[test]
fn test_warmup_cosine_full_training() {
let mut optimizer = MockOptimizer::new(0.1);
let mut scheduler = WarmupCosineScheduler::with_min_lr(5, 15, 0.001);
for i in 1..=5 {
scheduler.step(&mut optimizer);
let expected = 0.1 * (i as f32 / 5.0);
assert!((scheduler.get_lr() - expected).abs() < 1e-5);
}
let mut prev_lr = scheduler.get_lr();
for _ in 0..10 {
scheduler.step(&mut optimizer);
assert!(scheduler.get_lr() <= prev_lr);
prev_lr = scheduler.get_lr();
}
assert!(scheduler.get_lr() <= 0.01);
}
#[test]
fn test_reduce_on_plateau_continuous_improvement() {
let mut optimizer = MockOptimizer::new(0.1);
let mut scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.5, 3);
scheduler.step_with_metric(&mut optimizer, 1.0);
scheduler.step_with_metric(&mut optimizer, 0.9);
scheduler.step_with_metric(&mut optimizer, 0.8);
scheduler.step_with_metric(&mut optimizer, 0.7);
scheduler.step_with_metric(&mut optimizer, 0.6);
assert!((optimizer.lr() - 0.1).abs() < 1e-6);
}
#[test]
fn test_reduce_on_plateau_max_mode_improvement() {
let mut optimizer = MockOptimizer::new(0.1);
let mut scheduler = ReduceLROnPlateau::new(PlateauMode::Max, 0.5, 3);
scheduler.step_with_metric(&mut optimizer, 0.5);
scheduler.step_with_metric(&mut optimizer, 0.6);
scheduler.step_with_metric(&mut optimizer, 0.7);
scheduler.step_with_metric(&mut optimizer, 0.8);
assert!((optimizer.lr() - 0.1).abs() < 1e-6);
}
#[test]
fn test_step_lr_many_steps() {
let mut optimizer = MockOptimizer::new(1.0);
let mut scheduler = StepLR::new(2, 0.5);
for _ in 0..10 {
scheduler.step(&mut optimizer);
}
assert!((optimizer.lr() - 0.03125).abs() < 1e-6);
assert_eq!(scheduler.last_epoch(), 10);
}
#[test]
fn test_exponential_lr_many_steps() {
let mut optimizer = MockOptimizer::new(1.0);
let mut scheduler = ExponentialLR::new(0.9);
for _ in 0..5 {
scheduler.step(&mut optimizer);
}
assert!((optimizer.lr() - 0.59049).abs() < 1e-4);
}
#[test]
fn test_reduce_on_plateau_multiple_reductions() {
let mut optimizer = MockOptimizer::new(0.1);
let mut scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.5, 2).min_lr(0.001);
scheduler.step_with_metric(&mut optimizer, 1.0);
scheduler.step_with_metric(&mut optimizer, 1.0);
scheduler.step_with_metric(&mut optimizer, 1.0);
assert!((optimizer.lr() - 0.05).abs() < 1e-6);
scheduler.step_with_metric(&mut optimizer, 1.0);
scheduler.step_with_metric(&mut optimizer, 1.0);
assert!((optimizer.lr() - 0.025).abs() < 1e-6);
}