use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::Result;
use crate::optimizers::Optimizer;
use crate::simd_optimizer::SimdOptimizer;
#[derive(Debug, Clone)]
pub struct SimdSGD<A: Float> {
learning_rate: A,
momentum: A,
weight_decay: A,
velocity: Option<Array1<A>>,
}
impl<A: Float> SimdSGD<A> {
pub fn new(learning_rate: A) -> Self {
Self {
learning_rate,
momentum: A::zero(),
weight_decay: A::zero(),
velocity: None,
}
}
pub fn new_with_config(learning_rate: A, momentum: A, weight_decay: A) -> Self {
Self {
learning_rate,
momentum,
weight_decay,
velocity: None,
}
}
pub fn set_momentum(&mut self, momentum: A) -> &mut Self {
self.momentum = momentum;
self
}
pub fn with_momentum(mut self, momentum: A) -> Self {
self.momentum = momentum;
self
}
pub fn get_momentum(&self) -> A {
self.momentum
}
pub fn learning_rate(&self) -> A {
self.learning_rate
}
pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
self.weight_decay = weight_decay;
self
}
pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
self.weight_decay = weight_decay;
self
}
pub fn get_weight_decay(&self) -> A {
self.weight_decay
}
pub fn reset(&mut self) {
self.velocity = None;
}
}
impl Optimizer<f32, scirs2_core::ndarray::Ix1> for SimdSGD<f32> {
fn step(&mut self, params: &Array1<f32>, gradients: &Array1<f32>) -> Result<Array1<f32>> {
if params.shape() != gradients.shape() {
return Err(crate::error::OptimError::DimensionMismatch(format!(
"Incompatible shapes: parameters have shape {:?}, gradients have shape {:?}",
params.shape(),
gradients.shape()
)));
}
let params_view = params.view();
let gradients_view = gradients.view();
let adjusted_gradients = if self.weight_decay > 0.0 {
f32::simd_weight_decay(&gradients_view, ¶ms_view, self.weight_decay)
} else {
gradients.to_owned()
};
if self.velocity.is_none() {
self.velocity = Some(Array1::zeros(params.len()));
}
let velocity = self.velocity.as_mut().expect("unwrap failed");
if velocity.len() != params.len() {
*velocity = Array1::zeros(params.len());
}
let new_params = if self.momentum > 0.0 {
let (updated_params, updated_velocity) = f32::simd_momentum_update(
¶ms_view,
&adjusted_gradients.view(),
&velocity.view(),
self.learning_rate,
self.momentum,
);
*velocity = updated_velocity;
updated_params
} else {
f32::simd_sgd_update(¶ms_view, &adjusted_gradients.view(), self.learning_rate)
};
Ok(new_params)
}
fn get_learning_rate(&self) -> f32 {
self.learning_rate
}
fn set_learning_rate(&mut self, learning_rate: f32) {
self.learning_rate = learning_rate;
}
}
impl Optimizer<f64, scirs2_core::ndarray::Ix1> for SimdSGD<f64> {
fn step(&mut self, params: &Array1<f64>, gradients: &Array1<f64>) -> Result<Array1<f64>> {
if params.shape() != gradients.shape() {
return Err(crate::error::OptimError::DimensionMismatch(format!(
"Incompatible shapes: parameters have shape {:?}, gradients have shape {:?}",
params.shape(),
gradients.shape()
)));
}
let params_view = params.view();
let gradients_view = gradients.view();
let adjusted_gradients = if self.weight_decay > 0.0 {
f64::simd_weight_decay(&gradients_view, ¶ms_view, self.weight_decay)
} else {
gradients.to_owned()
};
if self.velocity.is_none() {
self.velocity = Some(Array1::zeros(params.len()));
}
let velocity = self.velocity.as_mut().expect("unwrap failed");
if velocity.len() != params.len() {
*velocity = Array1::zeros(params.len());
}
let new_params = if self.momentum > 0.0 {
let (updated_params, updated_velocity) = f64::simd_momentum_update(
¶ms_view,
&adjusted_gradients.view(),
&velocity.view(),
self.learning_rate,
self.momentum,
);
*velocity = updated_velocity;
updated_params
} else {
f64::simd_sgd_update(¶ms_view, &adjusted_gradients.view(), self.learning_rate)
};
Ok(new_params)
}
fn get_learning_rate(&self) -> f64 {
self.learning_rate
}
fn set_learning_rate(&mut self, learning_rate: f64) {
self.learning_rate = learning_rate;
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_simd_sgd_basic() {
let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
let mut optimizer = SimdSGD::new(0.1);
let result = optimizer.step(¶ms, &gradients).expect("unwrap failed");
assert_relative_eq!(result[0], 0.99, epsilon = 1e-6);
assert_relative_eq!(result[1], 1.98, epsilon = 1e-6);
assert_relative_eq!(result[2], 2.97, epsilon = 1e-6);
assert_relative_eq!(result[3], 3.96, epsilon = 1e-6);
}
#[test]
fn test_simd_sgd_momentum() {
let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
let mut optimizer = SimdSGD::new_with_config(0.1, 0.9, 0.0);
let result1 = optimizer.step(¶ms, &gradients).expect("unwrap failed");
let result2 = optimizer.step(&result1, &gradients).expect("unwrap failed");
assert!(result2[0] < result1[0]);
}
#[test]
fn test_simd_sgd_weight_decay() {
let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
let mut optimizer = SimdSGD::new_with_config(0.1, 0.0, 0.01);
let result = optimizer.step(¶ms, &gradients).expect("unwrap failed");
let expected_grad = 0.1 + 0.01 * 1.0;
assert_relative_eq!(result[0], 1.0 - 0.1 * expected_grad, epsilon = 1e-6);
}
#[test]
fn test_simd_sgd_large_array() {
let size = 1000;
let params: Array1<f32> = Array1::from_vec((0..size).map(|i| i as f32).collect());
let gradients: Array1<f32> = Array1::from_elem(size, 0.1);
let mut optimizer = SimdSGD::new(0.01);
let result = optimizer.step(¶ms, &gradients).expect("unwrap failed");
for i in 0..size {
assert_relative_eq!(result[i], (i as f32) - 0.01 * 0.1, epsilon = 1e-6);
}
}
#[test]
fn test_simd_sgd_f64() {
let params = Array1::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
let mut optimizer = SimdSGD::new(0.1);
let result = optimizer.step(¶ms, &gradients).expect("unwrap failed");
assert_relative_eq!(result[0], 0.99, epsilon = 1e-10);
assert_relative_eq!(result[1], 1.98, epsilon = 1e-10);
assert_relative_eq!(result[2], 2.97, epsilon = 1e-10);
assert_relative_eq!(result[3], 3.96, epsilon = 1e-10);
}
#[test]
fn test_simd_sgd_reset() {
let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
let mut optimizer = SimdSGD::new_with_config(0.1, 0.9, 0.0);
let _ = optimizer.step(¶ms, &gradients).expect("unwrap failed");
assert!(optimizer.velocity.is_some());
optimizer.reset();
assert!(optimizer.velocity.is_none());
}
}