use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::Result;
use crate::optimizers::Optimizer;
#[derive(Debug, Clone)]
pub struct AdamW<A: Float + ScalarOperand + Debug> {
learning_rate: A,
beta1: A,
beta2: A,
epsilon: A,
weight_decay: A,
m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
v: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
t: usize,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> AdamW<A> {
pub fn new(learning_rate: A) -> Self {
Self {
learning_rate,
beta1: A::from(0.9).expect("unwrap failed"),
beta2: A::from(0.999).expect("unwrap failed"),
epsilon: A::from(1e-8).expect("unwrap failed"),
weight_decay: A::from(0.01).expect("unwrap failed"), m: None,
v: None,
t: 0,
}
}
pub fn new_with_config(
learning_rate: A,
beta1: A,
beta2: A,
epsilon: A,
weight_decay: A,
) -> Self {
Self {
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
m: None,
v: None,
t: 0,
}
}
pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
self.beta1 = beta1;
self
}
pub fn get_beta1(&self) -> A {
self.beta1
}
pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
self.beta2 = beta2;
self
}
pub fn get_beta2(&self) -> A {
self.beta2
}
pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
self.epsilon = epsilon;
self
}
pub fn get_epsilon(&self) -> A {
self.epsilon
}
pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
self.weight_decay = weight_decay;
self
}
pub fn get_weight_decay(&self) -> A {
self.weight_decay
}
pub fn learning_rate(&self) -> A {
self.learning_rate
}
pub fn set_lr(&mut self, lr: A) {
self.learning_rate = lr;
}
pub fn reset(&mut self) {
self.m = None;
self.v = None;
self.t = 0;
}
}
impl<A, D> Optimizer<A, D> for AdamW<A>
where
A: Float + ScalarOperand + Debug + Send + Sync,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let params_dyn = params.to_owned().into_dyn();
let gradients_dyn = gradients.to_owned().into_dyn();
if self.m.is_none() {
self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
self.v = Some(vec![Array::zeros(params_dyn.raw_dim())]);
self.t = 0;
}
let m = self.m.as_mut().expect("unwrap failed");
let v = self.v.as_mut().expect("unwrap failed");
if m.is_empty() {
m.push(Array::zeros(params_dyn.raw_dim()));
v.push(Array::zeros(params_dyn.raw_dim()));
} else if m[0].raw_dim() != params_dyn.raw_dim() {
m[0] = Array::zeros(params_dyn.raw_dim());
v[0] = Array::zeros(params_dyn.raw_dim());
}
self.t += 1;
m[0] = &m[0] * self.beta1 + &(&gradients_dyn * (A::one() - self.beta1));
v[0] = &v[0] * self.beta2 + &(&gradients_dyn * &gradients_dyn * (A::one() - self.beta2));
let m_hat = &m[0] / (A::one() - self.beta1.powi(self.t as i32));
let v_hat = &v[0] / (A::one() - self.beta2.powi(self.t as i32));
let v_hat_sqrt = v_hat.mapv(|x| x.sqrt());
let weight_decay_factor = A::one() - self.learning_rate * self.weight_decay;
let weight_decayed_params = ¶ms_dyn * weight_decay_factor;
let step = &m_hat / &(&v_hat_sqrt + self.epsilon) * self.learning_rate;
let updated_params = &weight_decayed_params - step;
Ok(updated_params
.into_dimensionality::<D>()
.expect("unwrap failed"))
}
fn get_learning_rate(&self) -> A {
self.learning_rate
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.learning_rate = learning_rate;
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_adamw_step() {
let params = Array1::zeros(3);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let mut optimizer = AdamW::new(0.01);
let new_params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
assert!(new_params.iter().all(|&x| x != 0.0));
for param in new_params.iter() {
assert!(*param < 0.0);
}
}
#[test]
fn test_adamw_multiple_steps() {
let mut params = Array1::zeros(3);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let mut optimizer = AdamW::new_with_config(
0.01, 0.9, 0.999, 1e-8, 0.1, );
for _ in 0..10 {
params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
}
for (i, param) in params.iter().enumerate() {
assert!(*param < 0.0);
if i > 0 {
assert!(param < ¶ms[i - 1]);
}
}
}
#[test]
fn test_adamw_reset() {
let params = Array1::zeros(3);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let mut optimizer = AdamW::new(0.01);
optimizer.step(¶ms, &gradients).expect("unwrap failed");
assert_eq!(optimizer.t, 1);
assert!(optimizer.m.is_some());
assert!(optimizer.v.is_some());
optimizer.reset();
assert_eq!(optimizer.t, 0);
assert!(optimizer.m.is_none());
assert!(optimizer.v.is_none());
}
}