use crate::error::{NeuralError, Result};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct RAdam<F: Float + NumAssign + ScalarOperand + Debug> {
learning_rate: F,
beta1: F,
beta2: F,
epsilon: F,
weight_decay: F,
m: Vec<Array<F, scirs2_core::ndarray::IxDyn>>,
v: Vec<Array<F, scirs2_core::ndarray::IxDyn>>,
t: usize,
rho_inf: F,
}
impl<F: Float + NumAssign + ScalarOperand + Debug> RAdam<F> {
pub fn new(_learning_rate: F, beta1: F, beta2: F, epsilon: F, weightdecay: F) -> Result<Self> {
let two = F::from(2.0).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 2.0 to the appropriate floating point type".to_string(),
)
})?;
let rho_inf = two / (F::one() - beta2) - F::one();
Ok(Self {
learning_rate: _learning_rate,
beta1,
beta2,
epsilon,
weight_decay: weightdecay,
m: Vec::new(),
v: Vec::new(),
t: 0,
rho_inf,
})
}
pub fn default_with_lr(learning_rate: F) -> Result<Self> {
let beta1 = F::from(0.9).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 0.9 to the appropriate floating point type".to_string(),
)
})?;
let beta2 = F::from(0.999).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 0.999 to the appropriate floating point type".to_string(),
)
})?;
let epsilon = F::from(1e-8).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 1e-8 to the appropriate floating point type".to_string(),
)
})?;
let weight_decay = F::from(0.0).ok_or_else(|| {
NeuralError::InvalidArgument(
"Failed to convert 0.0 to the appropriate floating point type".to_string(),
)
})?;
Self::new(learning_rate, beta1, beta2, epsilon, weight_decay)
}
pub fn get_beta1(&self) -> F {
self.beta1
}
pub fn set_beta1(&mut self, beta1: F) -> &mut Self {
self.beta1 = beta1;
self
}
pub fn get_beta2(&self) -> F {
self.beta2
}
pub fn set_beta2(&mut self, beta2: F) -> &mut Self {
self.beta2 = beta2;
self
}
pub fn get_epsilon(&self) -> F {
self.epsilon
}
pub fn set_epsilon(&mut self, epsilon: F) -> &mut Self {
self.epsilon = epsilon;
self
}
pub fn get_weight_decay(&self) -> F {
self.weight_decay
}
pub fn set_weight_decay(&mut self, weightdecay: F) -> &mut Self {
self.weight_decay = weightdecay;
self
}
pub fn reset(&mut self) {
self.m.clear();
self.v.clear();
self.t = 0;
}
}
impl<F: Float + NumAssign + ScalarOperand + Debug> Optimizer<F> for RAdam<F> {
fn update(
&mut self,
params: &mut [Array<F, scirs2_core::ndarray::IxDyn>],
grads: &[Array<F, scirs2_core::ndarray::IxDyn>],
) -> Result<()> {
if params.len() != grads.len() {
return Err(NeuralError::TrainingError(format!(
"Number of parameter arrays ({}) does not match number of gradient arrays ({})",
params.len(),
grads.len()
)));
}
self.t += 1;
if self.m.len() != params.len() {
self.m = params.iter().map(|p| Array::zeros(p.raw_dim())).collect();
self.v = params.iter().map(|p| Array::zeros(p.raw_dim())).collect();
}
let one_minus_beta1 = F::one() - self.beta1;
let one_minus_beta2 = F::one() - self.beta2;
let beta1_pow_t = self.beta1.powi(self.t as i32);
let beta2_pow_t = self.beta2.powi(self.t as i32);
let bias_correction1 = F::one() - beta1_pow_t;
let _two = F::from(2.0).ok_or_else(|| {
NeuralError::InvalidArgument("Failed to convert 2.0 to floating point type".to_string())
})?;
let _four = F::from(4.0).ok_or_else(|| {
NeuralError::InvalidArgument("Failed to convert 4.0 to floating point type".to_string())
})?;
#[allow(dead_code)]
let _five = F::from(5.0).ok_or_else(|| {
NeuralError::InvalidArgument("Failed to convert 5.0 to floating point type".to_string())
})?; let rho_inf_f64 = self.rho_inf.to_f64().ok_or_else(|| {
NeuralError::ComputationError("Failed to convert rho_inf to f64".to_string())
})?;
let beta2_pow_t_f64 = beta2_pow_t.to_f64().ok_or_else(|| {
NeuralError::ComputationError("Failed to convert beta2_pow_t to f64".to_string())
})?;
let t_f64 = self.t as f64;
let rho_t_f64 = rho_inf_f64 - 2.0 * t_f64 * beta2_pow_t_f64 / (1.0 - beta2_pow_t_f64);
let rho_t = F::from(rho_t_f64).ok_or_else(|| {
NeuralError::ComputationError("Failed to convert rho_t from f64".to_string())
})?;
let rect_term;
let use_adaptive_lr;
if rho_t > _four {
let rho_t_minus_4 = rho_t - _four;
let rho_t_minus_2 = rho_t - _two;
let rt = (rho_t_minus_2 / rho_t_minus_4)
* ((rho_t - _four) * (rho_t - _two) / (rho_t * rho_t_minus_4)).sqrt();
rect_term = rt;
use_adaptive_lr = true;
} else {
rect_term = F::one();
use_adaptive_lr = false;
}
for i in 0..params.len() {
let adjusted_grad = if self.weight_decay > F::zero() {
&grads[i] + &(¶ms[i] * self.weight_decay)
} else {
grads[i].clone()
};
self.m[i] = &self.m[i] * self.beta1 + &(&adjusted_grad * one_minus_beta1);
self.v[i] =
&self.v[i] * self.beta2 + &(adjusted_grad.mapv(|x| x * x) * one_minus_beta2);
let m_hat = &self.m[i] / bias_correction1;
if use_adaptive_lr {
let v_hat = &self.v[i] * (F::one() / (F::one() - beta2_pow_t));
let denom = v_hat.mapv(|x| x.sqrt()) + self.epsilon;
params[i] = ¶ms[i] - &(m_hat / denom * self.learning_rate * rect_term);
} else {
params[i] = ¶ms[i] - &(m_hat * self.learning_rate);
}
}
Ok(())
}
fn get_learning_rate(&self) -> F {
self.learning_rate
}
fn set_learning_rate(&mut self, lr: F) {
self.learning_rate = lr;
}
fn name(&self) -> &'static str {
"RAdam"
}
}