use crate::error::{OptimizeError, OptimizeResult};
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub enum LrSchedule {
Constant(f64),
ExponentialDecay {
initial: f64,
decay: f64,
},
CosineAnnealing {
lr_max: f64,
lr_min: f64,
t_max: usize,
},
WarmupCosine {
warmup_steps: usize,
lr_peak: f64,
lr_min: f64,
total_steps: usize,
},
StepLr {
initial: f64,
step_size: usize,
gamma: f64,
},
}
impl LrSchedule {
pub fn lr_at(&self, step: usize) -> f64 {
match self {
LrSchedule::Constant(lr) => *lr,
LrSchedule::ExponentialDecay { initial, decay } => {
initial * decay.powi(step as i32)
}
LrSchedule::CosineAnnealing { lr_max, lr_min, t_max } => {
let t = (step % (2 * (*t_max).max(1))) as f64;
let t_m = *t_max as f64;
let cos_inner = PI * t / t_m;
lr_min + 0.5 * (lr_max - lr_min) * (1.0 + cos_inner.cos())
}
LrSchedule::WarmupCosine { warmup_steps, lr_peak, lr_min, total_steps } => {
let ws = *warmup_steps;
let ts = (*total_steps).max(ws + 1);
if step < ws {
lr_peak * step as f64 / ws.max(1) as f64
} else {
let progress = (step - ws) as f64 / (ts - ws) as f64;
lr_min + 0.5 * (lr_peak - lr_min) * (1.0 + (PI * progress).cos())
}
}
LrSchedule::StepLr { initial, step_size, gamma } => {
let n_decays = step / (*step_size).max(1);
initial * gamma.powi(n_decays as i32)
}
}
}
}
#[derive(Debug, Clone)]
pub struct Sgd {
pub learning_rate: f64,
pub momentum: f64,
pub weight_decay: f64,
pub nesterov: bool,
velocity: Vec<f64>,
}
impl Sgd {
pub fn new(learning_rate: f64, momentum: f64) -> Self {
Self {
learning_rate,
momentum,
weight_decay: 0.0,
nesterov: false,
velocity: Vec::new(),
}
}
pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
if params.len() != grad.len() {
return Err(OptimizeError::InvalidInput(format!(
"params length {} != grad length {}",
params.len(),
grad.len()
)));
}
let n = params.len();
if self.velocity.len() != n {
self.velocity = vec![0.0; n];
}
let lr = self.learning_rate;
let mu = self.momentum;
let wd = self.weight_decay;
if self.nesterov {
for i in 0..n {
let g = grad[i] + wd * params[i];
self.velocity[i] = mu * self.velocity[i] + g;
params[i] -= lr * (mu * self.velocity[i] + g);
}
} else {
for i in 0..n {
let g = grad[i] + wd * params[i];
self.velocity[i] = mu * self.velocity[i] + g;
params[i] -= lr * self.velocity[i];
}
}
Ok(())
}
pub fn zero_velocity(&mut self, n: usize) {
self.velocity = vec![0.0; n];
}
}
#[derive(Debug, Clone)]
pub struct Adam {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
m: Vec<f64>,
v: Vec<f64>,
t: usize,
}
impl Adam {
pub fn new(lr: f64) -> Self {
Self {
lr,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.0,
m: Vec::new(),
v: Vec::new(),
t: 0,
}
}
pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
if params.len() != grad.len() {
return Err(OptimizeError::InvalidInput(format!(
"params length {} != grad length {}",
params.len(),
grad.len()
)));
}
let n = params.len();
if self.m.len() != n {
self.m = vec![0.0; n];
self.v = vec![0.0; n];
}
self.t += 1;
let t = self.t as f64;
let bias_corr1 = 1.0 - self.beta1.powf(t);
let bias_corr2 = 1.0 - self.beta2.powf(t);
for i in 0..n {
let g = grad[i] + self.weight_decay * params[i];
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
let m_hat = self.m[i] / bias_corr1;
let v_hat = self.v[i] / bias_corr2;
params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
}
Ok(())
}
pub fn reset_state(&mut self) {
self.m.clear();
self.v.clear();
self.t = 0;
}
}
#[derive(Debug, Clone)]
pub struct AdaGrad {
pub lr: f64,
pub eps: f64,
pub weight_decay: f64,
sum_sq_grad: Vec<f64>,
}
impl AdaGrad {
pub fn new(lr: f64) -> Self {
Self { lr, eps: 1e-8, weight_decay: 0.0, sum_sq_grad: Vec::new() }
}
pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
if params.len() != grad.len() {
return Err(OptimizeError::InvalidInput(format!(
"params/grad length mismatch: {} vs {}",
params.len(),
grad.len()
)));
}
let n = params.len();
if self.sum_sq_grad.len() != n {
self.sum_sq_grad = vec![0.0; n];
}
for i in 0..n {
let g = grad[i] + self.weight_decay * params[i];
self.sum_sq_grad[i] += g * g;
params[i] -= self.lr * g / (self.sum_sq_grad[i].sqrt() + self.eps);
}
Ok(())
}
pub fn reset_state(&mut self) {
self.sum_sq_grad.clear();
}
}
#[derive(Debug, Clone)]
pub struct RmsProp {
pub lr: f64,
pub alpha: f64,
pub eps: f64,
pub momentum: f64,
sq_avg: Vec<f64>,
velocity: Vec<f64>,
}
impl RmsProp {
pub fn new(lr: f64) -> Self {
Self {
lr,
alpha: 0.99,
eps: 1e-8,
momentum: 0.0,
sq_avg: Vec::new(),
velocity: Vec::new(),
}
}
pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
if params.len() != grad.len() {
return Err(OptimizeError::InvalidInput(format!(
"params/grad length mismatch: {} vs {}",
params.len(),
grad.len()
)));
}
let n = params.len();
if self.sq_avg.len() != n {
self.sq_avg = vec![0.0; n];
self.velocity = vec![0.0; n];
}
for i in 0..n {
let g = grad[i];
self.sq_avg[i] = self.alpha * self.sq_avg[i] + (1.0 - self.alpha) * g * g;
let denom = self.sq_avg[i].sqrt() + self.eps;
if self.momentum > 0.0 {
self.velocity[i] = self.momentum * self.velocity[i] + self.lr * g / denom;
params[i] -= self.velocity[i];
} else {
params[i] -= self.lr * g / denom;
}
}
Ok(())
}
pub fn reset_state(&mut self) {
self.sq_avg.clear();
self.velocity.clear();
}
}
#[derive(Debug, Clone)]
pub struct AdamW {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
m: Vec<f64>,
v: Vec<f64>,
t: usize,
}
impl AdamW {
pub fn new(lr: f64) -> Self {
Self {
lr,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.01,
m: Vec::new(),
v: Vec::new(),
t: 0,
}
}
pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
if params.len() != grad.len() {
return Err(OptimizeError::InvalidInput(format!(
"params/grad length mismatch: {} vs {}",
params.len(),
grad.len()
)));
}
let n = params.len();
if self.m.len() != n {
self.m = vec![0.0; n];
self.v = vec![0.0; n];
}
self.t += 1;
let t = self.t as f64;
let bc1 = 1.0 - self.beta1.powf(t);
let bc2 = 1.0 - self.beta2.powf(t);
for i in 0..n {
params[i] *= 1.0 - self.lr * self.weight_decay;
let g = grad[i];
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
let m_hat = self.m[i] / bc1;
let v_hat = self.v[i] / bc2;
params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
}
Ok(())
}
pub fn reset_state(&mut self) {
self.m.clear();
self.v.clear();
self.t = 0;
}
}
#[derive(Debug, Clone)]
pub struct Svrg {
pub lr: f64,
pub n: usize,
pub update_freq: usize,
snapshot_params: Vec<f64>,
snapshot_grad: Vec<f64>,
inner_t: usize,
}
impl Svrg {
pub fn new(lr: f64, n: usize, update_freq: usize) -> Self {
Self {
lr,
n,
update_freq,
snapshot_params: Vec::new(),
snapshot_grad: Vec::new(),
inner_t: 0,
}
}
pub fn step(
&mut self,
params: &mut Vec<f64>,
stochastic_grad: &[f64],
snapshot_grad_i: &[f64],
) -> OptimizeResult<()> {
let n = params.len();
if stochastic_grad.len() != n || snapshot_grad_i.len() != n {
return Err(OptimizeError::InvalidInput(format!(
"SVRG gradient/param length mismatch: params={}, sg={}, sgi={}",
n,
stochastic_grad.len(),
snapshot_grad_i.len()
)));
}
if self.snapshot_grad.len() != n {
return Err(OptimizeError::InvalidInput(
"SVRG: snapshot not initialised — call update_snapshot first".to_string(),
));
}
for i in 0..n {
let g_tilde =
stochastic_grad[i] - snapshot_grad_i[i] + self.snapshot_grad[i];
params[i] -= self.lr * g_tilde;
}
self.inner_t += 1;
Ok(())
}
pub fn update_snapshot(&mut self, params: &[f64], full_grad: &[f64]) {
self.snapshot_params = params.to_vec();
self.snapshot_grad = full_grad.to_vec();
self.inner_t = 0;
}
pub fn needs_snapshot_update(&self) -> bool {
self.inner_t >= self.update_freq
}
pub fn snapshot_params(&self) -> &[f64] {
&self.snapshot_params
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
fn quadratic_grad(params: &[f64]) -> Vec<f64> {
params.iter().map(|&p| 2.0 * p).collect()
}
#[test]
fn test_constant_schedule() {
let s = LrSchedule::Constant(0.01);
assert_abs_diff_eq!(s.lr_at(0), 0.01, epsilon = 1e-14);
assert_abs_diff_eq!(s.lr_at(1000), 0.01, epsilon = 1e-14);
}
#[test]
fn test_exponential_decay_schedule() {
let s = LrSchedule::ExponentialDecay { initial: 0.1, decay: 0.9 };
assert_abs_diff_eq!(s.lr_at(0), 0.1, epsilon = 1e-12);
assert_abs_diff_eq!(s.lr_at(1), 0.09, epsilon = 1e-10);
assert_abs_diff_eq!(s.lr_at(10), 0.1 * 0.9_f64.powi(10), epsilon = 1e-10);
}
#[test]
fn test_cosine_annealing_at_zero() {
let s = LrSchedule::CosineAnnealing { lr_max: 0.1, lr_min: 0.0, t_max: 100 };
assert_abs_diff_eq!(s.lr_at(0), 0.1, epsilon = 1e-10);
}
#[test]
fn test_cosine_annealing_at_t_max() {
let s = LrSchedule::CosineAnnealing { lr_max: 0.1, lr_min: 0.001, t_max: 50 };
assert_abs_diff_eq!(s.lr_at(50), 0.001, epsilon = 1e-10);
}
#[test]
fn test_warmup_cosine_warmup_phase() {
let s = LrSchedule::WarmupCosine {
warmup_steps: 10,
lr_peak: 0.1,
lr_min: 0.0,
total_steps: 110,
};
assert_abs_diff_eq!(s.lr_at(5), 0.05, epsilon = 1e-10);
let lr10 = s.lr_at(10);
assert!(lr10 >= 0.09 && lr10 <= 0.1 + 1e-9, "lr at warmup end ≈ peak, got {}", lr10);
}
#[test]
fn test_step_lr_schedule() {
let s = LrSchedule::StepLr { initial: 0.1, step_size: 10, gamma: 0.5 };
assert_abs_diff_eq!(s.lr_at(0), 0.1, epsilon = 1e-12);
assert_abs_diff_eq!(s.lr_at(9), 0.1, epsilon = 1e-12);
assert_abs_diff_eq!(s.lr_at(10), 0.05, epsilon = 1e-12);
assert_abs_diff_eq!(s.lr_at(20), 0.025, epsilon = 1e-12);
}
#[test]
fn test_sgd_converges_quadratic() {
let mut opt = Sgd::new(0.1, 0.0);
let mut p = vec![1.0, -2.0];
for _ in 0..200 {
let g = quadratic_grad(&p);
opt.step(&mut p, &g).expect("step failed");
}
assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-4);
assert_abs_diff_eq!(p[1], 0.0, epsilon = 1e-4);
}
#[test]
fn test_sgd_momentum_converges() {
let mut opt = Sgd::new(0.05, 0.9);
let mut p = vec![2.0, -1.5];
for _ in 0..500 {
let g = quadratic_grad(&p);
opt.step(&mut p, &g).expect("step failed");
}
assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-3);
assert_abs_diff_eq!(p[1], 0.0, epsilon = 1e-3);
}
#[test]
fn test_sgd_nesterov() {
let mut opt = Sgd { nesterov: true, ..Sgd::new(0.05, 0.9) };
let mut p = vec![1.0, 1.0];
for _ in 0..500 {
let g = quadratic_grad(&p);
opt.step(&mut p, &g).expect("step failed");
}
assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-3);
}
#[test]
fn test_sgd_weight_decay() {
let mut opt = Sgd { weight_decay: 0.1, ..Sgd::new(0.01, 0.0) };
let mut p = vec![1.0];
opt.step(&mut p, &[0.0]).expect("step failed");
assert!(p[0] < 1.0, "weight decay should shrink param");
}
#[test]
fn test_sgd_length_mismatch() {
let mut opt = Sgd::new(0.01, 0.0);
let mut p = vec![1.0, 2.0];
assert!(opt.step(&mut p, &[0.1]).is_err());
}
#[test]
fn test_sgd_zero_velocity() {
let mut opt = Sgd::new(0.01, 0.9);
opt.zero_velocity(5);
assert_eq!(opt.velocity.len(), 5);
assert!(opt.velocity.iter().all(|&v| v == 0.0));
}
#[test]
fn test_adam_converges() {
let mut opt = Adam::new(0.01);
let mut p = vec![3.0, -3.0];
for _ in 0..1000 {
let g = quadratic_grad(&p);
opt.step(&mut p, &g).expect("step failed");
}
assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-2);
assert_abs_diff_eq!(p[1], 0.0, epsilon = 1e-2);
}
#[test]
fn test_adam_reset_state() {
let mut opt = Adam::new(0.01);
let mut p = vec![1.0];
opt.step(&mut p, &[0.5]).expect("step failed");
assert_eq!(opt.t, 1);
opt.reset_state();
assert_eq!(opt.t, 0);
assert!(opt.m.is_empty());
assert!(opt.v.is_empty());
}
#[test]
fn test_adam_weight_decay_coupled() {
let mut opt = Adam { weight_decay: 0.01, ..Adam::new(0.001) };
let mut p = vec![1.0];
let p_before = p[0];
opt.step(&mut p, &[0.0]).expect("step failed");
assert!(p[0] < p_before, "weight decay should reduce param");
}
#[test]
fn test_adagrad_converges() {
let mut opt = AdaGrad::new(0.5);
let mut p = vec![3.0, -2.0];
for _ in 0..2000 {
let g = quadratic_grad(&p);
opt.step(&mut p, &g).expect("step failed");
}
assert!(p[0].abs() < 0.5, "adagrad should converge, p[0]={}", p[0]);
}
#[test]
fn test_adagrad_reset() {
let mut opt = AdaGrad::new(0.1);
let mut p = vec![1.0];
opt.step(&mut p, &[1.0]).expect("step failed");
assert_eq!(opt.sum_sq_grad.len(), 1);
opt.reset_state();
assert!(opt.sum_sq_grad.is_empty());
}
#[test]
fn test_rmsprop_converges() {
let mut opt = RmsProp::new(0.01);
let mut p = vec![2.0, -2.0];
for _ in 0..1000 {
let g = quadratic_grad(&p);
opt.step(&mut p, &g).expect("step failed");
}
assert!(p[0].abs() < 0.1, "rmsprop p[0]={}", p[0]);
}
#[test]
fn test_rmsprop_with_momentum() {
let mut opt = RmsProp { momentum: 0.9, ..RmsProp::new(0.01) };
let mut p = vec![1.0, 1.0];
for _ in 0..500 {
let g = quadratic_grad(&p);
opt.step(&mut p, &g).expect("step failed");
}
assert!(p[0].abs() < 0.5, "rmsprop+momentum p[0]={}", p[0]);
}
#[test]
fn test_rmsprop_length_mismatch() {
let mut opt = RmsProp::new(0.01);
let mut p = vec![1.0, 2.0];
assert!(opt.step(&mut p, &[0.1]).is_err());
}
#[test]
fn test_adamw_decoupled_wd() {
let mut opt = AdamW { weight_decay: 0.1, ..AdamW::new(0.001) };
let mut p = vec![1.0];
let p_before = p[0];
opt.step(&mut p, &[0.0]).expect("step failed");
assert!(p[0] < p_before, "decoupled WD should shrink param");
}
#[test]
fn test_adamw_converges() {
let mut opt = AdamW { weight_decay: 0.0, ..AdamW::new(0.01) };
let mut p = vec![2.0, -2.0];
for _ in 0..1000 {
let g = quadratic_grad(&p);
opt.step(&mut p, &g).expect("step failed");
}
assert!(p[0].abs() < 0.1, "adamw p[0]={}", p[0]);
}
#[test]
fn test_adamw_reset() {
let mut opt = AdamW::new(0.001);
let mut p = vec![1.0];
opt.step(&mut p, &[0.5]).expect("step failed");
assert_eq!(opt.t, 1);
opt.reset_state();
assert_eq!(opt.t, 0);
assert!(opt.m.is_empty());
}
#[test]
fn test_svrg_needs_snapshot() {
let mut svrg = Svrg::new(0.01, 100, 10);
let mut p = vec![1.0, 2.0];
let sg = vec![0.1, 0.2];
let sgi = vec![0.05, 0.1];
assert!(svrg.step(&mut p, &sg, &sgi).is_err());
}
#[test]
fn test_svrg_step_after_snapshot() {
let mut svrg = Svrg::new(0.01, 100, 10);
let mut p = vec![1.0, 1.0];
let full_grad = vec![2.0, 2.0]; svrg.update_snapshot(&p, &full_grad);
let sg = vec![2.1, 1.9];
let sgi = vec![2.0, 2.0];
svrg.step(&mut p, &sg, &sgi).expect("step failed");
assert_abs_diff_eq!(p[0], 1.0 - 0.01 * 2.1, epsilon = 1e-12);
}
#[test]
fn test_svrg_update_freq() {
let mut svrg = Svrg::new(0.01, 100, 3);
let mut p = vec![1.0];
svrg.update_snapshot(&p, &[0.0]);
assert!(!svrg.needs_snapshot_update());
for _ in 0..3 {
svrg.step(&mut p, &[0.0], &[0.0]).expect("step");
}
assert!(svrg.needs_snapshot_update());
}
#[test]
fn test_svrg_snapshot_params() {
let mut svrg = Svrg::new(0.01, 100, 10);
let snap = vec![3.0, 4.0];
svrg.update_snapshot(&snap, &[0.0, 0.0]);
assert_eq!(svrg.snapshot_params(), &[3.0, 4.0]);
}
#[test]
fn test_svrg_length_mismatch() {
let mut svrg = Svrg::new(0.01, 100, 10);
let mut p = vec![1.0, 2.0];
svrg.update_snapshot(&p, &[0.0, 0.0]);
assert!(svrg.step(&mut p, &[0.1], &[0.0, 0.0]).is_err());
}
}