1use crate::error::{OptimError, Result};
15use scirs2_core::ndarray_ext::{Array1, ArrayView1};
16use scirs2_core::numeric::{Float, Zero};
17use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct AdaBound<T: Float> {
34 learning_rate: T,
36
37 final_lr: T,
40
41 beta1: T,
43
44 beta2: T,
46
47 epsilon: T,
49
50 gamma: T,
53
54 weight_decay: T,
56
57 amsbound: bool,
59
60 momentum: Option<Array1<T>>,
62
63 velocity: Option<Array1<T>>,
65
66 max_velocity: Option<Array1<T>>,
68
69 step_count: usize,
71}
72
73use scirs2_core::ndarray::ScalarOperand;
74
75impl<T: Float + ScalarOperand> Default for AdaBound<T> {
76 fn default() -> Self {
77 Self::new(
78 T::from(0.001).expect("unwrap failed"), T::from(0.1).expect("unwrap failed"), T::from(0.9).expect("unwrap failed"), T::from(0.999).expect("unwrap failed"), T::from(1e-8).expect("unwrap failed"), T::from(1e-3).expect("unwrap failed"), T::zero(), false, )
87 .expect("unwrap failed")
88 }
89}
90
91impl<T: Float + ScalarOperand> AdaBound<T> {
92 #[allow(clippy::too_many_arguments)]
120 pub fn new(
121 learning_rate: T,
122 final_lr: T,
123 beta1: T,
124 beta2: T,
125 epsilon: T,
126 gamma: T,
127 weight_decay: T,
128 amsbound: bool,
129 ) -> Result<Self> {
130 let lr_f64 = learning_rate.to_f64().expect("unwrap failed");
131 let final_f64 = final_lr.to_f64().expect("unwrap failed");
132 let beta1_f64 = beta1.to_f64().expect("unwrap failed");
133 let beta2_f64 = beta2.to_f64().expect("unwrap failed");
134 let eps_f64 = epsilon.to_f64().expect("unwrap failed");
135 let gamma_f64 = gamma.to_f64().expect("unwrap failed");
136 let wd_f64 = weight_decay.to_f64().expect("unwrap failed");
137
138 if lr_f64 <= 0.0 {
139 return Err(OptimError::InvalidParameter(format!(
140 "learning_rate must be positive, got {}",
141 lr_f64
142 )));
143 }
144 if final_f64 <= 0.0 {
145 return Err(OptimError::InvalidParameter(format!(
146 "final_lr must be positive, got {}",
147 final_f64
148 )));
149 }
150 if beta1_f64 <= 0.0 || beta1_f64 >= 1.0 {
151 return Err(OptimError::InvalidParameter(format!(
152 "beta1 must be in (0, 1), got {}",
153 beta1_f64
154 )));
155 }
156 if beta2_f64 <= 0.0 || beta2_f64 >= 1.0 {
157 return Err(OptimError::InvalidParameter(format!(
158 "beta2 must be in (0, 1), got {}",
159 beta2_f64
160 )));
161 }
162 if eps_f64 <= 0.0 {
163 return Err(OptimError::InvalidParameter(format!(
164 "epsilon must be positive, got {}",
165 eps_f64
166 )));
167 }
168 if gamma_f64 <= 0.0 {
169 return Err(OptimError::InvalidParameter(format!(
170 "gamma must be positive, got {}",
171 gamma_f64
172 )));
173 }
174 if wd_f64 < 0.0 {
175 return Err(OptimError::InvalidParameter(format!(
176 "weight_decay must be non-negative, got {}",
177 wd_f64
178 )));
179 }
180
181 Ok(Self {
182 learning_rate,
183 final_lr,
184 beta1,
185 beta2,
186 epsilon,
187 gamma,
188 weight_decay,
189 amsbound,
190 momentum: None,
191 velocity: None,
192 max_velocity: None,
193 step_count: 0,
194 })
195 }
196
197 pub fn step(&mut self, params: ArrayView1<T>, grads: ArrayView1<T>) -> Result<Array1<T>> {
228 let n = params.len();
229
230 if grads.len() != n {
231 return Err(OptimError::DimensionMismatch(format!(
232 "Expected gradient size {}, got {}",
233 n,
234 grads.len()
235 )));
236 }
237
238 if self.momentum.is_none() {
240 self.momentum = Some(Array1::zeros(n));
241 self.velocity = Some(Array1::zeros(n));
242 if self.amsbound {
243 self.max_velocity = Some(Array1::zeros(n));
244 }
245 }
246
247 self.step_count += 1;
248 let t = T::from(self.step_count).expect("unwrap failed");
249
250 let momentum = self.momentum.as_mut().expect("unwrap failed");
251 let velocity = self.velocity.as_mut().expect("unwrap failed");
252
253 let one = T::one();
254 let two = T::from(2).expect("unwrap failed");
255
256 let effective_grads = if self.weight_decay > T::zero() {
258 grads.to_owned() + &(params.to_owned() * self.weight_decay)
259 } else {
260 grads.to_owned()
261 };
262
263 for i in 0..n {
265 momentum[i] = self.beta1 * momentum[i] + (one - self.beta1) * effective_grads[i];
266 }
267
268 for i in 0..n {
270 let grad_sq = effective_grads[i] * effective_grads[i];
271 velocity[i] = self.beta2 * velocity[i] + (one - self.beta2) * grad_sq;
272 }
273
274 if self.amsbound {
276 let max_vel = self.max_velocity.as_mut().expect("unwrap failed");
277 for i in 0..n {
278 if velocity[i] > max_vel[i] {
279 max_vel[i] = velocity[i];
280 }
281 }
282 }
283
284 let bias_correction1 = one - self.beta1.powf(t);
286 let bias_correction2 = one - self.beta2.powf(t);
287
288 let lower_bound = self.final_lr * (one - one / (self.gamma * t + one));
291
292 let upper_bound = self.final_lr * (one + one / (self.gamma * t));
294
295 let mut updated_params = params.to_owned();
297
298 for i in 0..n {
299 let m_hat = momentum[i] / bias_correction1;
301
302 let v_hat = if self.amsbound {
304 self.max_velocity.as_ref().expect("unwrap failed")[i] / bias_correction2
305 } else {
306 velocity[i] / bias_correction2
307 };
308
309 let step_size = self.learning_rate / (v_hat.sqrt() + self.epsilon);
311
312 let clipped_step_size = if step_size < lower_bound {
314 lower_bound
315 } else if step_size > upper_bound {
316 upper_bound
317 } else {
318 step_size
319 };
320
321 updated_params[i] = updated_params[i] - clipped_step_size * m_hat;
323 }
324
325 Ok(updated_params)
326 }
327
328 pub fn step_count(&self) -> usize {
330 self.step_count
331 }
332
333 pub fn reset(&mut self) {
335 self.momentum = None;
336 self.velocity = None;
337 self.max_velocity = None;
338 self.step_count = 0;
339 }
340
341 pub fn current_bounds(&self) -> (T, T) {
343 if self.step_count == 0 {
344 return (self.final_lr, self.final_lr);
345 }
346
347 let t = T::from(self.step_count).expect("unwrap failed");
348 let one = T::one();
349
350 let lower_bound = self.final_lr * (one - one / (self.gamma * t + one));
351 let upper_bound = self.final_lr * (one + one / (self.gamma * t));
352
353 (lower_bound, upper_bound)
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use approx::assert_relative_eq;
361 use scirs2_core::ndarray_ext::array;
362
363 #[test]
364 fn test_adabound_creation() {
365 let optimizer = AdaBound::<f32>::default();
366 assert_eq!(optimizer.step_count(), 0);
367 }
368
369 #[test]
370 fn test_adabound_single_step() {
371 let mut optimizer = AdaBound::<f32>::default();
372 let params = array![1.0, 2.0, 3.0];
373 let grads = array![0.1, 0.2, 0.3];
374
375 let updated_params = optimizer
376 .step(params.view(), grads.view())
377 .expect("unwrap failed");
378
379 assert_eq!(updated_params.len(), 3);
380 assert_eq!(optimizer.step_count(), 1);
381
382 for i in 0..3 {
384 assert!(updated_params[i] < params[i]);
385 }
386 }
387
388 #[test]
389 fn test_adabound_multiple_steps() {
390 let mut optimizer = AdaBound::<f32>::default();
391 let mut params = array![1.0, 2.0, 3.0];
392
393 for _ in 0..10 {
394 let grads = array![0.1, 0.2, 0.3];
395 params = optimizer
396 .step(params.view(), grads.view())
397 .expect("unwrap failed");
398 }
399
400 assert_eq!(optimizer.step_count(), 10);
401 }
402
403 #[test]
404 fn test_adabound_dynamic_bounds() {
405 let mut optimizer = AdaBound::<f32>::default();
406 let params = array![1.0, 2.0, 3.0];
407 let grads = array![0.1, 0.2, 0.3];
408
409 let (lower0, upper0) = optimizer.current_bounds();
411 assert_relative_eq!(lower0, 0.1, epsilon = 1e-6);
412 assert_relative_eq!(upper0, 0.1, epsilon = 1e-6);
413
414 optimizer
416 .step(params.view(), grads.view())
417 .expect("unwrap failed");
418 let (lower1, upper1) = optimizer.current_bounds();
419 assert!(lower1 < upper1);
420 assert!(lower1 >= 0.0);
421
422 for _ in 0..10000 {
424 optimizer
426 .step(params.view(), grads.view())
427 .expect("unwrap failed");
428 }
429 let (lower_final, upper_final) = optimizer.current_bounds();
430 assert_relative_eq!(lower_final, 0.1, epsilon = 0.01);
431 assert_relative_eq!(upper_final, 0.1, epsilon = 0.01);
432 }
433
434 #[test]
435 fn test_amsbound() {
436 let mut optimizer = AdaBound::<f32>::new(0.001, 0.1, 0.9, 0.999, 1e-8, 1e-3, 0.0, true)
437 .expect("unwrap failed");
438
439 let params = array![1.0, 2.0, 3.0];
440 let grads = array![0.1, 0.2, 0.3];
441
442 let updated_params = optimizer
443 .step(params.view(), grads.view())
444 .expect("unwrap failed");
445 assert_eq!(updated_params.len(), 3);
446 assert!(optimizer.max_velocity.is_some());
447 }
448
449 #[test]
450 fn test_adabound_weight_decay() {
451 let mut optimizer = AdaBound::<f32>::new(0.001, 0.1, 0.9, 0.999, 1e-8, 1e-3, 0.01, false)
452 .expect("unwrap failed");
453
454 let params = array![1.0, 2.0, 3.0];
455 let grads = array![0.1, 0.2, 0.3];
456
457 let updated_params = optimizer
458 .step(params.view(), grads.view())
459 .expect("unwrap failed");
460
461 for i in 0..3 {
463 assert!(updated_params[i] < params[i]);
464 }
465 }
466
467 #[test]
468 fn test_adabound_convergence() {
469 let mut optimizer = AdaBound::<f64>::default();
471 let mut params = array![5.0];
472
473 for _ in 0..500 {
474 let grads = params.mapv(|x| 2.0 * x);
476 params = optimizer
477 .step(params.view(), grads.view())
478 .expect("unwrap failed");
479 }
480
481 assert!(
483 params[0].abs() < 0.1,
484 "Failed to converge, got {}",
485 params[0]
486 );
487 }
488
489 #[test]
490 fn test_adabound_reset() {
491 let mut optimizer = AdaBound::<f32>::default();
492 let params = array![1.0, 2.0, 3.0];
493 let grads = array![0.1, 0.2, 0.3];
494
495 optimizer
496 .step(params.view(), grads.view())
497 .expect("unwrap failed");
498 assert_eq!(optimizer.step_count(), 1);
499
500 optimizer.reset();
501 assert_eq!(optimizer.step_count(), 0);
502 assert!(optimizer.momentum.is_none());
503 assert!(optimizer.velocity.is_none());
504 }
505}