use crate::error::{OptimError, Result};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use std::marker::PhantomData;
pub struct SAM<A, O, D>
where
A: Float + ScalarOperand + Debug,
O: Optimizer<A, D> + Clone,
D: Dimension,
{
inner_optimizer: O,
rho: A,
epsilon: A,
adaptive: bool,
perturbed_params: Option<Array<A, D>>,
original_params: Option<Array<A, D>>,
_phantom: PhantomData<D>,
}
impl<A, O, D> SAM<A, O, D>
where
A: Float + ScalarOperand + Debug,
O: Optimizer<A, D> + Clone,
D: Dimension,
{
pub fn new(inner_optimizer: O) -> Self {
Self {
inner_optimizer,
rho: A::from(0.05).expect("unwrap failed"),
epsilon: A::from(1e-12).expect("unwrap failed"),
adaptive: false,
perturbed_params: None,
original_params: None,
_phantom: PhantomData,
}
}
pub fn with_config(inner_optimizer: O, rho: A, adaptive: bool) -> Self {
Self {
inner_optimizer,
rho,
epsilon: A::from(1e-12).expect("unwrap failed"),
adaptive,
perturbed_params: None,
original_params: None,
_phantom: PhantomData,
}
}
pub fn with_rho(mut self, rho: A) -> Self {
self.rho = rho;
self
}
pub fn with_epsilon(mut self, epsilon: A) -> Self {
self.epsilon = epsilon;
self
}
pub fn with_adaptive(mut self, adaptive: bool) -> Self {
self.adaptive = adaptive;
self
}
pub fn inner_optimizer(&self) -> &O {
&self.inner_optimizer
}
pub fn inner_optimizer_mut(&mut self) -> &mut O {
&mut self.inner_optimizer
}
pub fn rho(&self) -> A {
self.rho
}
pub fn epsilon(&self) -> A {
self.epsilon
}
pub fn is_adaptive(&self) -> bool {
self.adaptive
}
pub fn first_step(
&mut self,
params: &Array<A, D>,
gradients: &Array<A, D>,
) -> Result<(Array<A, D>, A)> {
self.original_params = Some(params.clone());
let grad_norm = calculate_norm(gradients)?;
if grad_norm.is_zero() || !grad_norm.is_finite() {
return Err(OptimError::OptimizationError(
"Gradient norm is zero or not finite".to_string(),
));
}
let e_w = if self.adaptive {
let param_norm = calculate_norm(params)?;
if param_norm.is_zero() || !param_norm.is_finite() {
let perturb = gradients / (grad_norm + self.epsilon);
&perturb * self.rho
} else {
let mut perturb = params.mapv(|p| p.abs() + self.epsilon);
perturb = &perturb / param_norm; gradients * &perturb * self.rho
}
} else {
let perturb = gradients / (grad_norm + self.epsilon);
&perturb * self.rho
};
let perturbed_params = params + &e_w;
self.perturbed_params = Some(perturbed_params.clone());
Ok((perturbed_params, calculate_norm(&e_w)?))
}
pub fn second_step(
&mut self,
params: &Array<A, D>,
gradients: &Array<A, D>,
) -> Result<Array<A, D>> {
let original_params = match &self.original_params {
Some(_params) => params,
None => {
return Err(OptimError::OptimizationError(
"Must call first_step before second_step".to_string(),
))
}
};
let updated_params = self.inner_optimizer.step(original_params, gradients)?;
self.perturbed_params = None;
self.original_params = None;
Ok(updated_params)
}
pub fn reset(&mut self) {
self.perturbed_params = None;
self.original_params = None;
}
}
impl<A, O, D> Clone for SAM<A, O, D>
where
A: Float + ScalarOperand + Debug,
O: Optimizer<A, D> + Clone,
D: Dimension,
{
fn clone(&self) -> Self {
Self {
inner_optimizer: self.inner_optimizer.clone(),
rho: self.rho,
epsilon: self.epsilon,
adaptive: self.adaptive,
perturbed_params: self.perturbed_params.clone(),
original_params: self.original_params.clone(),
_phantom: PhantomData,
}
}
}
impl<A, O, D> Debug for SAM<A, O, D>
where
A: Float + ScalarOperand + Debug,
O: Optimizer<A, D> + Clone + Debug,
D: Dimension,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SAM")
.field("inner_optimizer", &self.inner_optimizer)
.field("rho", &self.rho)
.field("epsilon", &self.epsilon)
.field("adaptive", &self.adaptive)
.finish()
}
}
impl<A, O, D> Optimizer<A, D> for SAM<A, O, D>
where
A: Float + ScalarOperand + Debug + Send + Sync,
O: Optimizer<A, D> + Clone + Send + Sync,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let _ = self.first_step(params, gradients)?;
self.second_step(params, gradients)
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.inner_optimizer.set_learning_rate(learning_rate);
}
fn get_learning_rate(&self) -> A {
self.inner_optimizer.get_learning_rate()
}
}
#[allow(dead_code)]
fn calculate_norm<A, D>(array: &Array<A, D>) -> Result<A>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
let squared_sum = array.iter().fold(A::zero(), |acc, &x| acc + x * x);
let norm = squared_sum.sqrt();
if !norm.is_finite() {
return Err(OptimError::OptimizationError(
"Norm calculation resulted in non-finite value".to_string(),
));
}
Ok(norm)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimizers::sgd::SGD;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_sam_creation() {
let sgd = SGD::new(0.01);
let optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
assert_abs_diff_eq!(optimizer.rho(), 0.05);
assert_abs_diff_eq!(optimizer.get_learning_rate(), 0.01);
assert!(!optimizer.is_adaptive());
}
#[test]
fn test_sam_with_config() {
let sgd = SGD::new(0.01);
let optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
SAM::with_config(sgd, 0.1, true);
assert_abs_diff_eq!(optimizer.rho(), 0.1);
assert!(optimizer.is_adaptive());
}
#[test]
fn test_sam_first_step() {
let sgd = SGD::new(0.1);
let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let grad_norm = (0.1f64.powi(2) + 0.2f64.powi(2) + 0.3f64.powi(2)).sqrt();
let normalized_grads = gradients.mapv(|g| g / grad_norm);
let expected_perturb = normalized_grads.mapv(|g| g * 0.05);
let expected_params = ¶ms + &expected_perturb;
let (perturbed_params, perturb_size) = optimizer
.first_step(¶ms, &gradients)
.expect("unwrap failed");
assert_abs_diff_eq!(perturbed_params[0], expected_params[0], epsilon = 1e-6);
assert_abs_diff_eq!(perturbed_params[1], expected_params[1], epsilon = 1e-6);
assert_abs_diff_eq!(perturbed_params[2], expected_params[2], epsilon = 1e-6);
assert_abs_diff_eq!(perturb_size, 0.05, epsilon = 1e-6);
}
#[test]
fn test_sam_adaptive() {
let sgd = SGD::new(0.1);
let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
SAM::with_config(sgd, 0.05, true);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let (perturbed_params, perturb_size) = optimizer
.first_step(¶ms, &gradients)
.expect("unwrap failed");
assert!(perturb_size > 0.0 && perturb_size < 1.0);
assert!(perturbed_params[0] != params[0]);
assert!(perturbed_params[1] != params[1]);
assert!(perturbed_params[2] != params[2]);
let delta0 = (perturbed_params[0] - params[0]).abs();
let delta2 = (perturbed_params[2] - params[2]).abs();
assert!(delta2 > delta0);
}
#[test]
fn test_sam_second_step() {
let sgd = SGD::new(0.1);
let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let _ = optimizer
.first_step(¶ms, &gradients)
.expect("unwrap failed");
let new_gradients = Array1::from_vec(vec![0.15, 0.25, 0.35]);
let updated_params = optimizer
.second_step(¶ms, &new_gradients)
.expect("unwrap failed");
let expected_params =
Array1::from_vec(vec![1.0 - 0.1 * 0.15, 2.0 - 0.1 * 0.25, 3.0 - 0.1 * 0.35]);
assert_abs_diff_eq!(updated_params[0], expected_params[0], epsilon = 1e-6);
assert_abs_diff_eq!(updated_params[1], expected_params[1], epsilon = 1e-6);
assert_abs_diff_eq!(updated_params[2], expected_params[2], epsilon = 1e-6);
}
#[test]
fn test_sam_reset() {
let sgd = SGD::new(0.1);
let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let _ = optimizer
.first_step(¶ms, &gradients)
.expect("unwrap failed");
optimizer.reset();
let result = optimizer.second_step(¶ms, &gradients);
assert!(result.is_err());
}
#[test]
fn test_sam_error_handling() {
let sgd = SGD::new(0.1);
let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let zero_gradients = Array1::zeros(3);
let result = optimizer.first_step(¶ms, &zero_gradients);
assert!(result.is_err());
}
}