use crate::error::{OptimError, Result};
use scirs2_core::ndarray_ext::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, Zero};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaBound<T: Float> {
learning_rate: T,
final_lr: T,
beta1: T,
beta2: T,
epsilon: T,
gamma: T,
weight_decay: T,
amsbound: bool,
momentum: Option<Array1<T>>,
velocity: Option<Array1<T>>,
max_velocity: Option<Array1<T>>,
step_count: usize,
}
use scirs2_core::ndarray::ScalarOperand;
impl<T: Float + ScalarOperand> Default for AdaBound<T> {
fn default() -> Self {
Self::new(
T::from(0.001).expect("unwrap failed"), T::from(0.1).expect("unwrap failed"), T::from(0.9).expect("unwrap failed"), T::from(0.999).expect("unwrap failed"), T::from(1e-8).expect("unwrap failed"), T::from(1e-3).expect("unwrap failed"), T::zero(), false, )
.expect("unwrap failed")
}
}
impl<T: Float + ScalarOperand> AdaBound<T> {
#[allow(clippy::too_many_arguments)]
pub fn new(
learning_rate: T,
final_lr: T,
beta1: T,
beta2: T,
epsilon: T,
gamma: T,
weight_decay: T,
amsbound: bool,
) -> Result<Self> {
let lr_f64 = learning_rate.to_f64().expect("unwrap failed");
let final_f64 = final_lr.to_f64().expect("unwrap failed");
let beta1_f64 = beta1.to_f64().expect("unwrap failed");
let beta2_f64 = beta2.to_f64().expect("unwrap failed");
let eps_f64 = epsilon.to_f64().expect("unwrap failed");
let gamma_f64 = gamma.to_f64().expect("unwrap failed");
let wd_f64 = weight_decay.to_f64().expect("unwrap failed");
if lr_f64 <= 0.0 {
return Err(OptimError::InvalidParameter(format!(
"learning_rate must be positive, got {}",
lr_f64
)));
}
if final_f64 <= 0.0 {
return Err(OptimError::InvalidParameter(format!(
"final_lr must be positive, got {}",
final_f64
)));
}
if beta1_f64 <= 0.0 || beta1_f64 >= 1.0 {
return Err(OptimError::InvalidParameter(format!(
"beta1 must be in (0, 1), got {}",
beta1_f64
)));
}
if beta2_f64 <= 0.0 || beta2_f64 >= 1.0 {
return Err(OptimError::InvalidParameter(format!(
"beta2 must be in (0, 1), got {}",
beta2_f64
)));
}
if eps_f64 <= 0.0 {
return Err(OptimError::InvalidParameter(format!(
"epsilon must be positive, got {}",
eps_f64
)));
}
if gamma_f64 <= 0.0 {
return Err(OptimError::InvalidParameter(format!(
"gamma must be positive, got {}",
gamma_f64
)));
}
if wd_f64 < 0.0 {
return Err(OptimError::InvalidParameter(format!(
"weight_decay must be non-negative, got {}",
wd_f64
)));
}
Ok(Self {
learning_rate,
final_lr,
beta1,
beta2,
epsilon,
gamma,
weight_decay,
amsbound,
momentum: None,
velocity: None,
max_velocity: None,
step_count: 0,
})
}
pub fn step(&mut self, params: ArrayView1<T>, grads: ArrayView1<T>) -> Result<Array1<T>> {
let n = params.len();
if grads.len() != n {
return Err(OptimError::DimensionMismatch(format!(
"Expected gradient size {}, got {}",
n,
grads.len()
)));
}
if self.momentum.is_none() {
self.momentum = Some(Array1::zeros(n));
self.velocity = Some(Array1::zeros(n));
if self.amsbound {
self.max_velocity = Some(Array1::zeros(n));
}
}
self.step_count += 1;
let t = T::from(self.step_count).expect("unwrap failed");
let momentum = self.momentum.as_mut().expect("unwrap failed");
let velocity = self.velocity.as_mut().expect("unwrap failed");
let one = T::one();
let two = T::from(2).expect("unwrap failed");
let effective_grads = if self.weight_decay > T::zero() {
grads.to_owned() + &(params.to_owned() * self.weight_decay)
} else {
grads.to_owned()
};
for i in 0..n {
momentum[i] = self.beta1 * momentum[i] + (one - self.beta1) * effective_grads[i];
}
for i in 0..n {
let grad_sq = effective_grads[i] * effective_grads[i];
velocity[i] = self.beta2 * velocity[i] + (one - self.beta2) * grad_sq;
}
if self.amsbound {
let max_vel = self.max_velocity.as_mut().expect("unwrap failed");
for i in 0..n {
if velocity[i] > max_vel[i] {
max_vel[i] = velocity[i];
}
}
}
let bias_correction1 = one - self.beta1.powf(t);
let bias_correction2 = one - self.beta2.powf(t);
let lower_bound = self.final_lr * (one - one / (self.gamma * t + one));
let upper_bound = self.final_lr * (one + one / (self.gamma * t));
let mut updated_params = params.to_owned();
for i in 0..n {
let m_hat = momentum[i] / bias_correction1;
let v_hat = if self.amsbound {
self.max_velocity.as_ref().expect("unwrap failed")[i] / bias_correction2
} else {
velocity[i] / bias_correction2
};
let step_size = self.learning_rate / (v_hat.sqrt() + self.epsilon);
let clipped_step_size = if step_size < lower_bound {
lower_bound
} else if step_size > upper_bound {
upper_bound
} else {
step_size
};
updated_params[i] = updated_params[i] - clipped_step_size * m_hat;
}
Ok(updated_params)
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn reset(&mut self) {
self.momentum = None;
self.velocity = None;
self.max_velocity = None;
self.step_count = 0;
}
pub fn current_bounds(&self) -> (T, T) {
if self.step_count == 0 {
return (self.final_lr, self.final_lr);
}
let t = T::from(self.step_count).expect("unwrap failed");
let one = T::one();
let lower_bound = self.final_lr * (one - one / (self.gamma * t + one));
let upper_bound = self.final_lr * (one + one / (self.gamma * t));
(lower_bound, upper_bound)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray_ext::array;
#[test]
fn test_adabound_creation() {
let optimizer = AdaBound::<f32>::default();
assert_eq!(optimizer.step_count(), 0);
}
#[test]
fn test_adabound_single_step() {
let mut optimizer = AdaBound::<f32>::default();
let params = array![1.0, 2.0, 3.0];
let grads = array![0.1, 0.2, 0.3];
let updated_params = optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
assert_eq!(updated_params.len(), 3);
assert_eq!(optimizer.step_count(), 1);
for i in 0..3 {
assert!(updated_params[i] < params[i]);
}
}
#[test]
fn test_adabound_multiple_steps() {
let mut optimizer = AdaBound::<f32>::default();
let mut params = array![1.0, 2.0, 3.0];
for _ in 0..10 {
let grads = array![0.1, 0.2, 0.3];
params = optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
}
assert_eq!(optimizer.step_count(), 10);
}
#[test]
fn test_adabound_dynamic_bounds() {
let mut optimizer = AdaBound::<f32>::default();
let params = array![1.0, 2.0, 3.0];
let grads = array![0.1, 0.2, 0.3];
let (lower0, upper0) = optimizer.current_bounds();
assert_relative_eq!(lower0, 0.1, epsilon = 1e-6);
assert_relative_eq!(upper0, 0.1, epsilon = 1e-6);
optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
let (lower1, upper1) = optimizer.current_bounds();
assert!(lower1 < upper1);
assert!(lower1 >= 0.0);
for _ in 0..10000 {
optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
}
let (lower_final, upper_final) = optimizer.current_bounds();
assert_relative_eq!(lower_final, 0.1, epsilon = 0.01);
assert_relative_eq!(upper_final, 0.1, epsilon = 0.01);
}
#[test]
fn test_amsbound() {
let mut optimizer = AdaBound::<f32>::new(0.001, 0.1, 0.9, 0.999, 1e-8, 1e-3, 0.0, true)
.expect("unwrap failed");
let params = array![1.0, 2.0, 3.0];
let grads = array![0.1, 0.2, 0.3];
let updated_params = optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
assert_eq!(updated_params.len(), 3);
assert!(optimizer.max_velocity.is_some());
}
#[test]
fn test_adabound_weight_decay() {
let mut optimizer = AdaBound::<f32>::new(0.001, 0.1, 0.9, 0.999, 1e-8, 1e-3, 0.01, false)
.expect("unwrap failed");
let params = array![1.0, 2.0, 3.0];
let grads = array![0.1, 0.2, 0.3];
let updated_params = optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
for i in 0..3 {
assert!(updated_params[i] < params[i]);
}
}
#[test]
fn test_adabound_convergence() {
let mut optimizer = AdaBound::<f64>::default();
let mut params = array![5.0];
for _ in 0..500 {
let grads = params.mapv(|x| 2.0 * x);
params = optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
}
assert!(
params[0].abs() < 0.1,
"Failed to converge, got {}",
params[0]
);
}
#[test]
fn test_adabound_reset() {
let mut optimizer = AdaBound::<f32>::default();
let params = array![1.0, 2.0, 3.0];
let grads = array![0.1, 0.2, 0.3];
optimizer
.step(params.view(), grads.view())
.expect("unwrap failed");
assert_eq!(optimizer.step_count(), 1);
optimizer.reset();
assert_eq!(optimizer.step_count(), 0);
assert!(optimizer.momentum.is_none());
assert!(optimizer.velocity.is_none());
}
}