use crate::error::{NeuralError, Result};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::Array;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use scirs2_optim::optimizers as optim_optimizers;
pub struct RMSprop<F: Float + Debug + NumAssign> {
inner: optim, optimizers: RMSprop<F>,
weight_decay: F,
}
impl<F: Float + Debug + NumAssign> RMSprop<F> {
pub fn new(_learningrate: F) -> Self {
let rho = F::from(0.9).unwrap_or(F::one());
let epsilon = F::from(1e-8).unwrap_or(F::zero());
let weight_decay = F::zero();
Self {
inner: optim, optimizers: RMSprop::new_with_config(
learning_rate,
rho,
epsilon,
weight_decay
),
weight_decay,
}
}
pub fn new_with_config(_learning_rate: F, rho: F, epsilon: F, weightdecay: F) -> Self {
pub fn get_rho(&self) -> F {
self.inner.get_rho()
pub fn set_rho(&mut self, rho: F) -> &mut Self {
self.inner.set_rho(rho);
self
pub fn get_epsilon(&self) -> F {
self.inner.get_epsilon()
pub fn set_epsilon(&mut self, epsilon: F) -> &mut Self {
self.inner.set_epsilon(epsilon);
pub fn get_weight_decay(&self) -> F {
self.weight_decay
pub fn set_weight_decay(&mut self, weightdecay: F) -> &mut Self {
self.weight_decay = weight_decay;
self.inner.set_weight_decay(weight_decay);
pub fn reset(&mut self) {
self.inner.reset();
impl<F: Float + Debug + NumAssign> Optimizer<F> for RMSprop<F> {
fn update(&mut self, params: &mut [Array<F, scirs2_core::ndarray::IxDyn>],
grads: &[Array<F, scirs2_core::ndarray::IxDyn>]) -> Result<()> {
if params.len() != grads.len() {
return Err(NeuralError::TrainingError(format!(
"Parameter and gradient counts do not match: {} vs {}",
params.len(), grads.len()
)));
for (param, grad) in params.iter_mut().zip(grads.iter()) {
let mut param_copy = param.clone();
match self.inner.step(¶m_copy, grad) {
Ok(updated_param) => {
*param = updated_param;
},
Err(e) => {
return Err(NeuralError::TrainingError(format!(
"Failed to update parameter: {}", e
)));
}
}
Ok(())
fn get_learning_rate(&self) -> F {
self.inner.get_learning_rate()
fn set_learning_rate(&mut self, lr: F) {
self.inner.set_learning_rate(lr);