#[derive(Clone, Copy, Debug, PartialEq)]
pub enum LrScheduleStd {
Constant,
StepDecay { step_size: u32, gamma: f32 },
Cosine { total_steps: u32, min_lr_ratio: f32 },
LinearWarmup { warmup_steps: u32 },
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ScheduleStdError {
InvalidConfig,
}
impl From<native_neural_network::schedulers::ScheduleError> for ScheduleStdError {
fn from(e: native_neural_network::schedulers::ScheduleError) -> Self {
match e {
native_neural_network::schedulers::ScheduleError::InvalidConfig => {
ScheduleStdError::InvalidConfig
}
}
}
}
impl From<LrScheduleStd> for native_neural_network::schedulers::LrSchedule {
fn from(s: LrScheduleStd) -> Self {
match s {
LrScheduleStd::Constant => native_neural_network::schedulers::LrSchedule::Constant,
LrScheduleStd::StepDecay { step_size, gamma } => {
native_neural_network::schedulers::LrSchedule::StepDecay { step_size, gamma }
}
LrScheduleStd::Cosine {
total_steps,
min_lr_ratio,
} => native_neural_network::schedulers::LrSchedule::Cosine {
total_steps,
min_lr_ratio,
},
LrScheduleStd::LinearWarmup { warmup_steps } => {
native_neural_network::schedulers::LrSchedule::LinearWarmup { warmup_steps }
}
}
}
}
pub fn compute_learning_rate(
base_lr: f32,
step: u32,
schedule: LrScheduleStd,
) -> Result<f32, ScheduleStdError> {
let upstream: native_neural_network::schedulers::LrSchedule = schedule.into();
native_neural_network::schedulers::compute_learning_rate(base_lr, step, upstream)
.map_err(|e| e.into())
}