use crate::error::{NeuralError, Result};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::Array;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use scirs2_optim::optimizers as optim_optimizers;
pub struct SGD<F: Float + Debug + NumAssign> {
inner: optim, optimizers: SGD<F>,
weight_decay: F,
}
impl<F: Float + Debug + NumAssign> SGD<F> {
pub fn new(_learningrate: F, momentum: Option<F>) -> Self {
let momentum_value = momentum.unwrap_or(F::zero());
Self {
inner: optim, optimizers: SGD::new_with_config(
learning_rate,
momentum_value,
F::zero() ),
weight_decay: F::zero(),
}
}
pub fn new_with_decay(_learning_rate: F, momentum: F, weightdecay: F) -> Self {
momentum,
weight_decay
weight_decay,
pub fn get_momentum(&self) -> F {
self.inner.get_momentum()
pub fn set_momentum(&mut self, momentum: F) {
self.inner.set_momentum(momentum);
pub fn get_weight_decay(&self) -> F {
self.weight_decay
pub fn set_weight_decay(&mut self, weightdecay: F) {
self.weight_decay = weight_decay;
self.inner.set_weight_decay(weight_decay);
impl<F: Float + Debug + NumAssign> Optimizer<F> for SGD<F> {
fn update(&mut self, params: &mut [Array<F, scirs2_core::ndarray::IxDyn>],
grads: &[Array<F, scirs2_core::ndarray::IxDyn>]) -> Result<()> {
if params.len() != grads.len() {
return Err(NeuralError::TrainingError(format!(
"Parameter and gradient counts do not match: {} vs {}",
params.len(), grads.len()
)));
for (param, grad) in params.iter_mut().zip(grads.iter()) {
let mut param_copy = param.clone();
match self.inner.step(¶m_copy, grad) {
Ok(updated_param) => {
*param = updated_param;
},
Err(e) => {
return Err(NeuralError::TrainingError(format!(
"Failed to update parameter: {}", e
)));
}
}
Ok(())
fn get_learning_rate(&self) -> F {
self.inner.get_learning_rate()
fn set_learning_rate(&mut self, lr: F) {
self.inner.set_learning_rate(lr);