pub fn step_decay(base_lr: f32, step: u32, step_size: u32, gamma: f32) -> Option<f32> {
if step_size == 0
|| !base_lr.is_finite()
|| base_lr <= 0.0
|| !gamma.is_finite()
|| gamma <= 0.0
{
return None;
}
let k = (step.max(1) - 1) / step_size;
Some(base_lr * crate::math::powf(gamma, k as f32))
}
pub fn step_decay_f64(base_lr: f64, step: u32, step_size: u32, gamma: f64) -> Option<f64> {
if step_size == 0
|| !base_lr.is_finite()
|| base_lr <= 0.0
|| !gamma.is_finite()
|| gamma <= 0.0
{
return None;
}
let k = (step.max(1) - 1) / step_size;
Some(base_lr * crate::math::powd(gamma, k as f64))
}
pub fn cosine_decay(base_lr: f32, step: u32, total_steps: u32, min_lr_ratio: f32) -> Option<f32> {
if total_steps == 0
|| !base_lr.is_finite()
|| base_lr <= 0.0
|| !min_lr_ratio.is_finite()
|| !(0.0..=1.0).contains(&min_lr_ratio)
{
return None;
}
let s = step.max(1).min(total_steps);
let progress = s as f32 / total_steps as f32;
let cosine = 0.5 * (1.0 + crate::math::cosf(core::f32::consts::PI * progress));
Some(base_lr * (min_lr_ratio + (1.0 - min_lr_ratio) * cosine))
}
pub fn cosine_decay_f64(
base_lr: f64,
step: u32,
total_steps: u32,
min_lr_ratio: f64,
) -> Option<f64> {
if total_steps == 0
|| !base_lr.is_finite()
|| base_lr <= 0.0
|| !min_lr_ratio.is_finite()
|| !(0.0..=1.0).contains(&min_lr_ratio)
{
return None;
}
let s = step.max(1).min(total_steps);
let progress = s as f64 / total_steps as f64;
let cosine = 0.5 * (1.0 + crate::math::cosd(core::f64::consts::PI * progress));
Some(base_lr * (min_lr_ratio + (1.0 - min_lr_ratio) * cosine))
}