use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use scirs2_optimize::stochastic::{minimize_adam, AdamOptions};
use crate::error::Result;
use crate::optimizers::Optimizer;
#[derive(Debug, Clone)]
pub struct Adam<A: Float + ScalarOperand + Debug> {
learning_rate: A,
beta1: A,
beta2: A,
epsilon: A,
weight_decay: A,
m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
v: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
t: usize,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> Adam<A> {
pub fn new(learning_rate: A) -> Self {
Self {
learning_rate,
beta1: A::from(0.9).expect("unwrap failed"),
beta2: A::from(0.999).expect("unwrap failed"),
epsilon: A::from(1e-8).expect("unwrap failed"),
weight_decay: A::zero(),
m: None,
v: None,
t: 0,
}
}
pub fn new_with_config(
learning_rate: A,
beta1: A,
beta2: A,
epsilon: A,
weight_decay: A,
) -> Self {
Self {
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
m: None,
v: None,
t: 0,
}
}
pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
self.beta1 = beta1;
self
}
pub fn with_beta1(mut self, beta1: A) -> Self {
self.beta1 = beta1;
self
}
pub fn get_beta1(&self) -> A {
self.beta1
}
pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
self.beta2 = beta2;
self
}
pub fn with_beta2(mut self, beta2: A) -> Self {
self.beta2 = beta2;
self
}
pub fn get_beta2(&self) -> A {
self.beta2
}
pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
self.epsilon = epsilon;
self
}
pub fn with_epsilon(mut self, epsilon: A) -> Self {
self.epsilon = epsilon;
self
}
pub fn get_epsilon(&self) -> A {
self.epsilon
}
pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
self.weight_decay = weight_decay;
self
}
pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
self.weight_decay = weight_decay;
self
}
pub fn get_weight_decay(&self) -> A {
self.weight_decay
}
pub fn learning_rate(&self) -> A {
self.learning_rate
}
pub fn set_lr(&mut self, lr: A) {
self.learning_rate = lr;
}
pub fn reset(&mut self) {
self.m = None;
self.v = None;
self.t = 0;
}
}
impl<A, D> Optimizer<A, D> for Adam<A>
where
A: Float + ScalarOperand + Debug + Send + Sync,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
if params.shape() != gradients.shape() {
return Err(crate::error::OptimError::DimensionMismatch(format!(
"Incompatible shapes: parameters have shape {:?}, gradients have shape {:?}",
params.shape(),
gradients.shape()
)));
}
let params_dyn = params.to_owned().into_dyn();
let gradients_dyn = gradients.to_owned().into_dyn();
let adjusted_gradients = if self.weight_decay > A::zero() {
&gradients_dyn + &(¶ms_dyn * self.weight_decay)
} else {
gradients_dyn
};
if self.m.is_none() {
self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
self.v = Some(vec![Array::zeros(params_dyn.raw_dim())]);
self.t = 0;
}
let m = self.m.as_mut().expect("unwrap failed");
let v = self.v.as_mut().expect("unwrap failed");
if m.is_empty() {
m.push(Array::zeros(params_dyn.raw_dim()));
v.push(Array::zeros(params_dyn.raw_dim()));
} else if m[0].raw_dim() != params_dyn.raw_dim() {
m[0] = Array::zeros(params_dyn.raw_dim());
v[0] = Array::zeros(params_dyn.raw_dim());
}
self.t = self.t.checked_add(1).ok_or_else(|| {
crate::error::OptimError::InvalidConfig(
"Timestep counter overflow - too many optimization steps".to_string(),
)
})?;
m[0] = &m[0] * self.beta1 + &(&adjusted_gradients * (A::one() - self.beta1));
v[0] = &v[0] * self.beta2
+ &(&adjusted_gradients * &adjusted_gradients * (A::one() - self.beta2));
let exp_beta1 = i32::try_from(self.t).map_err(|_| {
crate::error::OptimError::InvalidConfig(
"Timestep too large for bias correction calculation".to_string(),
)
})?;
let m_hat = &m[0] / (A::one() - self.beta1.powi(exp_beta1));
let exp_beta2 = i32::try_from(self.t).map_err(|_| {
crate::error::OptimError::InvalidConfig(
"Timestep too large for bias correction calculation".to_string(),
)
})?;
let v_hat = &v[0] / (A::one() - self.beta2.powi(exp_beta2));
let v_hat_sqrt = v_hat.mapv(|x| x.sqrt());
let step = &m_hat / &(&v_hat_sqrt + self.epsilon) * self.learning_rate;
let updated_params = ¶ms_dyn - step;
Ok(updated_params
.into_dimensionality::<D>()
.expect("unwrap failed"))
}
fn get_learning_rate(&self) -> A {
self.learning_rate
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.learning_rate = learning_rate;
}
}