use super::common::{compute_gradient_norm, Optimizer};
use crate::{TrainError, TrainResult};
use scirs2_core::ndarray::{Array, Ix2};
use std::collections::HashMap;
#[derive(Debug)]
pub struct SamOptimizer<O: Optimizer> {
base_optimizer: O,
rho: f64,
perturbations: HashMap<String, Array<f64, Ix2>>,
}
impl<O: Optimizer> SamOptimizer<O> {
pub fn new(base_optimizer: O, rho: f64) -> TrainResult<Self> {
if rho <= 0.0 {
return Err(TrainError::OptimizerError(
"SAM rho must be positive".to_string(),
));
}
Ok(Self {
base_optimizer,
rho,
perturbations: HashMap::new(),
})
}
pub fn first_step(
&mut self,
parameters: &mut HashMap<String, Array<f64, Ix2>>,
gradients: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<()> {
let grad_norm = compute_gradient_norm(gradients);
if grad_norm == 0.0 {
return Ok(());
}
for (name, param) in parameters.iter_mut() {
let grad = gradients.get(name).ok_or_else(|| {
TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
})?;
let perturbation = grad.mapv(|g| self.rho * g / grad_norm);
*param = &*param + &perturbation;
self.perturbations.insert(name.clone(), perturbation);
}
Ok(())
}
pub fn second_step(
&mut self,
parameters: &mut HashMap<String, Array<f64, Ix2>>,
gradients: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<()> {
for (name, param) in parameters.iter_mut() {
if let Some(perturbation) = self.perturbations.get(name) {
*param = &*param - perturbation;
}
}
self.perturbations.clear();
self.base_optimizer.step(parameters, gradients)
}
}
impl<O: Optimizer> Optimizer for SamOptimizer<O> {
fn step(
&mut self,
parameters: &mut HashMap<String, Array<f64, Ix2>>,
gradients: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<()> {
self.second_step(parameters, gradients)
}
fn zero_grad(&mut self) {
self.base_optimizer.zero_grad();
}
fn get_lr(&self) -> f64 {
self.base_optimizer.get_lr()
}
fn set_lr(&mut self, lr: f64) {
self.base_optimizer.set_lr(lr);
}
fn state_dict(&self) -> HashMap<String, Vec<f64>> {
let mut state = self.base_optimizer.state_dict();
state.insert("rho".to_string(), vec![self.rho]);
for (name, perturbation) in &self.perturbations {
state.insert(
format!("perturbation_{}", name),
perturbation.iter().copied().collect(),
);
}
state
}
fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
if let Some(rho_val) = state.get("rho") {
self.rho = rho_val[0];
}
self.base_optimizer.load_state_dict(state.clone());
for (key, values) in state {
if let Some(name) = key.strip_prefix("perturbation_") {
if let Some(pert) = self.perturbations.get(name) {
let shape = pert.raw_dim();
if let Ok(arr) = Array::from_shape_vec(shape, values) {
self.perturbations.insert(name.to_string(), arr);
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::super::common::OptimizerConfig;
use super::super::sgd::SgdOptimizer;
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_sam_optimizer() {
let inner_config = OptimizerConfig {
learning_rate: 0.01,
..Default::default()
};
let inner_optimizer = SgdOptimizer::new(inner_config);
let mut optimizer = SamOptimizer::new(inner_optimizer, 0.05).expect("unwrap");
let mut params = HashMap::new();
params.insert("w".to_string(), array![[1.0, 2.0]]);
let mut grads = HashMap::new();
grads.insert("w".to_string(), array![[0.1, 0.1]]);
let original_w = params.get("w").expect("unwrap").clone();
optimizer.first_step(&mut params, &grads).expect("unwrap");
let perturbed_w = params.get("w").expect("unwrap");
assert_ne!(perturbed_w[[0, 0]], original_w[[0, 0]]);
optimizer.second_step(&mut params, &grads).expect("unwrap");
let final_w = params.get("w").expect("unwrap");
assert!(final_w[[0, 0]] < original_w[[0, 0]]);
let state = optimizer.state_dict();
assert!(state.contains_key("rho"));
}
#[test]
fn test_sam_invalid_rho() {
let inner_optimizer = SgdOptimizer::new(OptimizerConfig::default());
let result = SamOptimizer::new(inner_optimizer, 0.0);
assert!(result.is_err());
let inner_optimizer = SgdOptimizer::new(OptimizerConfig::default());
let result = SamOptimizer::new(inner_optimizer, -0.1);
assert!(result.is_err());
}
}