use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::Result;
use crate::optimizers::Optimizer;
#[derive(Debug, Clone)]
pub struct RAdam<A: Float + ScalarOperand + Debug> {
learning_rate: A,
beta1: A,
beta2: A,
epsilon: A,
weight_decay: A,
m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
v: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
t: usize,
rho_inf: A,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> RAdam<A> {
pub fn new(learning_rate: A) -> Self {
let beta2 = A::from(0.999).expect("unwrap failed");
Self {
learning_rate,
beta1: A::from(0.9).expect("unwrap failed"),
beta2,
epsilon: A::from(1e-8).expect("unwrap failed"),
weight_decay: A::zero(),
m: None,
v: None,
t: 0,
rho_inf: A::from(2.0).expect("unwrap failed") / (A::one() - beta2) - A::one(),
}
}
pub fn new_with_config(
learning_rate: A,
beta1: A,
beta2: A,
epsilon: A,
weight_decay: A,
) -> Self {
Self {
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
m: None,
v: None,
t: 0,
rho_inf: A::from(2.0).expect("unwrap failed") / (A::one() - beta2) - A::one(),
}
}
pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
self.beta1 = beta1;
self
}
pub fn get_beta1(&self) -> A {
self.beta1
}
pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
self.beta2 = beta2;
self.rho_inf = A::from(2.0).expect("unwrap failed") / (A::one() - beta2) - A::one();
self
}
pub fn get_beta2(&self) -> A {
self.beta2
}
pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
self.epsilon = epsilon;
self
}
pub fn get_epsilon(&self) -> A {
self.epsilon
}
pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
self.weight_decay = weight_decay;
self
}
pub fn get_weight_decay(&self) -> A {
self.weight_decay
}
pub fn learning_rate(&self) -> A {
self.learning_rate
}
pub fn set_lr(&mut self, lr: A) {
self.learning_rate = lr;
}
pub fn reset(&mut self) {
self.m = None;
self.v = None;
self.t = 0;
}
}
impl<A, D> Optimizer<A, D> for RAdam<A>
where
A: Float + ScalarOperand + Debug + Send + Sync + std::convert::From<f64>,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let params_dyn = params.to_owned().into_dyn();
let gradients_dyn = gradients.to_owned().into_dyn();
let adjusted_gradients = if self.weight_decay > A::zero() {
&gradients_dyn + &(¶ms_dyn * self.weight_decay)
} else {
gradients_dyn
};
if self.m.is_none() {
self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
self.v = Some(vec![Array::zeros(params_dyn.raw_dim())]);
self.t = 0;
}
let m = self.m.as_mut().expect("unwrap failed");
let v = self.v.as_mut().expect("unwrap failed");
if m.is_empty() {
m.push(Array::zeros(params_dyn.raw_dim()));
v.push(Array::zeros(params_dyn.raw_dim()));
} else if m[0].raw_dim() != params_dyn.raw_dim() {
m[0] = Array::zeros(params_dyn.raw_dim());
v[0] = Array::zeros(params_dyn.raw_dim());
}
self.t += 1;
m[0] = &m[0] * self.beta1 + &(&adjusted_gradients * (A::one() - self.beta1));
v[0] = &v[0] * self.beta2
+ &(&adjusted_gradients * &adjusted_gradients * (A::one() - self.beta2));
let m_hat = &m[0] / (A::one() - self.beta1.powi(self.t as i32));
let beta2_t = self.beta2.powi(self.t as i32);
let rho_t = self.rho_inf
- <A as scirs2_core::numeric::NumCast>::from(2.0).expect("unwrap failed")
* <A as scirs2_core::numeric::NumCast>::from(self.t as f64).expect("unwrap failed")
* beta2_t
/ (A::one() - beta2_t);
let updated_params = if rho_t
> <A as scirs2_core::numeric::NumCast>::from(4.0).expect("unwrap failed")
{
let v_hat = &v[0] / (A::one() - beta2_t);
let sma_rectifier = (rho_t
- <A as scirs2_core::numeric::NumCast>::from(4.0).expect("unwrap failed"))
* (rho_t - <A as scirs2_core::numeric::NumCast>::from(2.0).expect("unwrap failed"))
/ self.rho_inf;
let sma_rectifier = sma_rectifier * A::sqrt(A::one() - beta2_t)
/ (A::one() - self.beta1.powi(self.t as i32));
let v_hat_sqrt = v_hat.mapv(|x| x.sqrt());
let step = &m_hat / &(&v_hat_sqrt + self.epsilon) * sma_rectifier * self.learning_rate;
¶ms_dyn - step
} else {
let step = &m_hat * self.learning_rate;
¶ms_dyn - step
};
Ok(updated_params
.into_dimensionality::<D>()
.expect("unwrap failed"))
}
fn get_learning_rate(&self) -> A {
self.learning_rate
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.learning_rate = learning_rate;
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_radam_step() {
let params = Array1::zeros(3);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let mut optimizer = RAdam::new(0.01);
let new_params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
assert!(new_params.iter().all(|&x| x != 0.0));
for i in 1..3 {
assert!(new_params[i].abs() > new_params[i - 1].abs());
}
}
#[test]
fn test_radam_multiple_steps() {
let mut params = Array1::zeros(3);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let mut optimizer = RAdam::new(0.01);
for _ in 0..100 {
params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
}
for i in 1..3 {
assert!(params[i].abs() > params[i - 1].abs());
}
}
#[test]
fn test_radam_weight_decay() {
let params = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let gradients = Array1::from_vec(vec![0.01, 0.01, 0.01]);
let mut optimizer = RAdam::new_with_config(
0.01, 0.9, 0.999, 1e-8, 0.1, );
let new_params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
for i in 0..3 {
assert!(new_params[i].abs() < params[i].abs());
}
}
#[test]
fn test_radam_reset() {
let params = Array1::zeros(3);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let mut optimizer = RAdam::new(0.01);
optimizer.step(¶ms, &gradients).expect("unwrap failed");
assert_eq!(optimizer.t, 1);
assert!(optimizer.m.is_some());
assert!(optimizer.v.is_some());
optimizer.reset();
assert_eq!(optimizer.t, 0);
assert!(optimizer.m.is_none());
assert!(optimizer.v.is_none());
}
}