use crate::schedulers::LRScheduler;
use crate::Float;
pub struct StepLR<F: Float> {
pub initial_lr: F,
pub step_size: usize,
pub gamma: F,
}
impl<F: Float> StepLR<F> {
pub fn new(initial_lr: F, step_size: usize, gamma: F) -> Self {
assert!(step_size > 0, "step_size must be greater than 0");
Self {
initial_lr,
step_size,
gamma,
}
}
pub fn default_decay(initial_lr: F) -> Self {
Self::new(
initial_lr,
30, F::from(0.1).expect("Failed to convert constant to float"), )
}
pub fn for_fine_tuning(initial_lr: F) -> Self {
Self::new(
initial_lr,
10, F::from(0.5).expect("Failed to convert constant to float"), )
}
pub fn aggressive_decay(initial_lr: F) -> Self {
Self::new(
initial_lr,
20, F::from(0.01).expect("Failed to convert constant to float"), )
}
pub fn num_decays(&self, step: usize) -> usize {
step / self.step_size
}
pub fn should_decay_at_step(&self, step: usize) -> bool {
step > 0 && step.is_multiple_of(self.step_size)
}
}
impl<F: Float> LRScheduler<F> for StepLR<F> {
fn get_lr(&self, step: usize) -> F {
let num_decays = self.num_decays(step);
if num_decays == 0 {
self.initial_lr
} else {
let decay_factor = self
.gamma
.powf(F::from(num_decays).expect("Failed to convert to float"));
self.initial_lr * decay_factor
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_step_lr_basic() {
let scheduler = StepLR::new(0.1f32, 10, 0.1);
assert_eq!(scheduler.get_lr(0), 0.1);
assert_eq!(scheduler.get_lr(5), 0.1);
assert_eq!(scheduler.get_lr(9), 0.1);
assert!((scheduler.get_lr(10) - 0.01).abs() < 1e-6);
assert!((scheduler.get_lr(15) - 0.01).abs() < 1e-6);
assert!((scheduler.get_lr(19) - 0.01).abs() < 1e-6);
assert!((scheduler.get_lr(20) - 0.001).abs() < 1e-6);
assert!((scheduler.get_lr(25) - 0.001).abs() < 1e-6);
}
#[test]
fn test_step_lr_num_decays() {
let scheduler = StepLR::new(0.1f32, 5, 0.5);
assert_eq!(scheduler.num_decays(0), 0);
assert_eq!(scheduler.num_decays(4), 0);
assert_eq!(scheduler.num_decays(5), 1);
assert_eq!(scheduler.num_decays(9), 1);
assert_eq!(scheduler.num_decays(10), 2);
assert_eq!(scheduler.num_decays(15), 3);
}
#[test]
fn test_step_lr_should_decay() {
let scheduler = StepLR::new(0.1f32, 5, 0.5);
assert!(!scheduler.should_decay_at_step(0));
assert!(!scheduler.should_decay_at_step(4));
assert!(scheduler.should_decay_at_step(5));
assert!(!scheduler.should_decay_at_step(6));
assert!(scheduler.should_decay_at_step(10));
assert!(scheduler.should_decay_at_step(15));
}
#[test]
fn test_step_lr_presets() {
let default_scheduler = StepLR::default_decay(0.1f32);
assert_eq!(default_scheduler.step_size, 30);
assert!((default_scheduler.gamma - 0.1).abs() < 1e-6);
let fine_tune_scheduler = StepLR::for_fine_tuning(0.01f32);
assert_eq!(fine_tune_scheduler.step_size, 10);
assert!((fine_tune_scheduler.gamma - 0.5).abs() < 1e-6);
let aggressive_scheduler = StepLR::aggressive_decay(0.1f32);
assert_eq!(aggressive_scheduler.step_size, 20);
assert!((aggressive_scheduler.gamma - 0.01).abs() < 1e-6);
}
#[test]
fn test_step_lr_sequence() {
let scheduler = StepLR::new(1.0f32, 3, 0.5);
let sequence = scheduler.get_lr_sequence(0, 10);
let expected = [1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.25, 0.25, 0.25, 0.125];
for (actual, expected) in sequence.iter().zip(expected.iter()) {
assert!(
(actual - expected).abs() < 1e-6,
"Expected {}, got {}",
expected,
actual
);
}
}
#[test]
#[should_panic(expected = "step_size must be greater than 0")]
fn test_step_lr_zero_step_size() {
StepLR::new(0.1f32, 0, 0.1);
}
}