optirs_core/optimizers/
lion.rs

1// Lion optimizer implementation
2//
3// Based on the paper "Symbolic Discovery of Optimization Algorithms"
4// by Chen et al. (2023).
5
6use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::optimizers::Optimizer;
12
13/// Lion optimizer
14///
15/// Implements the Lion (Evolved Sign Momentum) optimization algorithm.
16/// Lion is a memory-efficient optimizer that achieves strong performance
17/// with only momentum state and uses the sign of the momentum for updates.
18///
19/// Formula:
20/// u_t = beta1 * m_{t-1} + (1 - beta1) * g_t
21/// theta_t = theta_{t-1} - alpha * (sign(u_t) + lambda * theta_{t-1})
22/// m_t = beta2 * m_{t-1} + (1 - beta2) * g_t
23///
24/// # Examples
25///
26/// ```
27/// use scirs2_core::ndarray::Array1;
28/// use optirs_core::optimizers::{Lion, Optimizer};
29///
30/// // Initialize parameters and gradients
31/// let params = Array1::zeros(5);
32/// let gradients = Array1::from_vec(vec![0.1, 0.2, -0.3, 0.0, 0.5]);
33///
34/// // Create a Lion optimizer with default hyperparameters
35/// let mut optimizer = Lion::new(0.001);
36///
37/// // Update parameters
38/// let new_params = optimizer.step(&params, &gradients).unwrap();
39/// ```
40#[derive(Debug, Clone)]
41pub struct Lion<A: Float + ScalarOperand + Debug> {
42    /// Learning rate
43    learning_rate: A,
44    /// Exponential decay rate for the momentum
45    beta1: A,
46    /// Exponential decay rate for the momentum update
47    beta2: A,
48    /// Weight decay factor (L2 regularization)
49    weight_decay: A,
50    /// Momentum vector
51    m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
52}
53
54impl<A: Float + ScalarOperand + Debug + Send + Sync> Lion<A> {
55    /// Creates a new Lion optimizer with the given learning rate and default settings
56    ///
57    /// # Arguments
58    ///
59    /// * `learning_rate` - The learning rate for parameter updates
60    pub fn new(learning_rate: A) -> Self {
61        Self {
62            learning_rate,
63            beta1: A::from(0.9).unwrap(),
64            beta2: A::from(0.99).unwrap(),
65            weight_decay: A::zero(),
66            m: None,
67        }
68    }
69
70    /// Creates a new Lion optimizer with the full configuration
71    ///
72    /// # Arguments
73    ///
74    /// * `learning_rate` - The learning rate for parameter updates
75    /// * `beta1` - Exponential decay rate for computing the interpolated update (default: 0.9)
76    /// * `beta2` - Exponential decay rate for updating the momentum (default: 0.99)
77    /// * `weight_decay` - Weight decay factor for L2 regularization (default: 0.0)
78    pub fn new_with_config(learning_rate: A, beta1: A, beta2: A, weight_decay: A) -> Self {
79        Self {
80            learning_rate,
81            beta1,
82            beta2,
83            weight_decay,
84            m: None,
85        }
86    }
87
88    /// Sets the beta1 parameter
89    pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
90        self.beta1 = beta1;
91        self
92    }
93
94    /// Gets the beta1 parameter
95    pub fn get_beta1(&self) -> A {
96        self.beta1
97    }
98
99    /// Sets the beta2 parameter
100    pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
101        self.beta2 = beta2;
102        self
103    }
104
105    /// Gets the beta2 parameter
106    pub fn get_beta2(&self) -> A {
107        self.beta2
108    }
109
110    /// Sets the weight decay parameter
111    pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
112        self.weight_decay = weight_decay;
113        self
114    }
115
116    /// Gets the weight decay parameter
117    pub fn get_weight_decay(&self) -> A {
118        self.weight_decay
119    }
120
121    /// Gets the current learning rate
122    pub fn learning_rate(&self) -> A {
123        self.learning_rate
124    }
125
126    /// Sets the learning rate
127    pub fn set_lr(&mut self, lr: A) {
128        self.learning_rate = lr;
129    }
130
131    /// Resets the internal state of the optimizer
132    pub fn reset(&mut self) {
133        self.m = None;
134    }
135}
136
137impl<A, D> Optimizer<A, D> for Lion<A>
138where
139    A: Float + ScalarOperand + Debug + Send + Sync,
140    D: Dimension,
141{
142    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
143        // Convert to dynamic dimension for storage in state vectors
144        let params_dyn = params.to_owned().into_dyn();
145        let gradients_dyn = gradients.to_owned().into_dyn();
146
147        // Initialize state if this is the first step
148        if self.m.is_none() {
149            self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
150        }
151
152        let m = self.m.as_mut().unwrap();
153
154        // Ensure we have state for this parameter set
155        if m.is_empty() {
156            m.push(Array::zeros(params_dyn.raw_dim()));
157        } else if m[0].raw_dim() != params_dyn.raw_dim() {
158            // If the parameter dimensions have changed, reset state
159            m[0] = Array::zeros(params_dyn.raw_dim());
160        }
161
162        // Step 1: Compute interpolated update using beta1
163        let interpolated_update = &m[0] * self.beta1 + &gradients_dyn * (A::one() - self.beta1);
164
165        // Step 2: Compute sign of interpolated update
166        let sign_update = interpolated_update.mapv(|x| {
167            if x > A::zero() {
168                A::one()
169            } else if x < A::zero() {
170                -A::one()
171            } else {
172                A::zero()
173            }
174        });
175
176        // Step 3: Update parameters
177        let mut updated_params = params_dyn.clone();
178
179        // Apply weight decay if specified
180        if self.weight_decay > A::zero() {
181            updated_params = &updated_params * (A::one() - self.weight_decay * self.learning_rate);
182        }
183
184        // Apply the sign update
185        updated_params = &updated_params - &sign_update * self.learning_rate;
186
187        // Step 4: Update momentum using beta2
188        m[0] = &m[0] * self.beta2 + &gradients_dyn * (A::one() - self.beta2);
189
190        // Convert back to original dimension
191        Ok(updated_params.into_dimensionality::<D>().unwrap())
192    }
193
194    fn get_learning_rate(&self) -> A {
195        self.learning_rate
196    }
197
198    fn set_learning_rate(&mut self, learning_rate: A) {
199        self.learning_rate = learning_rate;
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use approx::assert_abs_diff_eq;
207    use scirs2_core::ndarray::Array1;
208
209    #[test]
210    fn test_lion_basic_creation() {
211        let optimizer: Lion<f64> = Lion::new(0.001);
212        assert_abs_diff_eq!(optimizer.learning_rate(), 0.001);
213        assert_abs_diff_eq!(optimizer.get_beta1(), 0.9);
214        assert_abs_diff_eq!(optimizer.get_beta2(), 0.99);
215        assert_abs_diff_eq!(optimizer.get_weight_decay(), 0.0);
216    }
217
218    #[test]
219    fn test_lion_convergence() {
220        let mut optimizer: Lion<f64> = Lion::new(0.1); // Higher learning rate for testing
221
222        // Minimize a simple quadratic function: f(x) = x^2
223        let mut params = Array1::from_vec(vec![5.0]);
224
225        // Lion converges linearly with sign updates
226        for _ in 0..40 {
227            // Fewer iterations with higher learning rate
228            // Gradient of x^2 is 2x
229            let gradients = Array1::from_vec(vec![2.0 * params[0]]);
230            params = optimizer.step(&params, &gradients).unwrap();
231        }
232
233        // With learning rate 0.1 and 40 iterations, should reach close to 1.0
234        assert!(params[0].abs() < 1.1);
235    }
236
237    #[test]
238    fn test_lion_reset() {
239        let mut optimizer: Lion<f64> = Lion::new(0.1);
240
241        // Perform a step to initialize state
242        let params = Array1::from_vec(vec![1.0]);
243        let gradients = Array1::from_vec(vec![0.1]);
244        let _ = optimizer.step(&params, &gradients).unwrap();
245
246        // Reset optimizer
247        optimizer.reset();
248
249        // Next step should behave like the first
250        let next_step = optimizer.step(&params, &gradients).unwrap();
251
252        // Create fresh optimizer for comparison
253        let mut fresh_optimizer: Lion<f64> = Lion::new(0.1);
254        let fresh_step = fresh_optimizer.step(&params, &gradients).unwrap();
255
256        assert_abs_diff_eq!(next_step[0], fresh_step[0], epsilon = 1e-10);
257    }
258}