use super::*;
use crate::optim::Optimizer;
use approx::assert_abs_diff_eq;
#[test]
fn test_cosine_annealing_initial_lr() {
let scheduler = CosineAnnealingLR::new(1.0, 100, 0.0);
assert_abs_diff_eq!(scheduler.get_lr(), 1.0, epsilon = 1e-6);
}
#[test]
fn test_cosine_annealing_final_lr() {
let mut scheduler = CosineAnnealingLR::new(1.0, 100, 0.0);
for _ in 0..100 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.0, epsilon = 1e-6);
}
#[test]
fn test_cosine_annealing_midpoint() {
let mut scheduler = CosineAnnealingLR::new(1.0, 100, 0.0);
for _ in 0..50 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.5, epsilon = 1e-4);
}
#[test]
fn test_cosine_annealing_with_min() {
let mut scheduler = CosineAnnealingLR::new(1.0, 100, 0.1);
assert_abs_diff_eq!(scheduler.get_lr(), 1.0, epsilon = 1e-6);
for _ in 0..100 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.1, epsilon = 1e-6);
}
#[test]
fn test_cosine_annealing_decreases_monotonically() {
let mut scheduler = CosineAnnealingLR::new(1.0, 100, 0.0);
let mut prev_lr = scheduler.get_lr();
for _ in 0..100 {
scheduler.step();
let current_lr = scheduler.get_lr();
assert!(
current_lr <= prev_lr,
"Learning rate should decrease monotonically: prev={prev_lr}, current={current_lr}"
);
prev_lr = current_lr;
}
}
#[test]
fn test_cosine_annealing_with_optimizer() {
use crate::optim::SGD;
let mut optimizer = SGD::new(1.0, 0.0);
let mut scheduler = CosineAnnealingLR::default_min(1.0, 10);
assert_abs_diff_eq!(optimizer.lr(), 1.0, epsilon = 1e-6);
scheduler.apply(&mut optimizer);
assert_abs_diff_eq!(optimizer.lr(), 1.0, epsilon = 1e-6);
scheduler.step();
scheduler.apply(&mut optimizer);
assert!(optimizer.lr() < 1.0);
}
#[test]
fn test_cosine_annealing_past_t_max() {
let mut scheduler = CosineAnnealingLR::new(1.0, 10, 0.0);
for _ in 0..20 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.0, epsilon = 1e-6);
}
#[test]
fn test_linear_warmup_initial() {
let scheduler = LinearWarmupLR::new(0.001, 100);
assert_abs_diff_eq!(scheduler.get_lr(), 0.0, epsilon = 1e-8);
}
#[test]
fn test_linear_warmup_midpoint() {
let mut scheduler = LinearWarmupLR::new(0.001, 100);
for _ in 0..50 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.0005, epsilon = 1e-7);
}
#[test]
fn test_linear_warmup_complete() {
let mut scheduler = LinearWarmupLR::new(0.001, 100);
for _ in 0..100 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.001, epsilon = 1e-7);
}
#[test]
fn test_linear_warmup_past_warmup() {
let mut scheduler = LinearWarmupLR::new(0.001, 100);
for _ in 0..200 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.001, epsilon = 1e-7);
}
#[test]
fn test_linear_warmup_increases_monotonically() {
let mut scheduler = LinearWarmupLR::new(0.001, 100);
let mut prev_lr = scheduler.get_lr();
for _ in 0..100 {
scheduler.step();
let current_lr = scheduler.get_lr();
assert!(
current_lr >= prev_lr,
"LR should increase during warmup: prev={prev_lr}, current={current_lr}"
);
prev_lr = current_lr;
}
}
#[test]
fn test_step_decay_initial() {
let scheduler = StepDecayLR::new(0.1, 10, 0.1);
assert_abs_diff_eq!(scheduler.get_lr(), 0.1, epsilon = 1e-7);
}
#[test]
fn test_step_decay_first_decay() {
let mut scheduler = StepDecayLR::new(0.1, 10, 0.1);
for _ in 0..10 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.01, epsilon = 1e-7);
}
#[test]
fn test_step_decay_second_decay() {
let mut scheduler = StepDecayLR::new(0.1, 10, 0.1);
for _ in 0..20 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.001, epsilon = 1e-8);
}
#[test]
fn test_step_decay_between_steps() {
let mut scheduler = StepDecayLR::new(0.1, 10, 0.1);
for _ in 0..5 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.1, epsilon = 1e-7);
}
#[test]
fn test_warmup_cosine_initial() {
let scheduler = WarmupCosineDecayLR::new(0.001, 0.0, 10, 100);
assert_abs_diff_eq!(scheduler.get_lr(), 0.0, epsilon = 1e-8);
}
#[test]
fn test_warmup_cosine_warmup_midpoint() {
let mut scheduler = WarmupCosineDecayLR::new(0.001, 0.0, 10, 100);
for _ in 0..5 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.0005, epsilon = 1e-7);
}
#[test]
fn test_warmup_cosine_warmup_complete() {
let mut scheduler = WarmupCosineDecayLR::new(0.001, 0.0, 10, 100);
for _ in 0..10 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.001, epsilon = 1e-7);
}
#[test]
fn test_warmup_cosine_decay_complete() {
let mut scheduler = WarmupCosineDecayLR::new(0.001, 0.0, 10, 100);
for _ in 0..100 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.0, epsilon = 1e-7);
}
#[test]
fn test_warmup_cosine_warmup_increases_then_decreases() {
let mut scheduler = WarmupCosineDecayLR::new(0.001, 0.0, 10, 100);
let mut prev_lr = scheduler.get_lr();
for _ in 0..10 {
scheduler.step();
let current_lr = scheduler.get_lr();
assert!(
current_lr >= prev_lr,
"LR should increase during warmup: prev={prev_lr}, current={current_lr}"
);
prev_lr = current_lr;
}
for _ in 10..100 {
scheduler.step();
let current_lr = scheduler.get_lr();
assert!(
current_lr <= prev_lr,
"LR should decrease during decay: prev={prev_lr}, current={current_lr}"
);
prev_lr = current_lr;
}
}
#[test]
fn test_linear_warmup_apply() {
use crate::optim::SGD;
let mut optimizer = SGD::new(0.0, 0.0);
let mut scheduler = LinearWarmupLR::new(0.01, 10);
scheduler.step();
scheduler.apply(&mut optimizer);
assert!(optimizer.lr() > 0.0);
}
#[test]
fn test_linear_warmup_zero_steps() {
let scheduler = LinearWarmupLR::new(0.01, 0);
assert_abs_diff_eq!(scheduler.get_lr(), 0.01, epsilon = 1e-8);
}
#[test]
fn test_step_decay_apply() {
use crate::optim::SGD;
let mut optimizer = SGD::new(0.0, 0.0);
let scheduler = StepDecayLR::new(0.1, 10, 0.1);
scheduler.apply(&mut optimizer);
assert_abs_diff_eq!(optimizer.lr(), 0.1, epsilon = 1e-8);
}
#[test]
fn test_step_decay_zero_step_size() {
let scheduler = StepDecayLR::new(0.1, 0, 0.1);
assert_abs_diff_eq!(scheduler.get_lr(), 0.1, epsilon = 1e-8);
}
#[test]
fn test_warmup_cosine_apply() {
use crate::optim::SGD;
let mut optimizer = SGD::new(0.0, 0.0);
let mut scheduler = WarmupCosineDecayLR::new(0.01, 0.0, 10, 100);
for _ in 0..10 {
scheduler.step();
}
scheduler.apply(&mut optimizer);
assert_abs_diff_eq!(optimizer.lr(), 0.01, epsilon = 1e-8);
}
#[test]
fn test_warmup_cosine_zero_warmup_steps() {
let scheduler = WarmupCosineDecayLR::new(0.01, 0.0, 0, 100);
assert_abs_diff_eq!(scheduler.get_lr(), 0.01, epsilon = 1e-8);
}
#[test]
fn test_warmup_cosine_zero_total_steps() {
let scheduler = WarmupCosineDecayLR::new(0.01, 0.001, 0, 0);
assert_abs_diff_eq!(scheduler.get_lr(), 0.001, epsilon = 1e-8);
}
#[test]
fn test_warmup_cosine_past_total() {
let mut scheduler = WarmupCosineDecayLR::new(0.01, 0.001, 10, 50);
for _ in 0..100 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_lr(), 0.001, epsilon = 1e-8);
}