optirs_core/optimizers/
sam.rs

1// Sharpness-Aware Minimization (SAM) optimizer
2//
3// Implements the SAM optimization algorithm from:
4// "Sharpness-Aware Minimization for Efficiently Improving Generalization" (Foret et al., 2020)
5
6use crate::error::{OptimError, Result};
7use crate::optimizers::Optimizer;
8use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13/// Sharpness-Aware Minimization (SAM) optimizer
14///
15/// SAM is an optimization technique that seeks parameters that lie in neighborhoods
16/// having uniformly low loss values, which improves generalization. It achieves this by
17/// performing a two-step update process:
18///
19/// 1. Compute and take a step in the direction of the "sharpness" gradient (perturbed parameters)
20/// 2. Compute the gradient at these perturbed parameters and use it to update the original parameters
21///
22/// This implementation wraps around a base optimizer and modifies its behavior to implement
23/// the SAM algorithm.
24///
25/// # Parameters
26///
27/// * `inner_optimizer` - The optimizer to use for the parameter updates
28/// * `rho` - The neighborhood size for perturbation (default: 0.05)
29/// * `epsilon` - Small constant for numerical stability (default: 1e-12)
30/// * `adaptive` - Whether to use adaptive perturbation size (SAM-A) (default: false)
31///
32/// # Example
33///
34/// ```
35/// use scirs2_core::ndarray::Array1;
36/// use optirs_core::optimizers::{SAM, SGD};
37/// use optirs_core::Optimizer;
38///
39/// // Create a base optimizer
40/// let sgd = SGD::new(0.1);
41///
42/// // Wrap it with SAM
43/// let mut optimizer = SAM::new(sgd);
44///
45/// // First step to compute perturbed parameters and store perturbed gradients
46/// let params = Array1::zeros(10);
47/// let gradients = Array1::ones(10);
48/// let (perturbed_params_) = optimizer.first_step(&params, &gradients).unwrap();
49///
50/// // Second step to update original parameters using gradients at perturbed parameters
51/// // Normally, you would compute new gradients at perturbed_params
52/// let new_gradients = Array1::ones(10) * 0.5; // Example new gradients
53/// let updated_params = optimizer.second_step(&params, &new_gradients).unwrap();
54/// ```
55pub struct SAM<A, O, D>
56where
57    A: Float + ScalarOperand + Debug,
58    O: Optimizer<A, D> + Clone,
59    D: Dimension,
60{
61    /// Inner optimizer for parameter updates
62    inner_optimizer: O,
63    /// Neighborhood size for perturbation (ρ)
64    rho: A,
65    /// Small constant for numerical stability (ε)
66    epsilon: A,
67    /// Whether to use adaptive perturbation size (SAM-A)
68    adaptive: bool,
69    /// Perturbed parameters from first step
70    perturbed_params: Option<Array<A, D>>,
71    /// Original parameters from first step
72    original_params: Option<Array<A, D>>,
73    /// Dimension type marker
74    _phantom: PhantomData<D>,
75}
76
77impl<A, O, D> SAM<A, O, D>
78where
79    A: Float + ScalarOperand + Debug,
80    O: Optimizer<A, D> + Clone,
81    D: Dimension,
82{
83    /// Creates a new SAM optimizer with the given inner optimizer and default settings
84    pub fn new(inner_optimizer: O) -> Self {
85        Self {
86            inner_optimizer,
87            rho: A::from(0.05).unwrap(),
88            epsilon: A::from(1e-12).unwrap(),
89            adaptive: false,
90            perturbed_params: None,
91            original_params: None,
92            _phantom: PhantomData,
93        }
94    }
95
96    /// Creates a new SAM optimizer with the given inner optimizer and configuration
97    pub fn with_config(inner_optimizer: O, rho: A, adaptive: bool) -> Self {
98        Self {
99            inner_optimizer,
100            rho,
101            epsilon: A::from(1e-12).unwrap(),
102            adaptive,
103            perturbed_params: None,
104            original_params: None,
105            _phantom: PhantomData,
106        }
107    }
108
109    /// Set the rho parameter (neighborhood size)
110    pub fn with_rho(mut self, rho: A) -> Self {
111        self.rho = rho;
112        self
113    }
114
115    /// Set the epsilon parameter (numerical stability)
116    pub fn with_epsilon(mut self, epsilon: A) -> Self {
117        self.epsilon = epsilon;
118        self
119    }
120
121    /// Set whether to use adaptive perturbation size (SAM-A)
122    pub fn with_adaptive(mut self, adaptive: bool) -> Self {
123        self.adaptive = adaptive;
124        self
125    }
126
127    /// Get the inner optimizer
128    pub fn inner_optimizer(&self) -> &O {
129        &self.inner_optimizer
130    }
131
132    /// Get a mutable reference to the inner optimizer
133    pub fn inner_optimizer_mut(&mut self) -> &mut O {
134        &mut self.inner_optimizer
135    }
136
137    /// Get the rho parameter
138    pub fn rho(&self) -> A {
139        self.rho
140    }
141
142    /// Get the epsilon parameter
143    pub fn epsilon(&self) -> A {
144        self.epsilon
145    }
146
147    /// Check if using adaptive perturbation size
148    pub fn is_adaptive(&self) -> bool {
149        self.adaptive
150    }
151
152    /// First step of SAM: compute perturbed parameters by moving in the direction of the gradient
153    ///
154    /// # Arguments
155    ///
156    /// * `params` - Current parameters
157    /// * `gradients` - Gradients of the loss with respect to the parameters
158    ///
159    /// # Returns
160    ///
161    /// Tuple containing (perturbed_parameters, perturbation_size)
162    pub fn first_step(
163        &mut self,
164        params: &Array<A, D>,
165        gradients: &Array<A, D>,
166    ) -> Result<(Array<A, D>, A)> {
167        // Store original parameters
168        self.original_params = Some(params.clone());
169
170        // Calculate gradient norm for scaling
171        let grad_norm = calculate_norm(gradients)?;
172
173        if grad_norm.is_zero() || !grad_norm.is_finite() {
174            return Err(OptimError::OptimizationError(
175                "Gradient norm is zero or not finite".to_string(),
176            ));
177        }
178
179        // Calculate perturbation size
180        let e_w = if self.adaptive {
181            // Adaptive SAM: scale perturbation by parameter-wise gradient magnitude
182            // Note: We need to be careful with parameter scaling to avoid numerical issues
183            let param_norm = calculate_norm(params)?;
184            if param_norm.is_zero() || !param_norm.is_finite() {
185                // Fall back to standard SAM if parameter norm is problematic
186                let perturb = gradients / (grad_norm + self.epsilon);
187                &perturb * self.rho
188            } else {
189                // Use a more stable calculation for adaptive SAM
190                let mut perturb = params.mapv(|p| p.abs() + self.epsilon);
191                perturb = &perturb / param_norm; // Normalize first
192                                                 // Element-wise multiply and scale by rho
193                gradients * &perturb * self.rho
194            }
195        } else {
196            // Standard SAM: scale perturbation by gradient norm
197            let perturb = gradients / (grad_norm + self.epsilon);
198            &perturb * self.rho
199        };
200
201        // Create perturbed parameters
202        let perturbed_params = params + &e_w;
203        self.perturbed_params = Some(perturbed_params.clone());
204
205        // Return perturbed parameters and perturbation norm
206        Ok((perturbed_params, calculate_norm(&e_w)?))
207    }
208
209    /// Second step of SAM: update the original parameters using gradients at the perturbed point
210    ///
211    /// # Arguments
212    ///
213    /// * `_params` - Original parameters (used for validation)
214    /// * `gradients` - Gradients of the loss with respect to the perturbed parameters
215    ///
216    /// # Returns
217    ///
218    /// Updated parameters after applying the "sharpness-aware" update
219    pub fn second_step(
220        &mut self,
221        params: &Array<A, D>,
222        gradients: &Array<A, D>,
223    ) -> Result<Array<A, D>> {
224        // Get original parameters
225        let original_params = match &self.original_params {
226            Some(_params) => params,
227            None => {
228                return Err(OptimError::OptimizationError(
229                    "Must call first_step before second_step".to_string(),
230                ))
231            }
232        };
233
234        // Use the inner optimizer to update the original parameters with the perturbed gradients
235        let updated_params = self.inner_optimizer.step(original_params, gradients)?;
236
237        // Reset stored parameters
238        self.perturbed_params = None;
239        self.original_params = None;
240
241        Ok(updated_params)
242    }
243
244    /// Reset the internal state
245    pub fn reset(&mut self) {
246        self.perturbed_params = None;
247        self.original_params = None;
248    }
249}
250
251impl<A, O, D> Clone for SAM<A, O, D>
252where
253    A: Float + ScalarOperand + Debug,
254    O: Optimizer<A, D> + Clone,
255    D: Dimension,
256{
257    fn clone(&self) -> Self {
258        Self {
259            inner_optimizer: self.inner_optimizer.clone(),
260            rho: self.rho,
261            epsilon: self.epsilon,
262            adaptive: self.adaptive,
263            perturbed_params: self.perturbed_params.clone(),
264            original_params: self.original_params.clone(),
265            _phantom: PhantomData,
266        }
267    }
268}
269
270impl<A, O, D> Debug for SAM<A, O, D>
271where
272    A: Float + ScalarOperand + Debug,
273    O: Optimizer<A, D> + Clone + Debug,
274    D: Dimension,
275{
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.debug_struct("SAM")
278            .field("inner_optimizer", &self.inner_optimizer)
279            .field("rho", &self.rho)
280            .field("epsilon", &self.epsilon)
281            .field("adaptive", &self.adaptive)
282            .finish()
283    }
284}
285
286impl<A, O, D> Optimizer<A, D> for SAM<A, O, D>
287where
288    A: Float + ScalarOperand + Debug + Send + Sync,
289    O: Optimizer<A, D> + Clone + Send + Sync,
290    D: Dimension,
291{
292    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
293        // This single-step implementation is a convenience that combines first_step and second_step
294        // However, it's less efficient than using the two-step process because it uses the same gradients
295        // for both steps, which doesn't fully implement the SAM algorithm
296
297        // First step: compute perturbed parameters
298        let _ = self.first_step(params, gradients)?;
299
300        // Second step: update with the same gradients
301        // Note: In a real implementation, you should compute new gradients at the perturbed point
302        self.second_step(params, gradients)
303    }
304
305    fn set_learning_rate(&mut self, learning_rate: A) {
306        self.inner_optimizer.set_learning_rate(learning_rate);
307    }
308
309    fn get_learning_rate(&self) -> A {
310        self.inner_optimizer.get_learning_rate()
311    }
312}
313
314/// Calculate the L2 norm of an array
315#[allow(dead_code)]
316fn calculate_norm<A, D>(array: &Array<A, D>) -> Result<A>
317where
318    A: Float + ScalarOperand + Debug,
319    D: Dimension,
320{
321    let squared_sum = array.iter().fold(A::zero(), |acc, &x| acc + x * x);
322    let norm = squared_sum.sqrt();
323
324    if !norm.is_finite() {
325        return Err(OptimError::OptimizationError(
326            "Norm calculation resulted in non-finite value".to_string(),
327        ));
328    }
329
330    Ok(norm)
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use crate::optimizers::sgd::SGD;
337    use approx::assert_abs_diff_eq;
338    use scirs2_core::ndarray::Array1;
339
340    #[test]
341    fn test_sam_creation() {
342        let sgd = SGD::new(0.01);
343        let optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
344
345        assert_abs_diff_eq!(optimizer.rho(), 0.05);
346        assert_abs_diff_eq!(optimizer.get_learning_rate(), 0.01);
347        assert!(!optimizer.is_adaptive());
348    }
349
350    #[test]
351    fn test_sam_with_config() {
352        let sgd = SGD::new(0.01);
353        let optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
354            SAM::with_config(sgd, 0.1, true);
355
356        assert_abs_diff_eq!(optimizer.rho(), 0.1);
357        assert!(optimizer.is_adaptive());
358    }
359
360    #[test]
361    fn test_sam_first_step() {
362        let sgd = SGD::new(0.1);
363        let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
364
365        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
366        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
367
368        // Calculate expected perturbation for standard SAM
369        let grad_norm = (0.1f64.powi(2) + 0.2f64.powi(2) + 0.3f64.powi(2)).sqrt();
370        let normalized_grads = gradients.mapv(|g| g / grad_norm);
371        let expected_perturb = normalized_grads.mapv(|g| g * 0.05);
372        let expected_params = &params + &expected_perturb;
373
374        let (perturbed_params, perturb_size) = optimizer.first_step(&params, &gradients).unwrap();
375
376        // Verify perturbed parameters
377        assert_abs_diff_eq!(perturbed_params[0], expected_params[0], epsilon = 1e-6);
378        assert_abs_diff_eq!(perturbed_params[1], expected_params[1], epsilon = 1e-6);
379        assert_abs_diff_eq!(perturbed_params[2], expected_params[2], epsilon = 1e-6);
380
381        // Verify perturbation size
382        assert_abs_diff_eq!(perturb_size, 0.05, epsilon = 1e-6);
383    }
384
385    #[test]
386    fn test_sam_adaptive() {
387        let sgd = SGD::new(0.1);
388        let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
389            SAM::with_config(sgd, 0.05, true);
390
391        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
392        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
393
394        // With the more stable implementation, we'll just verify the behavior makes sense
395        let (perturbed_params, perturb_size) = optimizer.first_step(&params, &gradients).unwrap();
396
397        // Verify perturbed parameters make sense
398        assert!(perturb_size > 0.0 && perturb_size < 1.0); // Perturbation size should be reasonable
399
400        // Params should be perturbed in a way that relates to both param magnitude and gradients
401        assert!(perturbed_params[0] != params[0]);
402        assert!(perturbed_params[1] != params[1]);
403        assert!(perturbed_params[2] != params[2]);
404
405        // The perturbation on larger parameters should be larger (relative to their gradients)
406        let delta0 = (perturbed_params[0] - params[0]).abs();
407        let delta2 = (perturbed_params[2] - params[2]).abs();
408        assert!(delta2 > delta0);
409    }
410
411    #[test]
412    fn test_sam_second_step() {
413        let sgd = SGD::new(0.1);
414        let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
415
416        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
417        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
418
419        // First step to set up perturbed parameters
420        let _ = optimizer.first_step(&params, &gradients).unwrap();
421
422        // Simulate computing new gradients at perturbed point
423        let new_gradients = Array1::from_vec(vec![0.15, 0.25, 0.35]);
424
425        // Second step should update original parameters with new gradients
426        let updated_params = optimizer.second_step(&params, &new_gradients).unwrap();
427
428        // Expected update: params - lr * new_gradients
429        let expected_params =
430            Array1::from_vec(vec![1.0 - 0.1 * 0.15, 2.0 - 0.1 * 0.25, 3.0 - 0.1 * 0.35]);
431
432        assert_abs_diff_eq!(updated_params[0], expected_params[0], epsilon = 1e-6);
433        assert_abs_diff_eq!(updated_params[1], expected_params[1], epsilon = 1e-6);
434        assert_abs_diff_eq!(updated_params[2], expected_params[2], epsilon = 1e-6);
435    }
436
437    #[test]
438    fn test_sam_reset() {
439        let sgd = SGD::new(0.1);
440        let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
441
442        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
443        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
444
445        // First step
446        let _ = optimizer.first_step(&params, &gradients).unwrap();
447
448        // Reset
449        optimizer.reset();
450
451        // Should fail because first_step needs to be called before second_step
452        let result = optimizer.second_step(&params, &gradients);
453        assert!(result.is_err());
454    }
455
456    #[test]
457    fn test_sam_error_handling() {
458        let sgd = SGD::new(0.1);
459        let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);
460
461        // Gradient with all zeros should return error
462        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
463        let zero_gradients = Array1::zeros(3);
464
465        let result = optimizer.first_step(&params, &zero_gradients);
466        assert!(result.is_err());
467    }
468}