use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::Result;
use crate::optimizers::Optimizer;
#[derive(Debug, Clone)]
pub struct RMSprop<A: Float + ScalarOperand + Debug> {
learning_rate: A,
rho: A,
epsilon: A,
weight_decay: A,
v: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> RMSprop<A> {
pub fn new(learning_rate: A) -> Self {
Self {
learning_rate,
rho: A::from(0.9).expect("unwrap failed"),
epsilon: A::from(1e-8).expect("unwrap failed"),
weight_decay: A::zero(),
v: None,
}
}
pub fn new_with_config(learning_rate: A, rho: A, epsilon: A, weight_decay: A) -> Self {
Self {
learning_rate,
rho,
epsilon,
weight_decay,
v: None,
}
}
pub fn set_rho(&mut self, rho: A) -> &mut Self {
self.rho = rho;
self
}
pub fn get_rho(&self) -> A {
self.rho
}
pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
self.epsilon = epsilon;
self
}
pub fn get_epsilon(&self) -> A {
self.epsilon
}
pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
self.weight_decay = weight_decay;
self
}
pub fn get_weight_decay(&self) -> A {
self.weight_decay
}
pub fn reset(&mut self) {
self.v = None;
}
}
impl<A, D> Optimizer<A, D> for RMSprop<A>
where
A: Float + ScalarOperand + Debug + Send + Sync,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let params_dyn = params.to_owned().into_dyn();
let gradients_dyn = gradients.to_owned().into_dyn();
let adjusted_gradients = if self.weight_decay > A::zero() {
&gradients_dyn + &(¶ms_dyn * self.weight_decay)
} else {
gradients_dyn
};
if self.v.is_none() {
self.v = Some(vec![Array::zeros(params_dyn.raw_dim())]);
}
let v = self.v.as_mut().expect("unwrap failed");
if v.is_empty() {
v.push(Array::zeros(params_dyn.raw_dim()));
} else if v[0].raw_dim() != params_dyn.raw_dim() {
v[0] = Array::zeros(params_dyn.raw_dim());
}
v[0] =
&v[0] * self.rho + &(&adjusted_gradients * &adjusted_gradients * (A::one() - self.rho));
let v_sqrt = v[0].mapv(|x| x.sqrt());
let step = &adjusted_gradients * self.learning_rate / &(&v_sqrt + self.epsilon);
let updated_params = ¶ms_dyn - step;
Ok(updated_params
.into_dimensionality::<D>()
.expect("unwrap failed"))
}
fn get_learning_rate(&self) -> A {
self.learning_rate
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.learning_rate = learning_rate;
}
}