optirs-core 0.3.1

OptiRS core optimization algorithms and utilities
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
// Sharpness-Aware Minimization (SAM) optimizer
//
// Implements the SAM optimization algorithm from:
// "Sharpness-Aware Minimization for Efficiently Improving Generalization" (Foret et al., 2020)

use crate::error::{OptimError, Result};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use std::marker::PhantomData;

/// Sharpness-Aware Minimization (SAM) optimizer
///
/// SAM is an optimization technique that seeks parameters that lie in neighborhoods
/// having uniformly low loss values, which improves generalization. It achieves this by
/// performing a two-step update process:
///
/// 1. Compute and take a step in the direction of the "sharpness" gradient (perturbed parameters)
/// 2. Compute the gradient at these perturbed parameters and use it to update the original parameters
///
/// This implementation wraps around a base optimizer and modifies its behavior to implement
/// the SAM algorithm.
///
/// # Parameters
///
/// * `inner_optimizer` - The optimizer to use for the parameter updates
/// * `rho` - The neighborhood size for perturbation (default: 0.05)
/// * `epsilon` - Small constant for numerical stability (default: 1e-12)
/// * `adaptive` - Whether to use adaptive perturbation size (SAM-A) (default: false)
///
/// # Example
///
/// ```
/// use scirs2_core::ndarray::Array1;
/// use optirs_core::optimizers::{SAM, SGD};
/// use optirs_core::Optimizer;
///
/// // Create a base optimizer
/// let sgd = SGD::new(0.1);
///
/// // Wrap it with SAM
/// let mut optimizer = SAM::new(sgd);
///
/// // First step to compute perturbed parameters and store perturbed gradients
/// let params = Array1::zeros(10);
/// let gradients = Array1::ones(10);
/// let (perturbed_params_) = optimizer.first_step(&params, &gradients).expect("unwrap failed");
///
/// // Second step to update original parameters using gradients at perturbed parameters
/// // Normally, you would compute new gradients at perturbed_params
/// let new_gradients = Array1::ones(10) * 0.5; // Example new gradients
/// let updated_params = optimizer.second_step(&params, &new_gradients).expect("unwrap failed");
/// ```
pub struct SAM<A, O, D>
where
    A: Float + ScalarOperand + Debug,
    O: Optimizer<A, D> + Clone,
    D: Dimension,
{
    /// Inner optimizer for parameter updates
    inner_optimizer: O,
    /// Neighborhood size for perturbation (ρ)
    rho: A,
    /// Small constant for numerical stability (ε)
    epsilon: A,
    /// Whether to use adaptive perturbation size (SAM-A)
    adaptive: bool,
    /// Perturbed parameters from first step
    perturbed_params: Option<Array<A, D>>,
    /// Original parameters from first step
    original_params: Option<Array<A, D>>,
    /// Dimension type marker
    _phantom: PhantomData<D>,
}

impl<A, O, D> SAM<A, O, D>
where
    A: Float + ScalarOperand + Debug,
    O: Optimizer<A, D> + Clone,
    D: Dimension,
{
    /// Creates a new SAM optimizer with the given inner optimizer and default settings
    pub fn new(inner_optimizer: O) -> Self {
        Self {
            inner_optimizer,
            rho: A::from(0.05).expect("unwrap failed"),
            epsilon: A::from(1e-12).expect("unwrap failed"),
            adaptive: false,
            perturbed_params: None,
            original_params: None,
            _phantom: PhantomData,
        }
    }

    /// Creates a new SAM optimizer with the given inner optimizer and configuration
    pub fn with_config(inner_optimizer: O, rho: A, adaptive: bool) -> Self {
        Self {
            inner_optimizer,
            rho,
            epsilon: A::from(1e-12).expect("unwrap failed"),
            adaptive,
            perturbed_params: None,
            original_params: None,
            _phantom: PhantomData,
        }
    }

    /// Set the rho parameter (neighborhood size)
    pub fn with_rho(mut self, rho: A) -> Self {
        self.rho = rho;
        self
    }

    /// Set the epsilon parameter (numerical stability)
    pub fn with_epsilon(mut self, epsilon: A) -> Self {
        self.epsilon = epsilon;
        self
    }

    /// Set whether to use adaptive perturbation size (SAM-A)
    pub fn with_adaptive(mut self, adaptive: bool) -> Self {
        self.adaptive = adaptive;
        self
    }

    /// Get the inner optimizer
    pub fn inner_optimizer(&self) -> &O {
        &self.inner_optimizer
    }

    /// Get a mutable reference to the inner optimizer
    pub fn inner_optimizer_mut(&mut self) -> &mut O {
        &mut self.inner_optimizer
    }

    /// Get the rho parameter
    pub fn rho(&self) -> A {
        self.rho
    }

    /// Get the epsilon parameter
    pub fn epsilon(&self) -> A {
        self.epsilon
    }

    /// Check if using adaptive perturbation size
    pub fn is_adaptive(&self) -> bool {
        self.adaptive
    }

    /// First step of SAM: compute perturbed parameters by moving in the direction of the gradient
    ///
    /// # Arguments
    ///
    /// * `params` - Current parameters
    /// * `gradients` - Gradients of the loss with respect to the parameters
    ///
    /// # Returns
    ///
    /// Tuple containing (perturbed_parameters, perturbation_size)
    pub fn first_step(
        &mut self,
        params: &Array<A, D>,
        gradients: &Array<A, D>,
    ) -> Result<(Array<A, D>, A)> {
        // Store original parameters
        self.original_params = Some(params.clone());

        // Calculate gradient norm for scaling
        let grad_norm = calculate_norm(gradients)?;

        if grad_norm.is_zero() || !grad_norm.is_finite() {
            return Err(OptimError::OptimizationError(
                "Gradient norm is zero or not finite".to_string(),
            ));
        }

        // Calculate perturbation size
        let e_w = if self.adaptive {
            // Adaptive SAM: scale perturbation by parameter-wise gradient magnitude
            // Note: We need to be careful with parameter scaling to avoid numerical issues
            let param_norm = calculate_norm(params)?;
            if param_norm.is_zero() || !param_norm.is_finite() {
                // Fall back to standard SAM if parameter norm is problematic
                let perturb = gradients / (grad_norm + self.epsilon);
                &perturb * self.rho
            } else {
                // Use a more stable calculation for adaptive SAM
                let mut perturb = params.mapv(|p| p.abs() + self.epsilon);
                perturb = &perturb / param_norm; // Normalize first
                                                 // Element-wise multiply and scale by rho
                gradients * &perturb * self.rho
            }
        } else {
            // Standard SAM: scale perturbation by gradient norm
            let perturb = gradients / (grad_norm + self.epsilon);
            &perturb * self.rho
        };

        // Create perturbed parameters
        let perturbed_params = params + &e_w;
        self.perturbed_params = Some(perturbed_params.clone());

        // Return perturbed parameters and perturbation norm
        Ok((perturbed_params, calculate_norm(&e_w)?))
    }

    /// Second step of SAM: update the original parameters using gradients at the perturbed point
    ///
    /// # Arguments
    ///
    /// * `_params` - Original parameters (used for validation)
    /// * `gradients` - Gradients of the loss with respect to the perturbed parameters
    ///
    /// # Returns
    ///
    /// Updated parameters after applying the "sharpness-aware" update
    pub fn second_step(
        &mut self,
        params: &Array<A, D>,
        gradients: &Array<A, D>,
    ) -> Result<Array<A, D>> {
        // Get original parameters
        let original_params = match &self.original_params {
            Some(_params) => params,
            None => {
                return Err(OptimError::OptimizationError(
                    "Must call first_step before second_step".to_string(),
                ))
            }
        };

        // Use the inner optimizer to update the original parameters with the perturbed gradients
        let updated_params = self.inner_optimizer.step(original_params, gradients)?;

        // Reset stored parameters
        self.perturbed_params = None;
        self.original_params = None;

        Ok(updated_params)
    }

    /// Reset the internal state
    pub fn reset(&mut self) {
        self.perturbed_params = None;
        self.original_params = None;
    }
}

impl<A, O, D> Clone for SAM<A, O, D>
where
    A: Float + ScalarOperand + Debug,
    O: Optimizer<A, D> + Clone,
    D: Dimension,
{
    fn clone(&self) -> Self {
        Self {
            inner_optimizer: self.inner_optimizer.clone(),
            rho: self.rho,
            epsilon: self.epsilon,
            adaptive: self.adaptive,
            perturbed_params: self.perturbed_params.clone(),
            original_params: self.original_params.clone(),
            _phantom: PhantomData,
        }
    }
}

impl<A, O, D> Debug for SAM<A, O, D>
where
    A: Float + ScalarOperand + Debug,
    O: Optimizer<A, D> + Clone + Debug,
    D: Dimension,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("SAM")
            .field("inner_optimizer", &self.inner_optimizer)
            .field("rho", &self.rho)
            .field("epsilon", &self.epsilon)
            .field("adaptive", &self.adaptive)
            .finish()
    }
}

impl<A, O, D> Optimizer<A, D> for SAM<A, O, D>
where
    A: Float + ScalarOperand + Debug + Send + Sync,
    O: Optimizer<A, D> + Clone + Send + Sync,
    D: Dimension,
{
    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
        // This single-step implementation is a convenience that combines first_step and second_step
        // However, it's less efficient than using the two-step process because it uses the same gradients
        // for both steps, which doesn't fully implement the SAM algorithm

        // First step: compute perturbed parameters
        let _ = self.first_step(params, gradients)?;

        // Second step: update with the same gradients
        // Note: In a real implementation, you should compute new gradients at the perturbed point
        self.second_step(params, gradients)
    }

    fn set_learning_rate(&mut self, learning_rate: A) {
        self.inner_optimizer.set_learning_rate(learning_rate);
    }

    fn get_learning_rate(&self) -> A {
        self.inner_optimizer.get_learning_rate()
    }
}

/// Calculate the L2 norm of an array
#[allow(dead_code)]
fn calculate_norm<A, D>(array: &Array<A, D>) -> Result<A>
where
    A: Float + ScalarOperand + Debug,
    D: Dimension,
{
    let squared_sum = array.iter().fold(A::zero(), |acc, &x| acc + x * x);
    let norm = squared_sum.sqrt();

    if !norm.is_finite() {
        return Err(OptimError::OptimizationError(
            "Norm calculation resulted in non-finite value".to_string(),
        ));
    }

    Ok(norm)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::optimizers::sgd::SGD;
    use approx::assert_abs_diff_eq;
    use scirs2_core::ndarray::Array1;

    #[test]
    fn test_sam_creation() {
        let sgd = SGD::new(0.01);
        let optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);

        assert_abs_diff_eq!(optimizer.rho(), 0.05);
        assert_abs_diff_eq!(optimizer.get_learning_rate(), 0.01);
        assert!(!optimizer.is_adaptive());
    }

    #[test]
    fn test_sam_with_config() {
        let sgd = SGD::new(0.01);
        let optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
            SAM::with_config(sgd, 0.1, true);

        assert_abs_diff_eq!(optimizer.rho(), 0.1);
        assert!(optimizer.is_adaptive());
    }

    #[test]
    fn test_sam_first_step() {
        let sgd = SGD::new(0.1);
        let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);

        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);

        // Calculate expected perturbation for standard SAM
        let grad_norm = (0.1f64.powi(2) + 0.2f64.powi(2) + 0.3f64.powi(2)).sqrt();
        let normalized_grads = gradients.mapv(|g| g / grad_norm);
        let expected_perturb = normalized_grads.mapv(|g| g * 0.05);
        let expected_params = &params + &expected_perturb;

        let (perturbed_params, perturb_size) = optimizer
            .first_step(&params, &gradients)
            .expect("unwrap failed");

        // Verify perturbed parameters
        assert_abs_diff_eq!(perturbed_params[0], expected_params[0], epsilon = 1e-6);
        assert_abs_diff_eq!(perturbed_params[1], expected_params[1], epsilon = 1e-6);
        assert_abs_diff_eq!(perturbed_params[2], expected_params[2], epsilon = 1e-6);

        // Verify perturbation size
        assert_abs_diff_eq!(perturb_size, 0.05, epsilon = 1e-6);
    }

    #[test]
    fn test_sam_adaptive() {
        let sgd = SGD::new(0.1);
        let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> =
            SAM::with_config(sgd, 0.05, true);

        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);

        // With the more stable implementation, we'll just verify the behavior makes sense
        let (perturbed_params, perturb_size) = optimizer
            .first_step(&params, &gradients)
            .expect("unwrap failed");

        // Verify perturbed parameters make sense
        assert!(perturb_size > 0.0 && perturb_size < 1.0); // Perturbation size should be reasonable

        // Params should be perturbed in a way that relates to both param magnitude and gradients
        assert!(perturbed_params[0] != params[0]);
        assert!(perturbed_params[1] != params[1]);
        assert!(perturbed_params[2] != params[2]);

        // The perturbation on larger parameters should be larger (relative to their gradients)
        let delta0 = (perturbed_params[0] - params[0]).abs();
        let delta2 = (perturbed_params[2] - params[2]).abs();
        assert!(delta2 > delta0);
    }

    #[test]
    fn test_sam_second_step() {
        let sgd = SGD::new(0.1);
        let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);

        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);

        // First step to set up perturbed parameters
        let _ = optimizer
            .first_step(&params, &gradients)
            .expect("unwrap failed");

        // Simulate computing new gradients at perturbed point
        let new_gradients = Array1::from_vec(vec![0.15, 0.25, 0.35]);

        // Second step should update original parameters with new gradients
        let updated_params = optimizer
            .second_step(&params, &new_gradients)
            .expect("unwrap failed");

        // Expected update: params - lr * new_gradients
        let expected_params =
            Array1::from_vec(vec![1.0 - 0.1 * 0.15, 2.0 - 0.1 * 0.25, 3.0 - 0.1 * 0.35]);

        assert_abs_diff_eq!(updated_params[0], expected_params[0], epsilon = 1e-6);
        assert_abs_diff_eq!(updated_params[1], expected_params[1], epsilon = 1e-6);
        assert_abs_diff_eq!(updated_params[2], expected_params[2], epsilon = 1e-6);
    }

    #[test]
    fn test_sam_reset() {
        let sgd = SGD::new(0.1);
        let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);

        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);

        // First step
        let _ = optimizer
            .first_step(&params, &gradients)
            .expect("unwrap failed");

        // Reset
        optimizer.reset();

        // Should fail because first_step needs to be called before second_step
        let result = optimizer.second_step(&params, &gradients);
        assert!(result.is_err());
    }

    #[test]
    fn test_sam_error_handling() {
        let sgd = SGD::new(0.1);
        let mut optimizer: SAM<f64, SGD<f64>, scirs2_core::ndarray::Ix1> = SAM::new(sgd);

        // Gradient with all zeros should return error
        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
        let zero_gradients = Array1::zeros(3);

        let result = optimizer.first_step(&params, &zero_gradients);
        assert!(result.is_err());
    }
}