use crate::error::{NeuralError, Result};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct MomentumOptimizer<F: Float + NumAssign + ScalarOperand + Debug> {
learning_rate: F,
momentum: F,
weight_decay: F,
velocity: Vec<Array<F, scirs2_core::ndarray::IxDyn>>,
}
impl<F: Float + NumAssign + ScalarOperand + Debug> MomentumOptimizer<F> {
pub fn new(learning_rate: F, momentum: F) -> Self {
Self {
learning_rate,
momentum,
weight_decay: F::zero(),
velocity: Vec::new(),
}
}
pub fn new_with_weight_decay(learning_rate: F, momentum: F, weight_decay: F) -> Self {
Self {
learning_rate,
momentum,
weight_decay,
velocity: Vec::new(),
}
}
pub fn set_momentum(&mut self, momentum: F) -> &mut Self {
self.momentum = momentum;
self
}
pub fn get_momentum(&self) -> F {
self.momentum
}
pub fn set_weight_decay(&mut self, weight_decay: F) -> &mut Self {
self.weight_decay = weight_decay;
self
}
pub fn get_weight_decay(&self) -> F {
self.weight_decay
}
}
impl<F: Float + NumAssign + ScalarOperand + Debug> Optimizer<F> for MomentumOptimizer<F> {
fn update(
&mut self,
params: &mut [Array<F, scirs2_core::ndarray::IxDyn>],
grads: &[Array<F, scirs2_core::ndarray::IxDyn>],
) -> Result<()> {
if params.len() != grads.len() {
return Err(NeuralError::TrainingError(format!(
"Number of parameter arrays ({}) does not match number of gradient arrays ({})",
params.len(),
grads.len()
)));
}
if self.velocity.len() != params.len() {
self.velocity = params.iter().map(|p| Array::zeros(p.raw_dim())).collect();
}
for i in 0..params.len() {
let adjusted_grad = if self.weight_decay > F::zero() {
&grads[i] + &(¶ms[i] * self.weight_decay)
} else {
grads[i].clone()
};
self.velocity[i] =
&self.velocity[i] * self.momentum + &(&adjusted_grad * self.learning_rate);
params[i] = ¶ms[i] - &self.velocity[i];
}
Ok(())
}
fn get_learning_rate(&self) -> F {
self.learning_rate
}
fn set_learning_rate(&mut self, lr: F) {
self.learning_rate = lr;
}
fn reset(&mut self) {
self.velocity.clear();
}
fn name(&self) -> &'static str {
"MomentumOptimizer"
}
}