use crate::error::OptimizeError;
#[derive(Debug, Clone)]
pub struct AdamOptimizer {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
pub amsgrad: bool,
pub m: Vec<f64>,
pub v: Vec<f64>,
pub v_max: Vec<f64>,
pub t: usize,
}
impl AdamOptimizer {
pub fn new(lr: f64, beta1: f64, beta2: f64, eps: f64, weight_decay: f64, amsgrad: bool) -> Self {
Self {
lr,
beta1,
beta2,
eps,
weight_decay,
amsgrad,
m: Vec::new(),
v: Vec::new(),
v_max: Vec::new(),
t: 0,
}
}
pub fn default_params(lr: f64) -> Self {
Self::new(lr, 0.9, 0.999, 1e-8, 0.0, false)
}
fn ensure_init(&mut self, n: usize) {
if self.m.len() != n {
self.m = vec![0.0; n];
self.v = vec![0.0; n];
self.v_max = vec![0.0; n];
self.t = 0;
}
}
pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> Result<(), OptimizeError> {
let n = params.len();
if grad.len() != n {
return Err(OptimizeError::ValueError(format!(
"params length {} != grad length {}",
n,
grad.len()
)));
}
self.ensure_init(n);
self.t += 1;
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
let lr_t = self.lr * bc2.sqrt() / bc1;
for i in 0..n {
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * grad[i];
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * grad[i] * grad[i];
let v_hat = if self.amsgrad {
self.v_max[i] = self.v_max[i].max(self.v[i]);
self.v_max[i]
} else {
self.v[i]
};
params[i] -= lr_t * self.m[i] / (v_hat.sqrt() + self.eps)
+ self.lr * self.weight_decay * params[i];
}
Ok(())
}
pub fn reset(&mut self) {
self.m.clear();
self.v.clear();
self.v_max.clear();
self.t = 0;
}
}
#[derive(Debug, Clone)]
pub struct AdamWOptimizer {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
pub amsgrad: bool,
m: Vec<f64>,
v: Vec<f64>,
v_max: Vec<f64>,
t: usize,
}
impl AdamWOptimizer {
pub fn new(lr: f64, beta1: f64, beta2: f64, eps: f64, weight_decay: f64, amsgrad: bool) -> Self {
Self {
lr,
beta1,
beta2,
eps,
weight_decay,
amsgrad,
m: Vec::new(),
v: Vec::new(),
v_max: Vec::new(),
t: 0,
}
}
pub fn default_params(lr: f64) -> Self {
Self::new(lr, 0.9, 0.999, 1e-8, 0.01, false)
}
fn ensure_init(&mut self, n: usize) {
if self.m.len() != n {
self.m = vec![0.0; n];
self.v = vec![0.0; n];
self.v_max = vec![0.0; n];
self.t = 0;
}
}
pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> Result<(), OptimizeError> {
let n = params.len();
if grad.len() != n {
return Err(OptimizeError::ValueError(format!(
"params length {} != grad length {}",
n,
grad.len()
)));
}
self.ensure_init(n);
self.t += 1;
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
let step_size = self.lr * bc2.sqrt() / bc1;
for i in 0..n {
params[i] *= 1.0 - self.lr * self.weight_decay;
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * grad[i];
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * grad[i] * grad[i];
let v_hat = if self.amsgrad {
self.v_max[i] = self.v_max[i].max(self.v[i]);
self.v_max[i]
} else {
self.v[i]
};
params[i] -= step_size * self.m[i] / (v_hat.sqrt() + self.eps);
}
Ok(())
}
pub fn reset(&mut self) {
self.m.clear();
self.v.clear();
self.v_max.clear();
self.t = 0;
}
}
#[derive(Debug, Clone)]
pub struct NAdamOptimizer {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
m: Vec<f64>,
v: Vec<f64>,
t: usize,
}
impl NAdamOptimizer {
pub fn new(lr: f64, beta1: f64, beta2: f64, eps: f64, weight_decay: f64) -> Self {
Self {
lr,
beta1,
beta2,
eps,
weight_decay,
m: Vec::new(),
v: Vec::new(),
t: 0,
}
}
pub fn default_params(lr: f64) -> Self {
Self::new(lr, 0.9, 0.999, 1e-8, 0.0)
}
fn ensure_init(&mut self, n: usize) {
if self.m.len() != n {
self.m = vec![0.0; n];
self.v = vec![0.0; n];
self.t = 0;
}
}
pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> Result<(), OptimizeError> {
let n = params.len();
if grad.len() != n {
return Err(OptimizeError::ValueError(format!(
"params length {} != grad length {}",
n,
grad.len()
)));
}
self.ensure_init(n);
self.t += 1;
let bc1_t = 1.0 - self.beta1.powi(self.t as i32);
let bc1_t1 = 1.0 - self.beta1.powi(self.t as i32 + 1);
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
for i in 0..n {
let g = grad[i] + self.weight_decay * params[i];
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
let m_hat_n = self.beta1 * self.m[i] / bc1_t1 + (1.0 - self.beta1) * g / bc1_t;
let v_hat = self.v[i] / bc2;
params[i] -= self.lr * m_hat_n / (v_hat.sqrt() + self.eps);
}
Ok(())
}
pub fn reset(&mut self) {
self.m.clear();
self.v.clear();
self.t = 0;
}
}
#[derive(Debug, Clone)]
pub struct RAdamOptimizer {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
m: Vec<f64>,
v: Vec<f64>,
t: usize,
}
impl RAdamOptimizer {
pub fn new(lr: f64, beta1: f64, beta2: f64, eps: f64, weight_decay: f64) -> Self {
Self {
lr,
beta1,
beta2,
eps,
weight_decay,
m: Vec::new(),
v: Vec::new(),
t: 0,
}
}
pub fn default_params(lr: f64) -> Self {
Self::new(lr, 0.9, 0.999, 1e-8, 0.0)
}
fn ensure_init(&mut self, n: usize) {
if self.m.len() != n {
self.m = vec![0.0; n];
self.v = vec![0.0; n];
self.t = 0;
}
}
pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> Result<(), OptimizeError> {
let n = params.len();
if grad.len() != n {
return Err(OptimizeError::ValueError(format!(
"params length {} != grad length {}",
n,
grad.len()
)));
}
self.ensure_init(n);
self.t += 1;
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
let rho_max = 2.0 / (1.0 - self.beta2) - 1.0;
let rho_t = rho_max - 2.0 * self.t as f64 * self.beta2.powi(self.t as i32) / bc2;
for i in 0..n {
let g = grad[i] + self.weight_decay * params[i];
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
let m_hat = self.m[i] / bc1;
if rho_t > 4.0 {
let rect = ((rho_t - 4.0) * (rho_t - 2.0) * rho_max
/ ((rho_max - 4.0) * (rho_max - 2.0) * rho_t))
.sqrt();
let v_hat = (self.v[i] / bc2).sqrt();
params[i] -= self.lr * rect * m_hat / (v_hat + self.eps);
} else {
params[i] -= self.lr * m_hat;
}
}
Ok(())
}
pub fn reset(&mut self) {
self.m.clear();
self.v.clear();
self.t = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
fn quadratic_grad(x: &[f64]) -> Vec<f64> {
x.iter().map(|&xi| 2.0 * xi).collect()
}
#[test]
fn test_adam_converges_quadratic() {
let mut opt = AdamOptimizer::default_params(0.01);
let mut params = vec![3.0, -2.0, 1.5];
for _ in 0..1000 {
let g = quadratic_grad(¶ms);
opt.step(&mut params, &g).expect("step failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 1e-3);
}
}
#[test]
fn test_adam_amsgrad_converges() {
let mut opt = AdamOptimizer::new(0.01, 0.9, 0.999, 1e-8, 0.0, true);
let mut params = vec![2.0, -1.0];
for _ in 0..1000 {
let g = quadratic_grad(¶ms);
opt.step(&mut params, &g).expect("step failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 1e-3);
}
}
#[test]
fn test_adamw_decoupled_decay() {
let mut opt = AdamWOptimizer::new(0.01, 0.9, 0.999, 1e-8, 0.1, false);
let mut params = vec![1.0];
let g = vec![0.0];
for _ in 0..100 {
opt.step(&mut params, &g).expect("step failed");
}
assert!(params[0] < 1.0, "weight decay should reduce params");
}
#[test]
fn test_adamw_converges_quadratic() {
let mut opt = AdamWOptimizer::default_params(0.01);
let mut params = vec![2.0, -1.5];
for _ in 0..1000 {
let g = quadratic_grad(¶ms);
opt.step(&mut params, &g).expect("step failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 0.05);
}
}
#[test]
fn test_nadam_converges_quadratic() {
let mut opt = NAdamOptimizer::default_params(0.01);
let mut params = vec![2.0, -1.0];
for _ in 0..1000 {
let g = quadratic_grad(¶ms);
opt.step(&mut params, &g).expect("step failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 1e-3);
}
}
#[test]
fn test_radam_converges_quadratic() {
let mut opt = RAdamOptimizer::default_params(0.01);
let mut params = vec![3.0, -2.0];
for _ in 0..1000 {
let g = quadratic_grad(¶ms);
opt.step(&mut params, &g).expect("step failed");
}
for &p in ¶ms {
assert_abs_diff_eq!(p, 0.0, epsilon = 1e-3);
}
}
#[test]
fn test_adam_length_mismatch() {
let mut opt = AdamOptimizer::default_params(0.01);
let mut params = vec![1.0, 2.0];
let grad = vec![0.1]; assert!(opt.step(&mut params, &grad).is_err());
}
#[test]
fn test_adam_reset() {
let mut opt = AdamOptimizer::default_params(0.01);
let mut params = vec![1.0, -1.0];
let g = quadratic_grad(¶ms);
opt.step(&mut params, &g).expect("step failed");
assert_eq!(opt.t, 1);
opt.reset();
assert_eq!(opt.t, 0);
assert!(opt.m.is_empty());
}
}