use crate::{
Result,
error::{EmptyInputPayload, Error, InvariantViolationPayload, LengthMismatchPayload},
};
pub type Schedule = Box<dyn Fn(usize) -> f32>;
pub fn exponential_decay(init: f32, decay_rate: f32) -> Schedule {
Box::new(move |step| init * decay_rate.powi(step as i32))
}
pub fn step_decay(init: f32, decay_rate: f32, step_size: usize) -> Result<Schedule> {
if step_size == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"step_decay: step_size",
"must be > 0",
)));
}
Ok(Box::new(move |step| {
init * decay_rate.powi((step / step_size) as i32)
}))
}
pub fn cosine_decay(init: f32, decay_steps: usize, end: f32) -> Result<Schedule> {
if decay_steps == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"cosine_decay: decay_steps",
"must be > 0",
)));
}
let pi = std::f32::consts::PI;
let decay_steps_f = decay_steps as f32;
Ok(Box::new(move |step| {
let s = (step as f32).min(decay_steps_f);
let decay = 0.5 * (1.0 + (pi * s / decay_steps_f).cos());
end + decay * (init - end)
}))
}
pub fn linear_schedule(init: f32, end: f32, steps: usize) -> Result<Schedule> {
if steps == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"linear_schedule: steps",
"must be > 0",
)));
}
let steps_f = steps as f32;
let slope = (end - init) / steps_f;
Ok(Box::new(move |step| {
let s = (step as f32).min(steps_f);
s * slope + init
}))
}
pub fn join_schedules(schedules: Vec<Schedule>, boundaries: Vec<usize>) -> Result<Schedule> {
if schedules.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"join_schedules: schedules",
)));
}
if schedules.len() != boundaries.len() + 1 {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"join_schedules: boundaries (must equal schedules - 1)",
schedules.len() - 1,
boundaries.len(),
)));
}
Ok(Box::new(move |step| {
let mut output = schedules[0](step);
for (i, &boundary) in boundaries.iter().enumerate() {
if step >= boundary {
output = schedules[i + 1](step - boundary);
}
}
output
}))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exponential_decay_at_step_0_returns_init() {
let s = exponential_decay(0.1, 0.9);
assert!((s(0) - 0.1).abs() < 1e-6);
}
#[test]
fn exponential_decay_at_step_5_matches_formula() {
let s = exponential_decay(0.1, 0.9);
assert!((s(5) - 0.059_049).abs() < 1e-6, "got {}", s(5));
}
#[test]
fn step_decay_holds_within_one_size_then_drops() -> Result<()> {
let s = step_decay(0.1, 0.5, 10)?;
assert!((s(0) - 0.1).abs() < 1e-6);
assert!((s(9) - 0.1).abs() < 1e-6);
assert!((s(10) - 0.05).abs() < 1e-6);
assert!((s(19) - 0.05).abs() < 1e-6);
assert!((s(20) - 0.025).abs() < 1e-6);
Ok(())
}
#[test]
fn step_decay_rejects_zero_step_size() {
assert!(step_decay(0.1, 0.5, 0).is_err());
}
#[test]
fn cosine_decay_at_t0_t_half_t_end_matches_formula() -> Result<()> {
let s = cosine_decay(0.1, 1000, 0.0)?;
assert!((s(0) - 0.1).abs() < 1e-6);
assert!((s(500) - 0.05).abs() < 1e-5);
assert!((s(1000)).abs() < 1e-5);
assert!((s(2000)).abs() < 1e-5);
Ok(())
}
#[test]
fn linear_schedule_at_endpoints_matches_formula() -> Result<()> {
let s = linear_schedule(0.0, 0.1, 100)?;
assert!((s(0) - 0.0).abs() < 1e-6);
assert!((s(100) - 0.1).abs() < 1e-6);
assert!((s(50) - 0.05).abs() < 1e-6);
assert!((s(150) - 0.1).abs() < 1e-6);
Ok(())
}
#[test]
fn join_schedules_switches_at_boundary() -> Result<()> {
let a = linear_schedule(0.0, 0.1, 10)?;
let b = cosine_decay(0.1, 100, 0.0)?;
let joined = join_schedules(vec![a, b], vec![10])?;
assert!((joined(5) - 0.05).abs() < 1e-6);
assert!((joined(10) - 0.1).abs() < 1e-6);
assert!((joined(110)).abs() < 1e-3);
Ok(())
}
#[test]
fn join_schedules_rejects_wrong_boundary_count() {
let a = exponential_decay(0.1, 0.9);
let b = exponential_decay(0.05, 0.9);
let res = join_schedules(vec![a, b], vec![]);
assert!(res.is_err());
}
}