use crate::error::TrainResult;
use scirs2_core::ndarray::Array1;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct LionConfig {
pub learning_rate: f64,
pub beta1: f64,
pub beta2: f64,
pub weight_decay: f64,
}
impl Default for LionConfig {
fn default() -> Self {
Self {
learning_rate: 1e-4,
beta1: 0.9,
beta2: 0.99,
weight_decay: 0.0,
}
}
}
pub struct LionOptimizer {
config: LionConfig,
momentum: HashMap<String, Array1<f64>>,
}
impl LionOptimizer {
pub fn new(config: LionConfig) -> TrainResult<Self> {
if config.learning_rate <= 0.0 {
return Err(crate::error::TrainError::ConfigError(
"Learning rate must be positive".to_string(),
));
}
if !(0.0..1.0).contains(&config.beta1) {
return Err(crate::error::TrainError::ConfigError(
"beta1 must be in [0, 1)".to_string(),
));
}
if !(0.0..1.0).contains(&config.beta2) {
return Err(crate::error::TrainError::ConfigError(
"beta2 must be in [0, 1)".to_string(),
));
}
if config.weight_decay < 0.0 {
return Err(crate::error::TrainError::ConfigError(
"weight_decay must be non-negative".to_string(),
));
}
Ok(Self {
config,
momentum: HashMap::new(),
})
}
pub fn step(
&mut self,
params: &mut HashMap<String, Array1<f64>>,
gradients: &HashMap<String, Array1<f64>>,
) -> TrainResult<()> {
for (name, param) in params.iter_mut() {
if let Some(grad) = gradients.get(name) {
let momentum = self
.momentum
.entry(name.clone())
.or_insert_with(|| Array1::zeros(param.len()));
let update_direction = momentum.mapv(|m| m * self.config.beta1)
+ grad.mapv(|g| g * (1.0 - self.config.beta1));
for i in 0..param.len() {
let sign_update = if update_direction[i] > 0.0 {
1.0
} else if update_direction[i] < 0.0 {
-1.0
} else {
0.0
};
let update = sign_update + self.config.weight_decay * param[i];
param[i] -= self.config.learning_rate * update;
}
*momentum = momentum.mapv(|m| m * self.config.beta2)
+ grad.mapv(|g| g * (1.0 - self.config.beta2));
}
}
Ok(())
}
pub fn get_lr(&self) -> f64 {
self.config.learning_rate
}
pub fn set_lr(&mut self, lr: f64) {
self.config.learning_rate = lr;
}
pub fn state_dict(&self) -> HashMap<String, Vec<f64>> {
self.momentum
.iter()
.map(|(k, v)| (format!("momentum.{}", k), v.to_vec()))
.collect()
}
pub fn load_state_dict(&mut self, state: &HashMap<String, Vec<f64>>) -> TrainResult<()> {
for (key, value) in state {
if let Some(param_name) = key.strip_prefix("momentum.") {
self.momentum
.insert(param_name.to_string(), Array1::from_vec(value.clone()));
}
}
Ok(())
}
pub fn reset(&mut self) {
self.momentum.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
use std::collections::HashMap;
#[test]
fn test_lion_optimizer() {
let config = LionConfig::default();
let mut optimizer = LionOptimizer::new(config).expect("unwrap");
let mut params = HashMap::new();
params.insert("w".to_string(), Array1::from_vec(vec![1.0, 2.0, 3.0]));
let mut gradients = HashMap::new();
gradients.insert("w".to_string(), Array1::from_vec(vec![0.1, 0.2, 0.3]));
optimizer.step(&mut params, &gradients).expect("unwrap");
let w = params.get("w").expect("unwrap");
assert!(w[0] < 1.0);
assert!(w[1] < 2.0);
assert!(w[2] < 3.0);
}
#[test]
fn test_lion_with_weight_decay() {
let config = LionConfig {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.99,
weight_decay: 0.01,
};
let mut optimizer = LionOptimizer::new(config).expect("unwrap");
let mut params = HashMap::new();
params.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0]));
let mut gradients = HashMap::new();
gradients.insert("w".to_string(), Array1::from_vec(vec![0.1, 0.1]));
let initial_w = params.get("w").expect("unwrap")[0];
optimizer.step(&mut params, &gradients).expect("unwrap");
let updated_w = params.get("w").expect("unwrap")[0];
assert!(updated_w < initial_w);
}
#[test]
fn test_lion_sign_based_update() {
let config = LionConfig {
learning_rate: 1e-2,
beta1: 0.0, beta2: 0.0,
weight_decay: 0.0,
};
let mut optimizer = LionOptimizer::new(config).expect("unwrap");
let mut params = HashMap::new();
params.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0, 1.0]));
let mut gradients = HashMap::new();
gradients.insert(
"w".to_string(),
Array1::from_vec(vec![0.1, 1.0, 100.0]), );
optimizer.step(&mut params, &gradients).expect("unwrap");
let w = params.get("w").expect("unwrap");
let delta0 = 1.0 - w[0];
let delta1 = 1.0 - w[1];
let delta2 = 1.0 - w[2];
assert!((delta0 - delta1).abs() < 1e-10);
assert!((delta1 - delta2).abs() < 1e-10);
}
#[test]
fn test_lion_state_dict() {
let config = LionConfig::default();
let mut optimizer = LionOptimizer::new(config).expect("unwrap");
let mut params = HashMap::new();
params.insert("w".to_string(), Array1::from_vec(vec![1.0, 2.0]));
let mut gradients = HashMap::new();
gradients.insert("w".to_string(), Array1::from_vec(vec![0.1, 0.2]));
optimizer.step(&mut params, &gradients).expect("unwrap");
let state = optimizer.state_dict();
assert!(state.contains_key("momentum.w"));
let mut optimizer2 = LionOptimizer::new(LionConfig::default()).expect("unwrap");
optimizer2.load_state_dict(&state).expect("unwrap");
assert_eq!(
optimizer.momentum.get("w").expect("unwrap").to_vec(),
optimizer2.momentum.get("w").expect("unwrap").to_vec()
);
}
#[test]
fn test_lion_lr_schedule() {
let config = LionConfig::default();
let mut optimizer = LionOptimizer::new(config).expect("unwrap");
assert!((optimizer.get_lr() - 1e-4).abs() < 1e-10);
optimizer.set_lr(1e-3);
assert!((optimizer.get_lr() - 1e-3).abs() < 1e-10);
}
#[test]
fn test_lion_invalid_config() {
let config = LionConfig {
learning_rate: -1.0,
..Default::default()
};
assert!(LionOptimizer::new(config).is_err());
let config = LionConfig {
beta1: 1.5,
..Default::default()
};
assert!(LionOptimizer::new(config).is_err());
let config = LionConfig {
beta2: -0.1,
..Default::default()
};
assert!(LionOptimizer::new(config).is_err());
let config = LionConfig {
weight_decay: -0.1,
..Default::default()
};
assert!(LionOptimizer::new(config).is_err());
}
#[test]
fn test_lion_reset() {
let config = LionConfig::default();
let mut optimizer = LionOptimizer::new(config).expect("unwrap");
let mut params = HashMap::new();
params.insert("w".to_string(), Array1::from_vec(vec![1.0]));
let mut gradients = HashMap::new();
gradients.insert("w".to_string(), Array1::from_vec(vec![0.1]));
optimizer.step(&mut params, &gradients).expect("unwrap");
assert!(!optimizer.momentum.is_empty());
optimizer.reset();
assert!(optimizer.momentum.is_empty());
}
}