use crate::error::OptimizeError;
#[derive(Debug, Clone)]
pub struct SgdOptimizer {
pub lr: f64,
pub momentum: f64,
pub nesterov: bool,
pub weight_decay: f64,
velocity: Vec<f64>,
}
impl SgdOptimizer {
pub fn new(lr: f64, momentum: f64, nesterov: bool, weight_decay: f64) -> Self {
Self {
lr,
momentum,
nesterov,
weight_decay,
velocity: Vec::new(),
}
}
pub fn vanilla(lr: f64) -> Self {
Self::new(lr, 0.0, false, 0.0)
}
pub fn step(
&mut self,
params: &mut Vec<f64>,
grad: &[f64],
) -> Result<(), OptimizeError> {
let n = params.len();
if grad.len() != n {
return Err(OptimizeError::ValueError(format!(
"params length {} != grad length {}",
n,
grad.len()
)));
}
if self.velocity.len() != n {
self.velocity = vec![0.0; n];
}
for i in 0..n {
let g = grad[i] + self.weight_decay * params[i];
if self.momentum == 0.0 {
params[i] -= self.lr * g;
} else {
self.velocity[i] = self.momentum * self.velocity[i] + g;
if self.nesterov {
params[i] -= self.lr * (g + self.momentum * self.velocity[i]);
} else {
params[i] -= self.lr * self.velocity[i];
}
}
}
Ok(())
}
pub fn reset(&mut self) {
self.velocity.clear();
}
}
#[derive(Debug, Clone)]
pub struct AdaGradOptimizer {
pub lr: f64,
pub eps: f64,
pub accum: Vec<f64>,
}
impl AdaGradOptimizer {
pub fn new(lr: f64, eps: f64) -> Self {
Self {
lr,
eps,
accum: Vec::new(),
}
}
pub fn default_params(lr: f64) -> Self {
Self::new(lr, 1e-8)
}
pub fn step(
&mut self,
params: &mut Vec<f64>,
grad: &[f64],
) -> Result<(), OptimizeError> {
let n = params.len();
if grad.len() != n {
return Err(OptimizeError::ValueError(format!(
"params length {} != grad length {}",
n,
grad.len()
)));
}
if self.accum.len() != n {
self.accum = vec![0.0; n];
}
for i in 0..n {
self.accum[i] += grad[i] * grad[i];
params[i] -= self.lr / (self.accum[i].sqrt() + self.eps) * grad[i];
}
Ok(())
}
pub fn reset(&mut self) {
self.accum.clear();
}
}
#[derive(Debug, Clone)]
pub struct AdaDeltaOptimizer {
pub rho: f64,
pub eps: f64,
pub accum_grad: Vec<f64>,
pub accum_update: Vec<f64>,
}
impl AdaDeltaOptimizer {
pub fn new(rho: f64, eps: f64) -> Self {
Self {
rho,
eps,
accum_grad: Vec::new(),
accum_update: Vec::new(),
}
}
pub fn default_params() -> Self {
Self::new(0.95, 1e-6)
}
pub fn step(
&mut self,
params: &mut Vec<f64>,
grad: &[f64],
) -> Result<(), OptimizeError> {
let n = params.len();
if grad.len() != n {
return Err(OptimizeError::ValueError(format!(
"params length {} != grad length {}",
n,
grad.len()
)));
}
if self.accum_grad.len() != n {
self.accum_grad = vec![0.0; n];
self.accum_update = vec![0.0; n];
}
for i in 0..n {
self.accum_grad[i] =
self.rho * self.accum_grad[i] + (1.0 - self.rho) * grad[i] * grad[i];
let rms_update = (self.accum_update[i] + self.eps).sqrt();
let rms_grad = (self.accum_grad[i] + self.eps).sqrt();
let delta = -(rms_update / rms_grad) * grad[i];
self.accum_update[i] =
self.rho * self.accum_update[i] + (1.0 - self.rho) * delta * delta;
params[i] += delta;
}
Ok(())
}
pub fn reset(&mut self) {
self.accum_grad.clear();
self.accum_update.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
fn quadratic_grad(x: &[f64]) -> Vec<f64> {
x.iter().map(|&xi| 2.0 * xi).collect()
}
#[test]
fn test_sgd_vanilla_converges() {
let mut opt = SgdOptimizer::vanilla(0.1);
let mut params = vec![1.0, -2.0, 0.5];
for _ in 0..200 {
let g = quadratic_grad(¶ms);
opt.step(&mut params, &g).expect("step failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 1e-4);
}
}
#[test]
fn test_sgd_momentum_converges() {
let mut opt = SgdOptimizer::new(0.05, 0.9, false, 0.0);
let mut params = vec![2.0, -1.5];
for _ in 0..300 {
let g = quadratic_grad(¶ms);
opt.step(&mut params, &g).expect("step failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 1e-3);
}
}
#[test]
fn test_sgd_nesterov_converges() {
let mut opt = SgdOptimizer::new(0.05, 0.9, true, 0.0);
let mut params = vec![1.5, -1.0];
for _ in 0..300 {
let g = quadratic_grad(¶ms);
opt.step(&mut params, &g).expect("step failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 1e-3);
}
}
#[test]
fn test_sgd_weight_decay() {
let mut opt = SgdOptimizer::new(0.01, 0.0, false, 0.1);
let mut params = vec![1.0];
let init = params[0];
let g = vec![0.0]; opt.step(&mut params, &g).expect("step failed");
assert!(params[0] < init, "weight decay should reduce param");
}
#[test]
fn test_sgd_length_mismatch() {
let mut opt = SgdOptimizer::vanilla(0.1);
let mut params = vec![1.0, 2.0];
let grad = vec![0.1]; assert!(opt.step(&mut params, &grad).is_err());
}
#[test]
fn test_adagrad_converges() {
let mut opt = AdaGradOptimizer::default_params(0.5);
let mut params = vec![3.0, -2.0];
for _ in 0..500 {
let g = quadratic_grad(¶ms);
opt.step(&mut params, &g).expect("step failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 0.1);
}
}
#[test]
fn test_adadelta_converges() {
let mut opt = AdaDeltaOptimizer::default_params();
let mut params = vec![2.0, -1.0];
for _ in 0..2000 {
let g = quadratic_grad(¶ms);
opt.step(&mut params, &g).expect("step failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 0.5);
}
}
#[test]
fn test_adadelta_length_mismatch() {
let mut opt = AdaDeltaOptimizer::default_params();
let mut params = vec![1.0, 2.0];
let grad = vec![0.1, 0.2, 0.3]; assert!(opt.step(&mut params, &grad).is_err());
}
}