1use crate::error::{OptimError, Result};
7use crate::optimizers::Optimizer;
8use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13pub struct SAM<A, O, D>
56where
57 A: Float + ScalarOperand + Debug,
58 O: Optimizer<A, D> + Clone,
59 D: Dimension,
60{
61 inner_optimizer: O,
63 rho: A,
65 epsilon: A,
67 adaptive: bool,
69 perturbed_params: Option<Array<A, D>>,
71 original_params: Option<Array<A, D>>,
73 _phantom: PhantomData<D>,
75}
76
77impl<A, O, D> SAM<A, O, D>
78where
79 A: Float + ScalarOperand + Debug,
80 O: Optimizer<A, D> + Clone,
81 D: Dimension,
82{
83 pub fn new(inner_optimizer: O) -> Self {
85 Self {
86 inner_optimizer,
87 rho: A::from(0.05).unwrap(),
88 epsilon: A::from(1e-12).unwrap(),
89 adaptive: false,
90 perturbed_params: None,
91 original_params: None,
92 _phantom: PhantomData,
93 }
94 }
95
96 pub fn with_config(inner_optimizer: O, rho: A, adaptive: bool) -> Self {
98 Self {
99 inner_optimizer,
100 rho,
101 epsilon: A::from(1e-12).unwrap(),
102 adaptive,
103 perturbed_params: None,
104 original_params: None,
105 _phantom: PhantomData,
106 }
107 }
108
109 pub fn with_rho(mut self, rho: A) -> Self {
111 self.rho = rho;
112 self
113 }
114
115 pub fn with_epsilon(mut self, epsilon: A) -> Self {
117 self.epsilon = epsilon;
118 self
119 }
120
121 pub fn with_adaptive(mut self, adaptive: bool) -> Self {
123 self.adaptive = adaptive;
124 self
125 }
126
127 pub fn inner_optimizer(&self) -> &O {
129 &self.inner_optimizer
130 }
131
132 pub fn inner_optimizer_mut(&mut self) -> &mut O {
134 &mut self.inner_optimizer
135 }
136
137 pub fn rho(&self) -> A {
139 self.rho
140 }
141
142 pub fn epsilon(&self) -> A {
144 self.epsilon
145 }
146
147 pub fn is_adaptive(&self) -> bool {
149 self.adaptive
150 }
151
152 pub fn first_step(
163 &mut self,
164 params: &Array<A, D>,
165 gradients: &Array<A, D>,
166 ) -> Result<(Array<A, D>, A)> {
167 self.original_params = Some(params.clone());
169
170 let grad_norm = calculate_norm(gradients)?;
172
173 if grad_norm.is_zero() || !grad_norm.is_finite() {
174 return Err(OptimError::OptimizationError(
175 "Gradient norm is zero or not finite".to_string(),
176 ));
177 }
178
179 let e_w = if self.adaptive {
181 let param_norm = calculate_norm(params)?;
184 if param_norm.is_zero() || !param_norm.is_finite() {
185 let perturb = gradients / (grad_norm + self.epsilon);
187 &perturb * self.rho
188 } else {
189 let mut perturb = params.mapv(|p| p.abs() + self.epsilon);
191 perturb = &perturb / param_norm; gradients * &perturb * self.rho
194 }
195 } else {
196 let perturb = gradients / (grad_norm + self.epsilon);
198 &perturb * self.rho
199 };
200
201 let perturbed_params = params + &e_w;
203 self.perturbed_params = Some(perturbed_params.clone());
204
205 Ok((perturbed_params, calculate_norm(&e_w)?))
207 }
208
209 pub fn second_step(
220 &mut self,
221 params: &Array<A, D>,
222 gradients: &Array<A, D>,
223 ) -> Result<Array<A, D>> {
224 let original_params = match &self.original_params {
226 Some(_params) => params,
227 None => {
228 return Err(OptimError::OptimizationError(
229 "Must call first_step before second_step".to_string(),
230 ))
231 }
232 };
233
234 let updated_params = self.inner_optimizer.step(original_params, gradients)?;
236
237 self.perturbed_params = None;
239 self.original_params = None;
240
241 Ok(updated_params)
242 }
243
244 pub fn reset(&mut self) {
246 self.perturbed_params = None;
247 self.original_params = None;
248 }
249}
250
251impl<A, O, D> Clone for SAM<A, O, D>
252where
253 A: Float + ScalarOperand + Debug,
254 O: Optimizer<A, D> + Clone,
255 D: Dimension,
256{
257 fn clone(&self) -> Self {
258 Self {
259 inner_optimizer: self.inner_optimizer.clone(),
260 rho: self.rho,
261 epsilon: self.epsilon,
262 adaptive: self.adaptive,
263 perturbed_params: self.perturbed_params.clone(),
264 original_params: self.original_params.clone(),
265 _phantom: PhantomData,
266 }
267 }
268}
269
270impl<A, O, D> Debug for SAM<A, O, D>
271where
272 A: Float + ScalarOperand + Debug,
273 O: Optimizer<A, D> + Clone + Debug,
274 D: Dimension,
275{
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.debug_struct("SAM")
278 .field("inner_optimizer", &self.inner_optimizer)
279 .field("rho", &self.rho)
280 .field("epsilon", &self.epsilon)
281 .field("adaptive", &self.adaptive)
282 .finish()
283 }
284}
285
286impl<A, O, D> Optimizer<A, D> for SAM<A, O, D>
287where
288 A: Float + ScalarOperand + Debug + Send + Sync,
289 O: Optimizer<A, D> + Clone + Send + Sync,
290 D: Dimension,
291{
292 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
293 let _ = self.first_step(params, gradients)?;
299
300 self.second_step(params, gradients)
303 }
304
305 fn set_learning_rate(&mut self, learning_rate: A) {
306 self.inner_optimizer.set_learning_rate(learning_rate);
307 }
308
309 fn get_learning_rate(&self) -> A {
310 self.inner_optimizer.get_learning_rate()
311 }
312}
313
314#[allow(dead_code)]
316fn calculate_norm<A, D>(array: &Array<A, D>) -> Result<A>
317where
318 A: Float + ScalarOperand + Debug,
319 D: Dimension,
320{
321 let squared_sum = array.iter().fold(A::zero(), |acc, &x| acc + x * x);
322 let norm = squared_sum.sqrt();
323
324 if !norm.is_finite() {
325 return Err(OptimError::OptimizationError(
326 "Norm calculation resulted in non-finite value".to_string(),
327 ));
328 }
329
330 Ok(norm)
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::optimizers::sgd::SGD;
337 use approx::assert_abs_diff_eq;
338 use scirs2_core::ndarray::Array1;
339
340 #[test]
341 fn test_sam_creation() {
342 let sgd = SGD::new(0.01);
343 let optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
344
345 assert_abs_diff_eq!(optimizer.rho(), 0.05);
346 assert_abs_diff_eq!(optimizer.get_learning_rate(), 0.01);
347 assert!(!optimizer.is_adaptive());
348 }
349
350 #[test]
351 fn test_sam_with_config() {
352 let sgd = SGD::new(0.01);
353 let optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
354 SAM::with_config(sgd, 0.1, true);
355
356 assert_abs_diff_eq!(optimizer.rho(), 0.1);
357 assert!(optimizer.is_adaptive());
358 }
359
360 #[test]
361 fn test_sam_first_step() {
362 let sgd = SGD::new(0.1);
363 let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
364
365 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
366 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
367
368 let grad_norm = (0.1f64.powi(2) + 0.2f64.powi(2) + 0.3f64.powi(2)).sqrt();
370 let normalized_grads = gradients.mapv(|g| g / grad_norm);
371 let expected_perturb = normalized_grads.mapv(|g| g * 0.05);
372 let expected_params = ¶ms + &expected_perturb;
373
374 let (perturbed_params, perturb_size) = optimizer.first_step(¶ms, &gradients).unwrap();
375
376 assert_abs_diff_eq!(perturbed_params[0], expected_params[0], epsilon = 1e-6);
378 assert_abs_diff_eq!(perturbed_params[1], expected_params[1], epsilon = 1e-6);
379 assert_abs_diff_eq!(perturbed_params[2], expected_params[2], epsilon = 1e-6);
380
381 assert_abs_diff_eq!(perturb_size, 0.05, epsilon = 1e-6);
383 }
384
385 #[test]
386 fn test_sam_adaptive() {
387 let sgd = SGD::new(0.1);
388 let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
389 SAM::with_config(sgd, 0.05, true);
390
391 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
392 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
393
394 let (perturbed_params, perturb_size) = optimizer.first_step(¶ms, &gradients).unwrap();
396
397 assert!(perturb_size > 0.0 && perturb_size < 1.0); assert!(perturbed_params[0] != params[0]);
402 assert!(perturbed_params[1] != params[1]);
403 assert!(perturbed_params[2] != params[2]);
404
405 let delta0 = (perturbed_params[0] - params[0]).abs();
407 let delta2 = (perturbed_params[2] - params[2]).abs();
408 assert!(delta2 > delta0);
409 }
410
411 #[test]
412 fn test_sam_second_step() {
413 let sgd = SGD::new(0.1);
414 let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
415
416 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
417 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
418
419 let _ = optimizer.first_step(¶ms, &gradients).unwrap();
421
422 let new_gradients = Array1::from_vec(vec![0.15, 0.25, 0.35]);
424
425 let updated_params = optimizer.second_step(¶ms, &new_gradients).unwrap();
427
428 let expected_params =
430 Array1::from_vec(vec![1.0 - 0.1 * 0.15, 2.0 - 0.1 * 0.25, 3.0 - 0.1 * 0.35]);
431
432 assert_abs_diff_eq!(updated_params[0], expected_params[0], epsilon = 1e-6);
433 assert_abs_diff_eq!(updated_params[1], expected_params[1], epsilon = 1e-6);
434 assert_abs_diff_eq!(updated_params[2], expected_params[2], epsilon = 1e-6);
435 }
436
437 #[test]
438 fn test_sam_reset() {
439 let sgd = SGD::new(0.1);
440 let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
441
442 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
443 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
444
445 let _ = optimizer.first_step(¶ms, &gradients).unwrap();
447
448 optimizer.reset();
450
451 let result = optimizer.second_step(¶ms, &gradients);
453 assert!(result.is_err());
454 }
455
456 #[test]
457 fn test_sam_error_handling() {
458 let sgd = SGD::new(0.1);
459 let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
460
461 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
463 let zero_gradients = Array1::zeros(3);
464
465 let result = optimizer.first_step(¶ms, &zero_gradients);
466 assert!(result.is_err());
467 }
468}