optirs_core/optimizers/
lion.rs1use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::optimizers::Optimizer;
12
13#[derive(Debug, Clone)]
41pub struct Lion<A: Float + ScalarOperand + Debug> {
42 learning_rate: A,
44 beta1: A,
46 beta2: A,
48 weight_decay: A,
50 m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
52}
53
54impl<A: Float + ScalarOperand + Debug + Send + Sync> Lion<A> {
55 pub fn new(learning_rate: A) -> Self {
61 Self {
62 learning_rate,
63 beta1: A::from(0.9).unwrap(),
64 beta2: A::from(0.99).unwrap(),
65 weight_decay: A::zero(),
66 m: None,
67 }
68 }
69
70 pub fn new_with_config(learning_rate: A, beta1: A, beta2: A, weight_decay: A) -> Self {
79 Self {
80 learning_rate,
81 beta1,
82 beta2,
83 weight_decay,
84 m: None,
85 }
86 }
87
88 pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
90 self.beta1 = beta1;
91 self
92 }
93
94 pub fn get_beta1(&self) -> A {
96 self.beta1
97 }
98
99 pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
101 self.beta2 = beta2;
102 self
103 }
104
105 pub fn get_beta2(&self) -> A {
107 self.beta2
108 }
109
110 pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
112 self.weight_decay = weight_decay;
113 self
114 }
115
116 pub fn get_weight_decay(&self) -> A {
118 self.weight_decay
119 }
120
121 pub fn learning_rate(&self) -> A {
123 self.learning_rate
124 }
125
126 pub fn set_lr(&mut self, lr: A) {
128 self.learning_rate = lr;
129 }
130
131 pub fn reset(&mut self) {
133 self.m = None;
134 }
135}
136
137impl<A, D> Optimizer<A, D> for Lion<A>
138where
139 A: Float + ScalarOperand + Debug + Send + Sync,
140 D: Dimension,
141{
142 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
143 let params_dyn = params.to_owned().into_dyn();
145 let gradients_dyn = gradients.to_owned().into_dyn();
146
147 if self.m.is_none() {
149 self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
150 }
151
152 let m = self.m.as_mut().unwrap();
153
154 if m.is_empty() {
156 m.push(Array::zeros(params_dyn.raw_dim()));
157 } else if m[0].raw_dim() != params_dyn.raw_dim() {
158 m[0] = Array::zeros(params_dyn.raw_dim());
160 }
161
162 let interpolated_update = &m[0] * self.beta1 + &gradients_dyn * (A::one() - self.beta1);
164
165 let sign_update = interpolated_update.mapv(|x| {
167 if x > A::zero() {
168 A::one()
169 } else if x < A::zero() {
170 -A::one()
171 } else {
172 A::zero()
173 }
174 });
175
176 let mut updated_params = params_dyn.clone();
178
179 if self.weight_decay > A::zero() {
181 updated_params = &updated_params * (A::one() - self.weight_decay * self.learning_rate);
182 }
183
184 updated_params = &updated_params - &sign_update * self.learning_rate;
186
187 m[0] = &m[0] * self.beta2 + &gradients_dyn * (A::one() - self.beta2);
189
190 Ok(updated_params.into_dimensionality::<D>().unwrap())
192 }
193
194 fn get_learning_rate(&self) -> A {
195 self.learning_rate
196 }
197
198 fn set_learning_rate(&mut self, learning_rate: A) {
199 self.learning_rate = learning_rate;
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use approx::assert_abs_diff_eq;
207 use scirs2_core::ndarray::Array1;
208
209 #[test]
210 fn test_lion_basic_creation() {
211 let optimizer: Lion<f64> = Lion::new(0.001);
212 assert_abs_diff_eq!(optimizer.learning_rate(), 0.001);
213 assert_abs_diff_eq!(optimizer.get_beta1(), 0.9);
214 assert_abs_diff_eq!(optimizer.get_beta2(), 0.99);
215 assert_abs_diff_eq!(optimizer.get_weight_decay(), 0.0);
216 }
217
218 #[test]
219 fn test_lion_convergence() {
220 let mut optimizer: Lion<f64> = Lion::new(0.1); let mut params = Array1::from_vec(vec![5.0]);
224
225 for _ in 0..40 {
227 let gradients = Array1::from_vec(vec![2.0 * params[0]]);
230 params = optimizer.step(¶ms, &gradients).unwrap();
231 }
232
233 assert!(params[0].abs() < 1.1);
235 }
236
237 #[test]
238 fn test_lion_reset() {
239 let mut optimizer: Lion<f64> = Lion::new(0.1);
240
241 let params = Array1::from_vec(vec![1.0]);
243 let gradients = Array1::from_vec(vec![0.1]);
244 let _ = optimizer.step(¶ms, &gradients).unwrap();
245
246 optimizer.reset();
248
249 let next_step = optimizer.step(¶ms, &gradients).unwrap();
251
252 let mut fresh_optimizer: Lion<f64> = Lion::new(0.1);
254 let fresh_step = fresh_optimizer.step(¶ms, &gradients).unwrap();
255
256 assert_abs_diff_eq!(next_step[0], fresh_step[0], epsilon = 1e-10);
257 }
258}