use crate::schedulers::LRScheduler;
use crate::Float;
pub struct ExponentialLR<F: Float> {
pub initial_lr: F,
pub gamma: F,
}
impl<F: Float> ExponentialLR<F> {
pub fn new(initial_lr: F, gamma: F) -> Self {
assert!(gamma > F::zero(), "gamma must be positive");
Self { initial_lr, gamma }
}
pub fn slow_decay(initial_lr: F) -> Self {
Self::new(
initial_lr,
F::from(0.99).expect("Failed to convert constant to float"),
)
}
pub fn moderate_decay(initial_lr: F) -> Self {
Self::new(
initial_lr,
F::from(0.95).expect("Failed to convert constant to float"),
)
}
pub fn fast_decay(initial_lr: F) -> Self {
Self::new(
initial_lr,
F::from(0.9).expect("Failed to convert constant to float"),
)
}
pub fn from_half_life(initial_lr: F, half_life: usize) -> Self {
assert!(half_life > 0, "half_life must be greater than 0");
let gamma = F::from(0.5)
.expect("Operation failed")
.powf(F::one() / F::from(half_life).expect("Failed to convert to float"));
Self::new(initial_lr, gamma)
}
pub fn lr_after_steps(&self, step: usize) -> F {
self.get_lr(step)
}
pub fn steps_to_reach(&self, target_lr: F) -> Option<usize> {
if target_lr <= F::zero() || target_lr > self.initial_lr {
return None;
}
let ratio = target_lr / self.initial_lr;
let steps = ratio.ln() / self.gamma.ln();
Some(steps.ceil().to_usize().unwrap_or(usize::MAX))
}
pub fn decay_rate(&self) -> F {
self.gamma
}
pub fn decay_percentage(&self) -> F {
(F::one() - self.gamma) * F::from(100.0).expect("Failed to convert constant to float")
}
}
impl<F: Float> LRScheduler<F> for ExponentialLR<F> {
fn get_lr(&self, step: usize) -> F {
if step == 0 {
self.initial_lr
} else {
let decay_factor = self
.gamma
.powf(F::from(step).expect("Failed to convert to float"));
self.initial_lr * decay_factor
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exponential_lr_basic() {
let scheduler = ExponentialLR::new(1.0f32, 0.5);
assert_eq!(scheduler.get_lr(0), 1.0);
assert_eq!(scheduler.get_lr(1), 0.5);
assert_eq!(scheduler.get_lr(2), 0.25);
assert_eq!(scheduler.get_lr(3), 0.125);
}
#[test]
fn test_exponential_lr_presets() {
let slow = ExponentialLR::slow_decay(0.1f32);
assert!((slow.gamma - 0.99).abs() < 1e-6);
let moderate = ExponentialLR::moderate_decay(0.1f32);
assert!((moderate.gamma - 0.95).abs() < 1e-6);
let fast = ExponentialLR::fast_decay(0.1f32);
assert!((fast.gamma - 0.9).abs() < 1e-6);
}
#[test]
fn test_exponential_lr_half_life() {
let scheduler = ExponentialLR::from_half_life(1.0f32, 10);
let lr_at_half_life = scheduler.get_lr(10);
assert!((lr_at_half_life - 0.5).abs() < 1e-3);
let lr_at_double_half_life = scheduler.get_lr(20);
assert!((lr_at_double_half_life - 0.25).abs() < 1e-3);
}
#[test]
fn test_exponential_lr_steps_to_reach() {
let scheduler = ExponentialLR::new(1.0f32, 0.5);
assert_eq!(scheduler.steps_to_reach(0.5), Some(1));
assert_eq!(scheduler.steps_to_reach(0.25), Some(2));
assert_eq!(scheduler.steps_to_reach(0.0), None);
assert_eq!(scheduler.steps_to_reach(2.0), None);
}
#[test]
fn test_exponential_lr_decay_info() {
let scheduler = ExponentialLR::new(1.0f32, 0.9);
assert!((scheduler.decay_rate() - 0.9).abs() < 1e-6);
assert!((scheduler.decay_percentage() - 10.0).abs() < 1e-5);
}
#[test]
fn test_exponential_lr_sequence() {
let scheduler = ExponentialLR::new(1.0f32, 0.5);
let sequence = scheduler.get_lr_sequence(0, 5);
let expected = [1.0, 0.5, 0.25, 0.125, 0.0625];
for (actual, expected) in sequence.iter().zip(expected.iter()) {
assert!(
(actual - expected).abs() < 1e-6,
"Expected {}, got {}",
expected,
actual
);
}
}
#[test]
#[should_panic(expected = "gamma must be positive")]
fn test_exponential_lr_negative_gamma() {
ExponentialLR::new(0.1f32, -0.1);
}
#[test]
#[should_panic(expected = "gamma must be positive")]
fn test_exponential_lr_zero_gamma() {
ExponentialLR::new(0.1f32, 0.0);
}
#[test]
#[should_panic(expected = "half_life must be greater than 0")]
fn test_exponential_lr_zero_half_life() {
ExponentialLR::from_half_life(0.1f32, 0);
}
}