use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::Result;
use crate::optimizers::Optimizer;
#[derive(Debug, Clone)]
pub struct Lion<A: Float + ScalarOperand + Debug> {
learning_rate: A,
beta1: A,
beta2: A,
weight_decay: A,
m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> Lion<A> {
pub fn new(learning_rate: A) -> Self {
Self {
learning_rate,
beta1: A::from(0.9).expect("unwrap failed"),
beta2: A::from(0.99).expect("unwrap failed"),
weight_decay: A::zero(),
m: None,
}
}
pub fn new_with_config(learning_rate: A, beta1: A, beta2: A, weight_decay: A) -> Self {
Self {
learning_rate,
beta1,
beta2,
weight_decay,
m: None,
}
}
pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
self.beta1 = beta1;
self
}
pub fn get_beta1(&self) -> A {
self.beta1
}
pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
self.beta2 = beta2;
self
}
pub fn get_beta2(&self) -> A {
self.beta2
}
pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
self.weight_decay = weight_decay;
self
}
pub fn get_weight_decay(&self) -> A {
self.weight_decay
}
pub fn learning_rate(&self) -> A {
self.learning_rate
}
pub fn set_lr(&mut self, lr: A) {
self.learning_rate = lr;
}
pub fn reset(&mut self) {
self.m = None;
}
}
impl<A, D> Optimizer<A, D> for Lion<A>
where
A: Float + ScalarOperand + Debug + Send + Sync,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let params_dyn = params.to_owned().into_dyn();
let gradients_dyn = gradients.to_owned().into_dyn();
if self.m.is_none() {
self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
}
let m = self.m.as_mut().expect("unwrap failed");
if m.is_empty() {
m.push(Array::zeros(params_dyn.raw_dim()));
} else if m[0].raw_dim() != params_dyn.raw_dim() {
m[0] = Array::zeros(params_dyn.raw_dim());
}
let interpolated_update = &m[0] * self.beta1 + &gradients_dyn * (A::one() - self.beta1);
let sign_update = interpolated_update.mapv(|x| {
if x > A::zero() {
A::one()
} else if x < A::zero() {
-A::one()
} else {
A::zero()
}
});
let mut updated_params = params_dyn.clone();
if self.weight_decay > A::zero() {
updated_params = &updated_params * (A::one() - self.weight_decay * self.learning_rate);
}
updated_params = &updated_params - &sign_update * self.learning_rate;
m[0] = &m[0] * self.beta2 + &gradients_dyn * (A::one() - self.beta2);
Ok(updated_params
.into_dimensionality::<D>()
.expect("unwrap failed"))
}
fn get_learning_rate(&self) -> A {
self.learning_rate
}
fn set_learning_rate(&mut self, learning_rate: A) {
self.learning_rate = learning_rate;
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_lion_basic_creation() {
let optimizer: Lion<f64> = Lion::new(0.001);
assert_abs_diff_eq!(optimizer.learning_rate(), 0.001);
assert_abs_diff_eq!(optimizer.get_beta1(), 0.9);
assert_abs_diff_eq!(optimizer.get_beta2(), 0.99);
assert_abs_diff_eq!(optimizer.get_weight_decay(), 0.0);
}
#[test]
fn test_lion_convergence() {
let mut optimizer: Lion<f64> = Lion::new(0.1);
let mut params = Array1::from_vec(vec![5.0]);
for _ in 0..40 {
let gradients = Array1::from_vec(vec![2.0 * params[0]]);
params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
}
assert!(params[0].abs() < 1.1);
}
#[test]
fn test_lion_reset() {
let mut optimizer: Lion<f64> = Lion::new(0.1);
let params = Array1::from_vec(vec![1.0]);
let gradients = Array1::from_vec(vec![0.1]);
let _ = optimizer.step(¶ms, &gradients).expect("unwrap failed");
optimizer.reset();
let next_step = optimizer.step(¶ms, &gradients).expect("unwrap failed");
let mut fresh_optimizer: Lion<f64> = Lion::new(0.1);
let fresh_step = fresh_optimizer
.step(¶ms, &gradients)
.expect("unwrap failed");
assert_abs_diff_eq!(next_step[0], fresh_step[0], epsilon = 1e-10);
}
}