optirs_core/optimizers/
lion.rs1use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::optimizers::Optimizer;
12
13#[derive(Debug, Clone)]
41pub struct Lion<A: Float + ScalarOperand + Debug> {
42 learning_rate: A,
44 beta1: A,
46 beta2: A,
48 weight_decay: A,
50 m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
52}
53
54impl<A: Float + ScalarOperand + Debug + Send + Sync> Lion<A> {
55 pub fn new(learning_rate: A) -> Self {
61 Self {
62 learning_rate,
63 beta1: A::from(0.9).expect("unwrap failed"),
64 beta2: A::from(0.99).expect("unwrap failed"),
65 weight_decay: A::zero(),
66 m: None,
67 }
68 }
69
70 pub fn new_with_config(learning_rate: A, beta1: A, beta2: A, weight_decay: A) -> Self {
79 Self {
80 learning_rate,
81 beta1,
82 beta2,
83 weight_decay,
84 m: None,
85 }
86 }
87
88 pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
90 self.beta1 = beta1;
91 self
92 }
93
94 pub fn get_beta1(&self) -> A {
96 self.beta1
97 }
98
99 pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
101 self.beta2 = beta2;
102 self
103 }
104
105 pub fn get_beta2(&self) -> A {
107 self.beta2
108 }
109
110 pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
112 self.weight_decay = weight_decay;
113 self
114 }
115
116 pub fn get_weight_decay(&self) -> A {
118 self.weight_decay
119 }
120
121 pub fn learning_rate(&self) -> A {
123 self.learning_rate
124 }
125
126 pub fn set_lr(&mut self, lr: A) {
128 self.learning_rate = lr;
129 }
130
131 pub fn reset(&mut self) {
133 self.m = None;
134 }
135}
136
137impl<A, D> Optimizer<A, D> for Lion<A>
138where
139 A: Float + ScalarOperand + Debug + Send + Sync,
140 D: Dimension,
141{
142 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
143 let params_dyn = params.to_owned().into_dyn();
145 let gradients_dyn = gradients.to_owned().into_dyn();
146
147 if self.m.is_none() {
149 self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
150 }
151
152 let m = self.m.as_mut().expect("unwrap failed");
153
154 if m.is_empty() {
156 m.push(Array::zeros(params_dyn.raw_dim()));
157 } else if m[0].raw_dim() != params_dyn.raw_dim() {
158 m[0] = Array::zeros(params_dyn.raw_dim());
160 }
161
162 let interpolated_update = &m[0] * self.beta1 + &gradients_dyn * (A::one() - self.beta1);
164
165 let sign_update = interpolated_update.mapv(|x| {
167 if x > A::zero() {
168 A::one()
169 } else if x < A::zero() {
170 -A::one()
171 } else {
172 A::zero()
173 }
174 });
175
176 let mut updated_params = params_dyn.clone();
178
179 if self.weight_decay > A::zero() {
181 updated_params = &updated_params * (A::one() - self.weight_decay * self.learning_rate);
182 }
183
184 updated_params = &updated_params - &sign_update * self.learning_rate;
186
187 m[0] = &m[0] * self.beta2 + &gradients_dyn * (A::one() - self.beta2);
189
190 Ok(updated_params
192 .into_dimensionality::<D>()
193 .expect("unwrap failed"))
194 }
195
196 fn get_learning_rate(&self) -> A {
197 self.learning_rate
198 }
199
200 fn set_learning_rate(&mut self, learning_rate: A) {
201 self.learning_rate = learning_rate;
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use approx::assert_abs_diff_eq;
209 use scirs2_core::ndarray::Array1;
210
211 #[test]
212 fn test_lion_basic_creation() {
213 let optimizer: Lion<f64> = Lion::new(0.001);
214 assert_abs_diff_eq!(optimizer.learning_rate(), 0.001);
215 assert_abs_diff_eq!(optimizer.get_beta1(), 0.9);
216 assert_abs_diff_eq!(optimizer.get_beta2(), 0.99);
217 assert_abs_diff_eq!(optimizer.get_weight_decay(), 0.0);
218 }
219
220 #[test]
221 fn test_lion_convergence() {
222 let mut optimizer: Lion<f64> = Lion::new(0.1); let mut params = Array1::from_vec(vec![5.0]);
226
227 for _ in 0..40 {
229 let gradients = Array1::from_vec(vec![2.0 * params[0]]);
232 params = optimizer.step(¶ms, &gradients).expect("unwrap failed");
233 }
234
235 assert!(params[0].abs() < 1.1);
237 }
238
239 #[test]
240 fn test_lion_reset() {
241 let mut optimizer: Lion<f64> = Lion::new(0.1);
242
243 let params = Array1::from_vec(vec![1.0]);
245 let gradients = Array1::from_vec(vec![0.1]);
246 let _ = optimizer.step(¶ms, &gradients).expect("unwrap failed");
247
248 optimizer.reset();
250
251 let next_step = optimizer.step(¶ms, &gradients).expect("unwrap failed");
253
254 let mut fresh_optimizer: Lion<f64> = Lion::new(0.1);
256 let fresh_step = fresh_optimizer
257 .step(¶ms, &gradients)
258 .expect("unwrap failed");
259
260 assert_abs_diff_eq!(next_step[0], fresh_step[0], epsilon = 1e-10);
261 }
262}