use super::decay::{cosine_decay, step_decay};
use super::decay::{cosine_decay_f64, step_decay_f64};
use super::policy::{LrSchedule, ScheduleError};
use super::warmup::linear_warmup;
use super::warmup::linear_warmup_f64;
pub fn compute_learning_rate(
base_lr: f32,
step: u32,
schedule: LrSchedule,
) -> Result<f32, ScheduleError> {
let lr = match schedule {
LrSchedule::Constant => Some(base_lr),
LrSchedule::StepDecay { step_size, gamma } => step_decay(base_lr, step, step_size, gamma),
LrSchedule::Cosine {
total_steps,
min_lr_ratio,
} => cosine_decay(base_lr, step, total_steps, min_lr_ratio),
LrSchedule::LinearWarmup { warmup_steps } => linear_warmup(base_lr, step, warmup_steps),
};
let lr = lr.ok_or(ScheduleError::InvalidConfig)?;
if !lr.is_finite() || lr <= 0.0 {
return Err(ScheduleError::InvalidConfig);
}
Ok(lr)
}
pub fn compute_learning_rate_f64(
base_lr: f64,
step: u32,
schedule: super::policy::LrScheduleF64,
) -> Result<f64, ScheduleError> {
let lr = match schedule {
super::policy::LrScheduleF64::Constant => Some(base_lr),
super::policy::LrScheduleF64::StepDecay { step_size, gamma } => {
step_decay_f64(base_lr, step, step_size, gamma)
}
super::policy::LrScheduleF64::Cosine {
total_steps,
min_lr_ratio,
} => cosine_decay_f64(base_lr, step, total_steps, min_lr_ratio),
super::policy::LrScheduleF64::LinearWarmup { warmup_steps } => {
linear_warmup_f64(base_lr, step, warmup_steps)
}
};
let lr = lr.ok_or(ScheduleError::InvalidConfig)?;
if !lr.is_finite() || lr <= 0.0 {
return Err(ScheduleError::InvalidConfig);
}
Ok(lr)
}