use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::Result;
use crate::optimizers::Optimizer;
#[derive(Debug, Clone)]
pub struct LAMB<A: Float + ScalarOperand + Debug> {
learning_rate: A,
beta1: A,
beta2: A,
epsilon: A,
weight_decay: A,
bias_correction: bool,
m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
v: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
t: usize,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> LAMB<A> {
pub fn new(learning_rate: A) -> Self {
Self {
learning_rate,
beta1: A::from(0.9).expect("unwrap failed"),
beta2: A::from(0.999).expect("unwrap failed"),
epsilon: A::from(1e-6).expect("unwrap failed"),
weight_decay: A::zero(),
bias_correction: true,
m: None,
v: None,
t: 0,
}
}
pub fn new_with_config(
learning_rate: A,
beta1: A,
beta2: A,
epsilon: A,
weight_decay: A,
bias_correction: bool,
) -> Self {
Self {
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
bias_correction,
m: None,
v: None,
t: 0,
}
}
pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
self.beta1 = beta1;
self
}
pub fn get_beta1(&self) -> A {
self.beta1
}
pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
self.beta2 = beta2;
self
}
pub fn get_beta2(&self) -> A {
self.beta2
}
pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
self.epsilon = epsilon;
self
}
pub fn get_epsilon(&self) -> A {
self.epsilon
}
pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
self.weight_decay = weight_decay;
self
}
pub fn get_weight_decay(&self) -> A {
self.weight_decay
}
pub fn learning_rate(&self) -> A {
self.learning_rate
}
pub fn set_lr(&mut self, lr: A) {
self.learning_rate = lr;
}
pub fn reset(&mut self) {
self.m = None;
self.v = None;
self.t = 0;
}
}
impl<A, D> Optimizer<A, D> for LAMB<A>
where
A: Float + ScalarOperand + Debug + Send + Sync,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let params_dyn = params.to_owned().into_dyn();
let gradients_dyn = gradients.to_owned().into_dyn();
if self.m.is_none() {
self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
self.v = Some(vec![Array::zeros(params_dyn.raw_dim())]);
self.t = 0;
}
let m = self.m.as_mut().expect("unwrap failed");
let v = self.v.as_mut().expect("unwrap failed");
if m.is_empty() {
m.push(Array::zeros(params_dyn.raw_dim()));
v.push(Array::zeros(params_dyn.raw_dim()));
} else if m[0].raw_dim() != params_dyn.raw_dim() {
m[0] = Array::zeros(params_dyn.raw_dim());
v[0] = Array::zeros(params_dyn.raw_dim());
}
self.t += 1;
m[0] = &m[0] * self.beta1 + &gradients_dyn * (A::one() - self.beta1);
v[0] = &v[0] * self.beta2 + &(&gradients_dyn * &gradients_dyn * (A::one() - self.beta2));
let (m_hat, v_hat) = if self.bias_correction {
let bias1 = A::one() - self.beta1.powi(self.t as i32);
let bias2 = A::one() - self.beta2.powi(self.t as i32);
(&m[0] / bias1, &v[0] / bias2)
} else {
(m[0].clone(), v[0].clone())
};
let v_hat_sqrt = v_hat.mapv(|x| x.sqrt());
let adaptive_term = &m_hat / &(&v_hat_sqrt + self.epsilon);
let normalized_gradient = if self.weight_decay > A::zero() {
&adaptive_term + &(¶ms_dyn * self.weight_decay)
} else {
adaptive_term
};
let weight_norm = {
let norm_sq = params_dyn
.iter()
.map(|x| *x * *x)
.fold(A::zero(), |acc, x| acc + x);
norm_sq.sqrt()
};
let gradient_norm = {
let norm_sq = normalized_gradient
.iter()
.map(|x| *x * *x)
.fold(A::zero(), |acc, x| acc + x);
norm_sq.sqrt()
};
let trust_ratio = if weight_norm > A::zero() && gradient_norm > A::zero() {
weight_norm / gradient_norm
} else {
A::one()
};
let step = &normalized_gradient * (self.learning_rate * trust_ratio);
let updated_params = ¶ms_dyn - step;
Ok(updated_params
.into_dimensionality::<D>()
.expect("unwrap failed"))
}
fn get_learning_rate(&self) -> A {
self.learning_rate
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.learning_rate = learning_rate;
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_lamb_basic_creation() {
let optimizer: LAMB<f64> = LAMB::new(0.001);
assert_abs_diff_eq!(optimizer.learning_rate(), 0.001);
assert_abs_diff_eq!(optimizer.get_beta1(), 0.9);
assert_abs_diff_eq!(optimizer.get_beta2(), 0.999);
assert_abs_diff_eq!(optimizer.get_epsilon(), 1e-6);
assert_abs_diff_eq!(optimizer.get_weight_decay(), 0.0);
assert!(optimizer.bias_correction);
}
#[test]
fn test_lamb_convergence() {
let mut optimizer: LAMB<f64> = LAMB::new(0.1);
let mut params = Array1::from_vec(vec![5.0, 3.0]);
for _ in 0..50 {
let gradients = Array1::from_vec(vec![2.0 * params[0], 2.0 * params[1]]);
params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
}
assert!(params[0].abs() < 1.0);
assert!(params[1].abs() < 1.0);
}
#[test]
fn test_lamb_with_weight_decay() {
let mut optimizer: LAMB<f64> = LAMB::new_with_config(
0.1, 0.9, 0.999, 1e-6, 0.1, true, );
let mut params = Array1::from_vec(vec![1.0, 1.0]);
for _ in 0..20 {
let gradients = Array1::from_vec(vec![0.1, 0.1]);
params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
}
assert!(params[0] < 1.0);
assert!(params[1] < 1.0);
}
#[test]
fn test_lamb_reset() {
let mut optimizer: LAMB<f64> = LAMB::new(0.1);
let params = Array1::from_vec(vec![1.0]);
let gradients = Array1::from_vec(vec![0.5]);
let _ = optimizer.step(¶ms, &gradients).expect("unwrap failed");
assert!(optimizer.m.is_some());
assert!(optimizer.v.is_some());
assert_eq!(optimizer.t, 1);
optimizer.reset();
assert!(optimizer.m.is_none());
assert!(optimizer.v.is_none());
assert_eq!(optimizer.t, 0);
}
#[test]
fn test_lamb_trust_ratio() {
let mut optimizer: LAMB<f64> = LAMB::new(0.1);
let params = Array1::from_vec(vec![2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.4, 0.6]);
let new_params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
assert_ne!(new_params[0], params[0]);
assert_ne!(new_params[1], params[1]);
assert!(new_params[0] < params[0]); assert!(new_params[1] < params[1]); }
}