1use crate::error::{OptimError, Result};
15use scirs2_core::ndarray_ext::{Array1, ArrayView1};
16use scirs2_core::numeric::{Float, Zero};
17use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct AdaBound<T: Float> {
34 learning_rate: T,
36
37 final_lr: T,
40
41 beta1: T,
43
44 beta2: T,
46
47 epsilon: T,
49
50 gamma: T,
53
54 weight_decay: T,
56
57 amsbound: bool,
59
60 momentum: Option<Array1<T>>,
62
63 velocity: Option<Array1<T>>,
65
66 max_velocity: Option<Array1<T>>,
68
69 step_count: usize,
71}
72
73use scirs2_core::ndarray::ScalarOperand;
74
75impl<T: Float + ScalarOperand> Default for AdaBound<T> {
76 fn default() -> Self {
77 Self::new(
78 T::from(0.001).unwrap(), T::from(0.1).unwrap(), T::from(0.9).unwrap(), T::from(0.999).unwrap(), T::from(1e-8).unwrap(), T::from(1e-3).unwrap(), T::zero(), false, )
87 .unwrap()
88 }
89}
90
91impl<T: Float + ScalarOperand> AdaBound<T> {
92 #[allow(clippy::too_many_arguments)]
120 pub fn new(
121 learning_rate: T,
122 final_lr: T,
123 beta1: T,
124 beta2: T,
125 epsilon: T,
126 gamma: T,
127 weight_decay: T,
128 amsbound: bool,
129 ) -> Result<Self> {
130 let lr_f64 = learning_rate.to_f64().unwrap();
131 let final_f64 = final_lr.to_f64().unwrap();
132 let beta1_f64 = beta1.to_f64().unwrap();
133 let beta2_f64 = beta2.to_f64().unwrap();
134 let eps_f64 = epsilon.to_f64().unwrap();
135 let gamma_f64 = gamma.to_f64().unwrap();
136 let wd_f64 = weight_decay.to_f64().unwrap();
137
138 if lr_f64 <= 0.0 {
139 return Err(OptimError::InvalidParameter(format!(
140 "learning_rate must be positive, got {}",
141 lr_f64
142 )));
143 }
144 if final_f64 <= 0.0 {
145 return Err(OptimError::InvalidParameter(format!(
146 "final_lr must be positive, got {}",
147 final_f64
148 )));
149 }
150 if beta1_f64 <= 0.0 || beta1_f64 >= 1.0 {
151 return Err(OptimError::InvalidParameter(format!(
152 "beta1 must be in (0, 1), got {}",
153 beta1_f64
154 )));
155 }
156 if beta2_f64 <= 0.0 || beta2_f64 >= 1.0 {
157 return Err(OptimError::InvalidParameter(format!(
158 "beta2 must be in (0, 1), got {}",
159 beta2_f64
160 )));
161 }
162 if eps_f64 <= 0.0 {
163 return Err(OptimError::InvalidParameter(format!(
164 "epsilon must be positive, got {}",
165 eps_f64
166 )));
167 }
168 if gamma_f64 <= 0.0 {
169 return Err(OptimError::InvalidParameter(format!(
170 "gamma must be positive, got {}",
171 gamma_f64
172 )));
173 }
174 if wd_f64 < 0.0 {
175 return Err(OptimError::InvalidParameter(format!(
176 "weight_decay must be non-negative, got {}",
177 wd_f64
178 )));
179 }
180
181 Ok(Self {
182 learning_rate,
183 final_lr,
184 beta1,
185 beta2,
186 epsilon,
187 gamma,
188 weight_decay,
189 amsbound,
190 momentum: None,
191 velocity: None,
192 max_velocity: None,
193 step_count: 0,
194 })
195 }
196
197 pub fn step(&mut self, params: ArrayView1<T>, grads: ArrayView1<T>) -> Result<Array1<T>> {
228 let n = params.len();
229
230 if grads.len() != n {
231 return Err(OptimError::DimensionMismatch(format!(
232 "Expected gradient size {}, got {}",
233 n,
234 grads.len()
235 )));
236 }
237
238 if self.momentum.is_none() {
240 self.momentum = Some(Array1::zeros(n));
241 self.velocity = Some(Array1::zeros(n));
242 if self.amsbound {
243 self.max_velocity = Some(Array1::zeros(n));
244 }
245 }
246
247 self.step_count += 1;
248 let t = T::from(self.step_count).unwrap();
249
250 let momentum = self.momentum.as_mut().unwrap();
251 let velocity = self.velocity.as_mut().unwrap();
252
253 let one = T::one();
254 let two = T::from(2).unwrap();
255
256 let effective_grads = if self.weight_decay > T::zero() {
258 grads.to_owned() + &(params.to_owned() * self.weight_decay)
259 } else {
260 grads.to_owned()
261 };
262
263 for i in 0..n {
265 momentum[i] = self.beta1 * momentum[i] + (one - self.beta1) * effective_grads[i];
266 }
267
268 for i in 0..n {
270 let grad_sq = effective_grads[i] * effective_grads[i];
271 velocity[i] = self.beta2 * velocity[i] + (one - self.beta2) * grad_sq;
272 }
273
274 if self.amsbound {
276 let max_vel = self.max_velocity.as_mut().unwrap();
277 for i in 0..n {
278 if velocity[i] > max_vel[i] {
279 max_vel[i] = velocity[i];
280 }
281 }
282 }
283
284 let bias_correction1 = one - self.beta1.powf(t);
286 let bias_correction2 = one - self.beta2.powf(t);
287
288 let lower_bound = self.final_lr * (one - one / (self.gamma * t + one));
291
292 let upper_bound = self.final_lr * (one + one / (self.gamma * t));
294
295 let mut updated_params = params.to_owned();
297
298 for i in 0..n {
299 let m_hat = momentum[i] / bias_correction1;
301
302 let v_hat = if self.amsbound {
304 self.max_velocity.as_ref().unwrap()[i] / bias_correction2
305 } else {
306 velocity[i] / bias_correction2
307 };
308
309 let step_size = self.learning_rate / (v_hat.sqrt() + self.epsilon);
311
312 let clipped_step_size = if step_size < lower_bound {
314 lower_bound
315 } else if step_size > upper_bound {
316 upper_bound
317 } else {
318 step_size
319 };
320
321 updated_params[i] = updated_params[i] - clipped_step_size * m_hat;
323 }
324
325 Ok(updated_params)
326 }
327
328 pub fn step_count(&self) -> usize {
330 self.step_count
331 }
332
333 pub fn reset(&mut self) {
335 self.momentum = None;
336 self.velocity = None;
337 self.max_velocity = None;
338 self.step_count = 0;
339 }
340
341 pub fn current_bounds(&self) -> (T, T) {
343 if self.step_count == 0 {
344 return (self.final_lr, self.final_lr);
345 }
346
347 let t = T::from(self.step_count).unwrap();
348 let one = T::one();
349
350 let lower_bound = self.final_lr * (one - one / (self.gamma * t + one));
351 let upper_bound = self.final_lr * (one + one / (self.gamma * t));
352
353 (lower_bound, upper_bound)
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use approx::assert_relative_eq;
361 use scirs2_core::ndarray_ext::array;
362
363 #[test]
364 fn test_adabound_creation() {
365 let optimizer = AdaBound::<f32>::default();
366 assert_eq!(optimizer.step_count(), 0);
367 }
368
369 #[test]
370 fn test_adabound_single_step() {
371 let mut optimizer = AdaBound::<f32>::default();
372 let params = array![1.0, 2.0, 3.0];
373 let grads = array![0.1, 0.2, 0.3];
374
375 let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
376
377 assert_eq!(updated_params.len(), 3);
378 assert_eq!(optimizer.step_count(), 1);
379
380 for i in 0..3 {
382 assert!(updated_params[i] < params[i]);
383 }
384 }
385
386 #[test]
387 fn test_adabound_multiple_steps() {
388 let mut optimizer = AdaBound::<f32>::default();
389 let mut params = array![1.0, 2.0, 3.0];
390
391 for _ in 0..10 {
392 let grads = array![0.1, 0.2, 0.3];
393 params = optimizer.step(params.view(), grads.view()).unwrap();
394 }
395
396 assert_eq!(optimizer.step_count(), 10);
397 }
398
399 #[test]
400 fn test_adabound_dynamic_bounds() {
401 let mut optimizer = AdaBound::<f32>::default();
402 let params = array![1.0, 2.0, 3.0];
403 let grads = array![0.1, 0.2, 0.3];
404
405 let (lower0, upper0) = optimizer.current_bounds();
407 assert_relative_eq!(lower0, 0.1, epsilon = 1e-6);
408 assert_relative_eq!(upper0, 0.1, epsilon = 1e-6);
409
410 optimizer.step(params.view(), grads.view()).unwrap();
412 let (lower1, upper1) = optimizer.current_bounds();
413 assert!(lower1 < upper1);
414 assert!(lower1 >= 0.0);
415
416 for _ in 0..10000 {
418 optimizer.step(params.view(), grads.view()).unwrap();
420 }
421 let (lower_final, upper_final) = optimizer.current_bounds();
422 assert_relative_eq!(lower_final, 0.1, epsilon = 0.01);
423 assert_relative_eq!(upper_final, 0.1, epsilon = 0.01);
424 }
425
426 #[test]
427 fn test_amsbound() {
428 let mut optimizer =
429 AdaBound::<f32>::new(0.001, 0.1, 0.9, 0.999, 1e-8, 1e-3, 0.0, true).unwrap();
430
431 let params = array![1.0, 2.0, 3.0];
432 let grads = array![0.1, 0.2, 0.3];
433
434 let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
435 assert_eq!(updated_params.len(), 3);
436 assert!(optimizer.max_velocity.is_some());
437 }
438
439 #[test]
440 fn test_adabound_weight_decay() {
441 let mut optimizer =
442 AdaBound::<f32>::new(0.001, 0.1, 0.9, 0.999, 1e-8, 1e-3, 0.01, false).unwrap();
443
444 let params = array![1.0, 2.0, 3.0];
445 let grads = array![0.1, 0.2, 0.3];
446
447 let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
448
449 for i in 0..3 {
451 assert!(updated_params[i] < params[i]);
452 }
453 }
454
455 #[test]
456 fn test_adabound_convergence() {
457 let mut optimizer = AdaBound::<f64>::default();
459 let mut params = array![5.0];
460
461 for _ in 0..500 {
462 let grads = params.mapv(|x| 2.0 * x);
464 params = optimizer.step(params.view(), grads.view()).unwrap();
465 }
466
467 assert!(
469 params[0].abs() < 0.1,
470 "Failed to converge, got {}",
471 params[0]
472 );
473 }
474
475 #[test]
476 fn test_adabound_reset() {
477 let mut optimizer = AdaBound::<f32>::default();
478 let params = array![1.0, 2.0, 3.0];
479 let grads = array![0.1, 0.2, 0.3];
480
481 optimizer.step(params.view(), grads.view()).unwrap();
482 assert_eq!(optimizer.step_count(), 1);
483
484 optimizer.reset();
485 assert_eq!(optimizer.step_count(), 0);
486 assert!(optimizer.momentum.is_none());
487 assert!(optimizer.velocity.is_none());
488 }
489}