use crate::error::AutogradError;
pub trait Optimizer {
fn step(&mut self, params: &mut Vec<f64>, grads: &[f64]) -> Result<(), AutogradError>;
fn zero_grad(&mut self);
fn learning_rate(&self) -> f64;
}
#[derive(Debug, Clone)]
pub struct Sgd {
pub lr: f64,
pub momentum: f64,
velocity: Vec<f64>,
}
impl Sgd {
pub fn new(lr: f64, momentum: f64) -> Self {
Self {
lr,
momentum,
velocity: Vec::new(),
}
}
}
impl Optimizer for Sgd {
fn step(&mut self, params: &mut Vec<f64>, grads: &[f64]) -> Result<(), AutogradError> {
if params.len() != grads.len() {
return Err(AutogradError::ShapeMismatch(format!(
"SGD: params length {} != grads length {}",
params.len(),
grads.len()
)));
}
let n = params.len();
if self.velocity.len() != n {
self.velocity = vec![0.0f64; n];
}
for i in 0..n {
self.velocity[i] = self.momentum * self.velocity[i] + grads[i];
params[i] -= self.lr * self.velocity[i];
}
Ok(())
}
fn zero_grad(&mut self) {
for v in self.velocity.iter_mut() {
*v = 0.0;
}
}
fn learning_rate(&self) -> f64 {
self.lr
}
}
#[derive(Debug, Clone)]
pub struct AdamOptimizer {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
m: Vec<f64>,
v: Vec<f64>,
t: usize,
}
impl AdamOptimizer {
pub fn new(lr: f64) -> Self {
Self {
lr,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.0,
m: Vec::new(),
v: Vec::new(),
t: 0,
}
}
pub fn with_params(lr: f64, beta1: f64, beta2: f64, eps: f64) -> Self {
Self {
lr,
beta1,
beta2,
eps,
weight_decay: 0.0,
m: Vec::new(),
v: Vec::new(),
t: 0,
}
}
}
impl Optimizer for AdamOptimizer {
fn step(&mut self, params: &mut Vec<f64>, grads: &[f64]) -> Result<(), AutogradError> {
if params.len() != grads.len() {
return Err(AutogradError::ShapeMismatch(format!(
"Adam: params length {} != grads length {}",
params.len(),
grads.len()
)));
}
let n = params.len();
if self.m.len() != n {
self.m = vec![0.0f64; n];
self.v = vec![0.0f64; n];
}
self.t += 1;
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
for i in 0..n {
let mut g = grads[i];
if self.weight_decay != 0.0 {
g += self.weight_decay * params[i];
}
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
let m_hat = self.m[i] / bc1;
let v_hat = self.v[i] / bc2;
params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
}
Ok(())
}
fn zero_grad(&mut self) {
for m in self.m.iter_mut() {
*m = 0.0;
}
for v in self.v.iter_mut() {
*v = 0.0;
}
self.t = 0;
}
fn learning_rate(&self) -> f64 {
self.lr
}
}
#[derive(Debug, Clone)]
pub struct RmsProp {
pub lr: f64,
pub alpha: f64,
pub eps: f64,
pub momentum: f64,
cache: Vec<f64>,
velocity: Vec<f64>,
}
impl RmsProp {
pub fn new(lr: f64) -> Self {
Self {
lr,
alpha: 0.99,
eps: 1e-8,
momentum: 0.0,
cache: Vec::new(),
velocity: Vec::new(),
}
}
}
impl Optimizer for RmsProp {
fn step(&mut self, params: &mut Vec<f64>, grads: &[f64]) -> Result<(), AutogradError> {
if params.len() != grads.len() {
return Err(AutogradError::ShapeMismatch(format!(
"RMSprop: params length {} != grads length {}",
params.len(),
grads.len()
)));
}
let n = params.len();
if self.cache.len() != n {
self.cache = vec![0.0f64; n];
self.velocity = vec![0.0f64; n];
}
for i in 0..n {
let g = grads[i];
self.cache[i] = self.alpha * self.cache[i] + (1.0 - self.alpha) * g * g;
let rms = (self.cache[i] + self.eps).sqrt();
let update = self.lr * g / rms;
if self.momentum != 0.0 {
self.velocity[i] = self.momentum * self.velocity[i] + update;
params[i] -= self.velocity[i];
} else {
params[i] -= update;
}
}
Ok(())
}
fn zero_grad(&mut self) {
for c in self.cache.iter_mut() {
*c = 0.0;
}
for v in self.velocity.iter_mut() {
*v = 0.0;
}
}
fn learning_rate(&self) -> f64 {
self.lr
}
}
#[derive(Debug, Clone)]
pub struct Adagrad {
pub lr: f64,
pub eps: f64,
sum_sq_grad: Vec<f64>,
}
impl Adagrad {
pub fn new(lr: f64) -> Self {
Self {
lr,
eps: 1e-8,
sum_sq_grad: Vec::new(),
}
}
}
impl Optimizer for Adagrad {
fn step(&mut self, params: &mut Vec<f64>, grads: &[f64]) -> Result<(), AutogradError> {
if params.len() != grads.len() {
return Err(AutogradError::ShapeMismatch(format!(
"Adagrad: params length {} != grads length {}",
params.len(),
grads.len()
)));
}
let n = params.len();
if self.sum_sq_grad.len() != n {
self.sum_sq_grad = vec![0.0f64; n];
}
for i in 0..n {
let g = grads[i];
self.sum_sq_grad[i] += g * g;
let lr_scaled = self.lr / (self.sum_sq_grad[i].sqrt() + self.eps);
params[i] -= lr_scaled * g;
}
Ok(())
}
fn zero_grad(&mut self) {
for s in self.sum_sq_grad.iter_mut() {
*s = 0.0;
}
}
fn learning_rate(&self) -> f64 {
self.lr
}
}
#[derive(Debug, Clone)]
pub struct CosineAnnealingSchedule {
pub lr_max: f64,
pub lr_min: f64,
pub t_max: usize,
pub step: usize,
}
impl CosineAnnealingSchedule {
pub fn new(lr_max: f64, lr_min: f64, t_max: usize) -> Self {
Self {
lr_max,
lr_min,
t_max,
step: 0,
}
}
pub fn get_lr(&self) -> f64 {
if self.t_max == 0 {
return self.lr_max;
}
let t = self.step as f64;
let t_max = self.t_max as f64;
let cos_val = (std::f64::consts::PI * t / t_max).cos();
self.lr_min + (self.lr_max - self.lr_min) * (1.0 + cos_val) / 2.0
}
pub fn step(&mut self) {
self.step += 1;
}
}
#[derive(Debug, Clone)]
pub struct OneCycleSchedule {
pub max_lr: f64,
pub total_steps: usize,
pub pct_start: f64,
pub step: usize,
div_factor: f64,
final_div_factor: f64,
}
impl OneCycleSchedule {
pub fn new(max_lr: f64, total_steps: usize) -> Self {
Self {
max_lr,
total_steps,
pct_start: 0.3,
step: 0,
div_factor: 25.0,
final_div_factor: 1e4,
}
}
pub fn get_lr(&self) -> f64 {
if self.total_steps == 0 {
return self.max_lr;
}
let warmup_steps = (self.pct_start * self.total_steps as f64) as usize;
let start_lr = self.max_lr / self.div_factor;
let min_lr = self.max_lr / self.final_div_factor;
let t = self.step;
if t <= warmup_steps {
if warmup_steps == 0 {
return self.max_lr;
}
let progress = t as f64 / warmup_steps as f64;
start_lr + (self.max_lr - start_lr) * progress
} else {
let anneal_steps = self.total_steps.saturating_sub(warmup_steps);
if anneal_steps == 0 {
return min_lr;
}
let progress = (t - warmup_steps) as f64 / anneal_steps as f64;
let cos_val = (std::f64::consts::PI * progress.min(1.0)).cos();
min_lr + (self.max_lr - min_lr) * (1.0 + cos_val) / 2.0
}
}
pub fn step(&mut self) {
self.step += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1e-6;
#[test]
fn test_sgd_single_step_vanilla() {
let mut sgd = Sgd::new(0.1, 0.0);
let mut p = vec![1.0, -1.0];
sgd.step(&mut p, &[2.0, -4.0]).expect("sgd step");
assert!((p[0] - (1.0 - 0.1 * 2.0)).abs() < TOL);
assert!((p[1] - (-1.0 - 0.1 * (-4.0))).abs() < TOL);
}
#[test]
fn test_sgd_reduces_loss_quadratic() {
let mut sgd = Sgd::new(0.01, 0.9);
let mut p = vec![5.0];
for _ in 0..500 {
let g = vec![2.0 * p[0]];
sgd.step(&mut p, &g).expect("sgd step");
}
assert!(p[0].abs() < 0.5, "SGD did not converge, p[0] = {}", p[0]);
}
#[test]
fn test_sgd_dimension_mismatch_error() {
let mut sgd = Sgd::new(0.1, 0.0);
let mut p = vec![1.0, 2.0];
let result = sgd.step(&mut p, &[1.0]);
assert!(result.is_err());
}
#[test]
fn test_sgd_zero_grad_resets_velocity() {
let mut sgd = Sgd::new(0.1, 0.9);
let mut p = vec![1.0];
sgd.step(&mut p, &[1.0]).expect("step");
assert_ne!(sgd.velocity, vec![0.0]);
sgd.zero_grad();
assert_eq!(sgd.velocity, vec![0.0]);
}
#[test]
fn test_adam_single_step() {
let mut adam = AdamOptimizer::new(0.01);
let mut p = vec![1.0];
let p_before = p[0];
adam.step(&mut p, &[1.0]).expect("adam step");
assert!(p[0] < p_before, "Adam should decrease p, got {}", p[0]);
}
#[test]
fn test_adam_converges_on_quadratic() {
let mut adam = AdamOptimizer::new(0.05);
let mut p = vec![3.0, -3.0];
for _ in 0..500 {
let g = vec![2.0 * p[0], 2.0 * p[1]];
adam.step(&mut p, &g).expect("adam step");
}
assert!(p[0].abs() < 0.1, "Adam p[0] did not converge: {}", p[0]);
assert!(p[1].abs() < 0.1, "Adam p[1] did not converge: {}", p[1]);
}
#[test]
fn test_adam_converges_faster_than_sgd_on_rosenbrock() {
let rosenbrock_grad = |x: &[f64]| -> Vec<f64> {
let dx = -2.0 * (1.0 - x[0]) - 400.0 * x[0] * (x[1] - x[0] * x[0]);
let dy = 200.0 * (x[1] - x[0] * x[0]);
vec![dx, dy]
};
let start = vec![-1.0, 1.0];
let mut adam = AdamOptimizer::new(0.001);
let mut pa = start.clone();
for _ in 0..200 {
let g = rosenbrock_grad(&pa);
adam.step(&mut pa, &g).expect("adam step");
}
let adam_dist = (pa[0] - 1.0).abs() + (pa[1] - 1.0).abs();
let mut sgd = Sgd::new(0.001, 0.0);
let mut ps = start.clone();
for _ in 0..200 {
let g = rosenbrock_grad(&ps);
sgd.step(&mut ps, &g).expect("sgd step");
}
let sgd_dist = (ps[0] - 1.0).abs() + (ps[1] - 1.0).abs();
assert!(
adam_dist < sgd_dist,
"Adam dist {adam_dist:.4} should be < SGD dist {sgd_dist:.4}"
);
}
#[test]
fn test_adam_dimension_mismatch_error() {
let mut adam = AdamOptimizer::new(0.01);
let mut p = vec![1.0, 2.0];
let result = adam.step(&mut p, &[1.0]);
assert!(result.is_err());
}
#[test]
fn test_adam_with_params() {
let adam = AdamOptimizer::with_params(0.001, 0.95, 0.9999, 1e-7);
assert_eq!(adam.beta1, 0.95);
assert_eq!(adam.beta2, 0.9999);
assert_eq!(adam.eps, 1e-7);
}
#[test]
fn test_adam_zero_grad_resets_state() {
let mut adam = AdamOptimizer::new(0.01);
let mut p = vec![1.0];
adam.step(&mut p, &[1.0]).expect("step");
assert_eq!(adam.t, 1);
adam.zero_grad();
assert_eq!(adam.t, 0);
assert_eq!(adam.m, vec![0.0]);
assert_eq!(adam.v, vec![0.0]);
}
#[test]
fn test_rmsprop_single_step() {
let mut rms = RmsProp::new(0.01);
let mut p = vec![2.0];
let p_before = p[0];
rms.step(&mut p, &[4.0]).expect("rmsprop step");
assert!(p[0] < p_before);
}
#[test]
fn test_rmsprop_converges() {
let mut rms = RmsProp::new(0.01);
let mut p = vec![3.0];
for _ in 0..500 {
let grad = 2.0 * p[0];
rms.step(&mut p, &[grad]).expect("step");
}
assert!(p[0].abs() < 1.0, "RMSprop did not converge: {}", p[0]);
}
#[test]
fn test_rmsprop_dimension_mismatch_error() {
let mut rms = RmsProp::new(0.01);
let mut p = vec![1.0];
let result = rms.step(&mut p, &[1.0, 2.0]);
assert!(result.is_err());
}
#[test]
fn test_adagrad_single_step() {
let mut ada = Adagrad::new(0.1);
let mut p = vec![1.0];
ada.step(&mut p, &[2.0]).expect("adagrad step");
let expected = 1.0 - 0.1 / (4.0_f64.sqrt() + 1e-8) * 2.0;
assert!((p[0] - expected).abs() < 1e-6);
}
#[test]
fn test_adagrad_accumulates_squared_grads() {
let mut ada = Adagrad::new(0.1);
let mut p = vec![5.0];
let lr0_effective;
{
let mut p0 = p.clone();
ada.step(&mut p0, &[1.0]).expect("step");
lr0_effective = 5.0 - p0[0]; }
ada.zero_grad(); ada.step(&mut p, &[1.0]).expect("step 1");
ada.step(&mut p, &[1.0]).expect("step 2");
let lr1 = 5.0 - p[0] - lr0_effective;
assert!(lr1 < lr0_effective, "lr0={lr0_effective}, lr1={lr1}");
}
#[test]
fn test_adagrad_dimension_mismatch_error() {
let mut ada = Adagrad::new(0.1);
let mut p = vec![1.0, 2.0];
let result = ada.step(&mut p, &[1.0]);
assert!(result.is_err());
}
#[test]
fn test_cosine_annealing_starts_at_max() {
let sched = CosineAnnealingSchedule::new(0.1, 0.001, 100);
assert!((sched.get_lr() - 0.1).abs() < TOL);
}
#[test]
fn test_cosine_annealing_ends_at_min() {
let mut sched = CosineAnnealingSchedule::new(0.1, 0.001, 100);
for _ in 0..100 {
sched.step();
}
assert!((sched.get_lr() - 0.001).abs() < 1e-6);
}
#[test]
fn test_cosine_annealing_is_monotone_decreasing() {
let mut sched = CosineAnnealingSchedule::new(0.1, 0.001, 100);
let mut prev = sched.get_lr();
for _ in 0..100 {
sched.step();
let cur = sched.get_lr();
assert!(
cur <= prev + 1e-12,
"LR increased from {prev} to {cur} at step {}",
sched.step
);
prev = cur;
}
}
#[test]
fn test_cosine_annealing_midpoint() {
let sched = CosineAnnealingSchedule::new(1.0, 0.0, 100);
let mut s = CosineAnnealingSchedule::new(1.0, 0.0, 100);
for _ in 0..50 {
s.step();
}
assert!((s.get_lr() - 0.5).abs() < 1e-6, "got {}", s.get_lr());
let _ = sched; }
#[test]
fn test_one_cycle_starts_below_max() {
let sched = OneCycleSchedule::new(0.1, 100);
assert!(sched.get_lr() < sched.max_lr);
}
#[test]
fn test_one_cycle_peaks_near_max_at_warmup_end() {
let mut sched = OneCycleSchedule::new(0.1, 100);
for _ in 0..30 {
sched.step();
}
let peak = sched.get_lr();
assert!(
(peak - 0.1).abs() < 1e-9,
"peak lr should be max_lr, got {}",
peak
);
}
#[test]
fn test_one_cycle_decreases_after_warmup() {
let mut sched = OneCycleSchedule::new(0.1, 100);
for _ in 0..30 {
sched.step();
}
let peak = sched.get_lr();
for _ in 0..40 {
sched.step();
}
let later = sched.get_lr();
assert!(later < peak, "lr should decrease after peak: {} vs {}", later, peak);
}
#[test]
fn test_one_cycle_increases_during_warmup() {
let mut sched = OneCycleSchedule::new(0.1, 100);
let lr0 = sched.get_lr();
sched.step();
let lr1 = sched.get_lr();
sched.step();
let lr2 = sched.get_lr();
assert!(lr1 > lr0, "LR should increase during warmup");
assert!(lr2 > lr1, "LR should increase during warmup");
}
#[test]
fn test_optimizer_learning_rate_accessor() {
let sgd = Sgd::new(0.05, 0.0);
assert_eq!(sgd.learning_rate(), 0.05);
let adam = AdamOptimizer::new(0.003);
assert_eq!(adam.learning_rate(), 0.003);
let rms = RmsProp::new(0.01);
assert_eq!(rms.learning_rate(), 0.01);
let ada = Adagrad::new(0.1);
assert_eq!(ada.learning_rate(), 0.1);
}
}