optirs_core/optimizers/
adabound.rs

1// OptiRS - AdaBound Optimizer
2// Adaptive Gradient Methods with Dynamic Bound of Learning Rate
3// Reference: "Adaptive Gradient Methods with Dynamic Bound of Learning Rate" (ICLR 2019)
4//
5// Algorithm:
6//   AdaBound employs dynamic bounds on learning rates to achieve smooth transition
7//   from adaptive methods to SGD. This prevents the generalization gap observed
8//   in pure adaptive methods.
9//
10//   Lower bound: α_l(t) = α_final * (1 - 1/(γ*t + 1))
11//   Upper bound: α_u(t) = α_final * (1 + 1/(γ*t))
12//   Clipped learning rate: η_t(i) = Clip(α / √(v_t(i) + ε), α_l(t), α_u(t))
13
14use crate::error::{OptimError, Result};
15use scirs2_core::ndarray_ext::{Array1, ArrayView1};
16use scirs2_core::numeric::{Float, Zero};
17use serde::{Deserialize, Serialize};
18
19/// AdaBound optimizer configuration
20///
21/// AdaBound combines the benefits of adaptive learning rate methods (like Adam)
22/// with the strong generalization of SGD by dynamically bounding the learning rates.
23///
24/// # Key Features
25/// - Smooth transition from Adam to SGD during training
26/// - Dynamic bounds prevent learning rates from becoming too large or too small
27/// - Better generalization than pure Adam
28/// - Maintains fast convergence of adaptive methods
29///
30/// # Type Parameters
31/// - `T`: Floating-point type (f32 or f64)
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct AdaBound<T: Float> {
34    /// Initial learning rate (α)
35    learning_rate: T,
36
37    /// Final learning rate for SGD convergence
38    /// Typically 0.1 * learning_rate
39    final_lr: T,
40
41    /// First moment decay rate (β₁) - typically 0.9
42    beta1: T,
43
44    /// Second moment decay rate (β₂) - typically 0.999
45    beta2: T,
46
47    /// Small constant for numerical stability (ε) - typically 1e-8
48    epsilon: T,
49
50    /// Convergence speed parameter (γ) - typically 1e-3
51    /// Controls how fast bounds converge to final_lr
52    gamma: T,
53
54    /// Weight decay coefficient (L2 regularization)
55    weight_decay: T,
56
57    /// Whether to use AMSBound variant (max of v_t)
58    amsbound: bool,
59
60    /// First moment vector (m_t)
61    momentum: Option<Array1<T>>,
62
63    /// Second moment vector (v_t)
64    velocity: Option<Array1<T>>,
65
66    /// Max of second moment (v̂_t) - only for AMSBound
67    max_velocity: Option<Array1<T>>,
68
69    /// Number of optimization steps performed
70    step_count: usize,
71}
72
73use scirs2_core::ndarray::ScalarOperand;
74
75impl<T: Float + ScalarOperand> Default for AdaBound<T> {
76    fn default() -> Self {
77        Self::new(
78            T::from(0.001).unwrap(), // learning_rate
79            T::from(0.1).unwrap(),   // final_lr
80            T::from(0.9).unwrap(),   // beta1
81            T::from(0.999).unwrap(), // beta2
82            T::from(1e-8).unwrap(),  // epsilon
83            T::from(1e-3).unwrap(),  // gamma
84            T::zero(),               // weight_decay
85            false,                   // amsbound
86        )
87        .unwrap()
88    }
89}
90
91impl<T: Float + ScalarOperand> AdaBound<T> {
92    /// Create a new AdaBound optimizer
93    ///
94    /// # Arguments
95    /// - `learning_rate`: Initial learning rate (typically 0.001)
96    /// - `final_lr`: Final learning rate for SGD convergence (typically 0.1)
97    /// - `beta1`: First moment decay rate (typically 0.9)
98    /// - `beta2`: Second moment decay rate (typically 0.999)
99    /// - `epsilon`: Small constant for numerical stability (typically 1e-8)
100    /// - `gamma`: Convergence speed parameter (typically 1e-3)
101    /// - `weight_decay`: L2 regularization coefficient (typically 0.0)
102    /// - `amsbound`: Use AMSBound variant if true
103    ///
104    /// # Example
105    /// ```
106    /// use optirs_core::optimizers::AdaBound;
107    ///
108    /// let optimizer = AdaBound::<f32>::new(
109    ///     0.001,  // learning_rate
110    ///     0.1,    // final_lr
111    ///     0.9,    // beta1
112    ///     0.999,  // beta2
113    ///     1e-8,   // epsilon
114    ///     1e-3,   // gamma
115    ///     0.0,    // weight_decay
116    ///     false   // amsbound
117    /// ).unwrap();
118    /// ```
119    #[allow(clippy::too_many_arguments)]
120    pub fn new(
121        learning_rate: T,
122        final_lr: T,
123        beta1: T,
124        beta2: T,
125        epsilon: T,
126        gamma: T,
127        weight_decay: T,
128        amsbound: bool,
129    ) -> Result<Self> {
130        let lr_f64 = learning_rate.to_f64().unwrap();
131        let final_f64 = final_lr.to_f64().unwrap();
132        let beta1_f64 = beta1.to_f64().unwrap();
133        let beta2_f64 = beta2.to_f64().unwrap();
134        let eps_f64 = epsilon.to_f64().unwrap();
135        let gamma_f64 = gamma.to_f64().unwrap();
136        let wd_f64 = weight_decay.to_f64().unwrap();
137
138        if lr_f64 <= 0.0 {
139            return Err(OptimError::InvalidParameter(format!(
140                "learning_rate must be positive, got {}",
141                lr_f64
142            )));
143        }
144        if final_f64 <= 0.0 {
145            return Err(OptimError::InvalidParameter(format!(
146                "final_lr must be positive, got {}",
147                final_f64
148            )));
149        }
150        if beta1_f64 <= 0.0 || beta1_f64 >= 1.0 {
151            return Err(OptimError::InvalidParameter(format!(
152                "beta1 must be in (0, 1), got {}",
153                beta1_f64
154            )));
155        }
156        if beta2_f64 <= 0.0 || beta2_f64 >= 1.0 {
157            return Err(OptimError::InvalidParameter(format!(
158                "beta2 must be in (0, 1), got {}",
159                beta2_f64
160            )));
161        }
162        if eps_f64 <= 0.0 {
163            return Err(OptimError::InvalidParameter(format!(
164                "epsilon must be positive, got {}",
165                eps_f64
166            )));
167        }
168        if gamma_f64 <= 0.0 {
169            return Err(OptimError::InvalidParameter(format!(
170                "gamma must be positive, got {}",
171                gamma_f64
172            )));
173        }
174        if wd_f64 < 0.0 {
175            return Err(OptimError::InvalidParameter(format!(
176                "weight_decay must be non-negative, got {}",
177                wd_f64
178            )));
179        }
180
181        Ok(Self {
182            learning_rate,
183            final_lr,
184            beta1,
185            beta2,
186            epsilon,
187            gamma,
188            weight_decay,
189            amsbound,
190            momentum: None,
191            velocity: None,
192            max_velocity: None,
193            step_count: 0,
194        })
195    }
196
197    /// Perform a single optimization step
198    ///
199    /// # Arguments
200    /// - `params`: Current parameter values
201    /// - `grads`: Gradient values
202    ///
203    /// # Returns
204    /// Result containing updated parameters or error
205    ///
206    /// # Algorithm
207    /// 1. Initialize moments on first step
208    /// 2. Apply weight decay if configured
209    /// 3. Update biased first moment: m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
210    /// 4. Update biased second moment: v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²
211    /// 5. Compute bias-corrected moments
212    /// 6. Compute dynamic bounds: [α_l(t), α_u(t)]
213    /// 7. Compute clipped learning rate per parameter
214    /// 8. Apply parameter update: θ_{t+1} = θ_t - η_t * m̂_t
215    ///
216    /// # Example
217    /// ```
218    /// use optirs_core::optimizers::AdaBound;
219    /// use scirs2_core::ndarray_ext::array;
220    ///
221    /// let mut optimizer = AdaBound::<f32>::default();
222    /// let params = array![1.0, 2.0, 3.0];
223    /// let grads = array![0.1, 0.2, 0.3];
224    ///
225    /// let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
226    /// ```
227    pub fn step(&mut self, params: ArrayView1<T>, grads: ArrayView1<T>) -> Result<Array1<T>> {
228        let n = params.len();
229
230        if grads.len() != n {
231            return Err(OptimError::DimensionMismatch(format!(
232                "Expected gradient size {}, got {}",
233                n,
234                grads.len()
235            )));
236        }
237
238        // Initialize moments on first step
239        if self.momentum.is_none() {
240            self.momentum = Some(Array1::zeros(n));
241            self.velocity = Some(Array1::zeros(n));
242            if self.amsbound {
243                self.max_velocity = Some(Array1::zeros(n));
244            }
245        }
246
247        self.step_count += 1;
248        let t = T::from(self.step_count).unwrap();
249
250        let momentum = self.momentum.as_mut().unwrap();
251        let velocity = self.velocity.as_mut().unwrap();
252
253        let one = T::one();
254        let two = T::from(2).unwrap();
255
256        // Apply weight decay if configured
257        let effective_grads = if self.weight_decay > T::zero() {
258            grads.to_owned() + &(params.to_owned() * self.weight_decay)
259        } else {
260            grads.to_owned()
261        };
262
263        // Update biased first moment: m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
264        for i in 0..n {
265            momentum[i] = self.beta1 * momentum[i] + (one - self.beta1) * effective_grads[i];
266        }
267
268        // Update biased second moment: v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²
269        for i in 0..n {
270            let grad_sq = effective_grads[i] * effective_grads[i];
271            velocity[i] = self.beta2 * velocity[i] + (one - self.beta2) * grad_sq;
272        }
273
274        // For AMSBound: v̂_t = max(v̂_{t-1}, v_t)
275        if self.amsbound {
276            let max_vel = self.max_velocity.as_mut().unwrap();
277            for i in 0..n {
278                if velocity[i] > max_vel[i] {
279                    max_vel[i] = velocity[i];
280                }
281            }
282        }
283
284        // Compute bias correction terms
285        let bias_correction1 = one - self.beta1.powf(t);
286        let bias_correction2 = one - self.beta2.powf(t);
287
288        // Compute dynamic bounds
289        // Lower bound: α_l(t) = α_final * (1 - 1/(γ*t + 1))
290        let lower_bound = self.final_lr * (one - one / (self.gamma * t + one));
291
292        // Upper bound: α_u(t) = α_final * (1 + 1/(γ*t))
293        let upper_bound = self.final_lr * (one + one / (self.gamma * t));
294
295        // Apply parameter updates with clipped learning rates
296        let mut updated_params = params.to_owned();
297
298        for i in 0..n {
299            // Bias-corrected first moment
300            let m_hat = momentum[i] / bias_correction1;
301
302            // Bias-corrected second moment (or max for AMSBound)
303            let v_hat = if self.amsbound {
304                self.max_velocity.as_ref().unwrap()[i] / bias_correction2
305            } else {
306                velocity[i] / bias_correction2
307            };
308
309            // Compute adaptive learning rate: α / √(v_t + ε)
310            let step_size = self.learning_rate / (v_hat.sqrt() + self.epsilon);
311
312            // Clip learning rate to dynamic bounds
313            let clipped_step_size = if step_size < lower_bound {
314                lower_bound
315            } else if step_size > upper_bound {
316                upper_bound
317            } else {
318                step_size
319            };
320
321            // Apply update: θ_{t+1} = θ_t - η_clipped * m̂_t
322            updated_params[i] = updated_params[i] - clipped_step_size * m_hat;
323        }
324
325        Ok(updated_params)
326    }
327
328    /// Get the number of optimization steps performed
329    pub fn step_count(&self) -> usize {
330        self.step_count
331    }
332
333    /// Reset the optimizer state
334    pub fn reset(&mut self) {
335        self.momentum = None;
336        self.velocity = None;
337        self.max_velocity = None;
338        self.step_count = 0;
339    }
340
341    /// Get current dynamic bounds [lower, upper]
342    pub fn current_bounds(&self) -> (T, T) {
343        if self.step_count == 0 {
344            return (self.final_lr, self.final_lr);
345        }
346
347        let t = T::from(self.step_count).unwrap();
348        let one = T::one();
349
350        let lower_bound = self.final_lr * (one - one / (self.gamma * t + one));
351        let upper_bound = self.final_lr * (one + one / (self.gamma * t));
352
353        (lower_bound, upper_bound)
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use approx::assert_relative_eq;
361    use scirs2_core::ndarray_ext::array;
362
363    #[test]
364    fn test_adabound_creation() {
365        let optimizer = AdaBound::<f32>::default();
366        assert_eq!(optimizer.step_count(), 0);
367    }
368
369    #[test]
370    fn test_adabound_single_step() {
371        let mut optimizer = AdaBound::<f32>::default();
372        let params = array![1.0, 2.0, 3.0];
373        let grads = array![0.1, 0.2, 0.3];
374
375        let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
376
377        assert_eq!(updated_params.len(), 3);
378        assert_eq!(optimizer.step_count(), 1);
379
380        // Parameters should decrease (gradient descent)
381        for i in 0..3 {
382            assert!(updated_params[i] < params[i]);
383        }
384    }
385
386    #[test]
387    fn test_adabound_multiple_steps() {
388        let mut optimizer = AdaBound::<f32>::default();
389        let mut params = array![1.0, 2.0, 3.0];
390
391        for _ in 0..10 {
392            let grads = array![0.1, 0.2, 0.3];
393            params = optimizer.step(params.view(), grads.view()).unwrap();
394        }
395
396        assert_eq!(optimizer.step_count(), 10);
397    }
398
399    #[test]
400    fn test_adabound_dynamic_bounds() {
401        let mut optimizer = AdaBound::<f32>::default();
402        let params = array![1.0, 2.0, 3.0];
403        let grads = array![0.1, 0.2, 0.3];
404
405        // Before any steps, bounds should be equal to final_lr
406        let (lower0, upper0) = optimizer.current_bounds();
407        assert_relative_eq!(lower0, 0.1, epsilon = 1e-6);
408        assert_relative_eq!(upper0, 0.1, epsilon = 1e-6);
409
410        // After first step, bounds should widen
411        optimizer.step(params.view(), grads.view()).unwrap();
412        let (lower1, upper1) = optimizer.current_bounds();
413        assert!(lower1 < upper1);
414        assert!(lower1 >= 0.0);
415
416        // After many steps, bounds should converge to final_lr
417        for _ in 0..10000 {
418            // Need many more steps for bound convergence
419            optimizer.step(params.view(), grads.view()).unwrap();
420        }
421        let (lower_final, upper_final) = optimizer.current_bounds();
422        assert_relative_eq!(lower_final, 0.1, epsilon = 0.01);
423        assert_relative_eq!(upper_final, 0.1, epsilon = 0.01);
424    }
425
426    #[test]
427    fn test_amsbound() {
428        let mut optimizer =
429            AdaBound::<f32>::new(0.001, 0.1, 0.9, 0.999, 1e-8, 1e-3, 0.0, true).unwrap();
430
431        let params = array![1.0, 2.0, 3.0];
432        let grads = array![0.1, 0.2, 0.3];
433
434        let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
435        assert_eq!(updated_params.len(), 3);
436        assert!(optimizer.max_velocity.is_some());
437    }
438
439    #[test]
440    fn test_adabound_weight_decay() {
441        let mut optimizer =
442            AdaBound::<f32>::new(0.001, 0.1, 0.9, 0.999, 1e-8, 1e-3, 0.01, false).unwrap();
443
444        let params = array![1.0, 2.0, 3.0];
445        let grads = array![0.1, 0.2, 0.3];
446
447        let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
448
449        // With weight decay, updates should be larger
450        for i in 0..3 {
451            assert!(updated_params[i] < params[i]);
452        }
453    }
454
455    #[test]
456    fn test_adabound_convergence() {
457        // Test convergence on quadratic function f(x) = x²
458        let mut optimizer = AdaBound::<f64>::default();
459        let mut params = array![5.0];
460
461        for _ in 0..500 {
462            // AdaBound needs more iterations for tight convergence
463            let grads = params.mapv(|x| 2.0 * x);
464            params = optimizer.step(params.view(), grads.view()).unwrap();
465        }
466
467        // Should converge close to zero
468        assert!(
469            params[0].abs() < 0.1,
470            "Failed to converge, got {}",
471            params[0]
472        );
473    }
474
475    #[test]
476    fn test_adabound_reset() {
477        let mut optimizer = AdaBound::<f32>::default();
478        let params = array![1.0, 2.0, 3.0];
479        let grads = array![0.1, 0.2, 0.3];
480
481        optimizer.step(params.view(), grads.view()).unwrap();
482        assert_eq!(optimizer.step_count(), 1);
483
484        optimizer.reset();
485        assert_eq!(optimizer.step_count(), 0);
486        assert!(optimizer.momentum.is_none());
487        assert!(optimizer.velocity.is_none());
488    }
489}