use super::decay::{cosine_decay, step_decay};
use super::policy::{LrSchedule, ScheduleError};
use super::warmup::linear_warmup;
pub fn compute_learning_rate(base_lr: f32, step: u32, schedule: LrSchedule) -> Result<f32, ScheduleError> {
let lr = match schedule {
LrSchedule::Constant => Some(base_lr),
LrSchedule::StepDecay { step_size, gamma } => step_decay(base_lr, step, step_size, gamma),
LrSchedule::Cosine { total_steps, min_lr_ratio } => cosine_decay(base_lr, step, total_steps, min_lr_ratio),
LrSchedule::LinearWarmup { warmup_steps } => linear_warmup(base_lr, step, warmup_steps),
};
let lr = lr.ok_or(ScheduleError::InvalidConfig)?;
if !lr.is_finite() || lr <= 0.0 {
return Err(ScheduleError::InvalidConfig);
}
Ok(lr)
}