Skip to main content

rlx_optim/
lion.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Lion — EvoLved Sign Momentum (Chen et al., 2023, "Symbolic
17//! Discovery of Optimization Algorithms").
18//!
19//! # Idea
20//!
21//! Lion was *discovered* by a program-synthesis search over candidate
22//! optimizer expressions. The found rule is shockingly simple — one
23//! momentum buffer, and the update is the **sign** of an
24//! interpolation between the momentum and the gradient.
25//!
26//! # Update rule
27//!
28//! ```text
29//! c_t   = β₁·m_{t-1} + (1 − β₁)·g_t
30//! θ_t   = θ_{t-1} − lr · ( sign(c_t) + λ·θ_{t-1} )
31//! m_t   = β₂·m_{t-1} + (1 − β₂)·g_t          // note: different β₂!
32//! ```
33//!
34//! Two distinct betas: `β₁` shapes the *update direction* (faster
35//! adaptation), `β₂` shapes the *carried momentum* (slower memory).
36//!
37//! # When to use
38//!
39//! Half the memory of Adam (one buffer instead of two), often
40//! converges to similar quality on transformers when the LR is
41//! tuned 3–10× lower than the corresponding AdamW LR. Sign updates
42//! get coarse on tiny problems — favor large-batch / large-model
43//! regimes.
44
45use std::collections::HashMap;
46
47use crate::Optimizer;
48use crate::common::{zeros_entry, zip3_for_each};
49
50/// EvoLved sign-momentum optimizer.
51///
52/// Per-tensor state: **one** `f32` buffer (half of Adam's footprint).
53#[derive(Debug, Clone)]
54pub struct Lion {
55    /// Learning rate. **Critical**: typically 3–10× smaller than the
56    /// AdamW LR you'd use on the same model (because the update has
57    /// unit `‖sign(·)‖` per coordinate).
58    pub lr: f32,
59    /// Interpolation coefficient for the *update direction* (β₁ in
60    /// Chen et al.). Default `0.9`.
61    pub beta1: f32,
62    /// EMA coefficient for the *carried momentum* (β₂). Default `0.99`.
63    pub beta2: f32,
64    /// Decoupled weight-decay coefficient λ. Tune ~3–10× higher than
65    /// the AdamW λ you'd pair with the same model. Default `0.0`.
66    pub weight_decay: f32,
67    m: HashMap<String, Vec<f32>>,
68}
69
70impl Lion {
71    /// Construct with `(β₁, β₂, λ) = (0.9, 0.99, 0.0)`.
72    pub fn new(lr: f32) -> Self {
73        Self {
74            lr,
75            beta1: 0.9,
76            beta2: 0.99,
77            weight_decay: 0.0,
78            m: HashMap::new(),
79        }
80    }
81
82    /// Override (β₁, β₂). They serve different roles — see the
83    /// struct-level docs.
84    pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
85        self.beta1 = b1;
86        self.beta2 = b2;
87        self
88    }
89
90    /// Override the decoupled-decay coefficient.
91    pub fn with_weight_decay(mut self, wd: f32) -> Self {
92        self.weight_decay = wd;
93        self
94    }
95}
96
97impl Optimizer for Lion {
98    fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
99        debug_assert_eq!(param.len(), grad.len());
100        let b1 = self.beta1;
101        let b2 = self.beta2;
102        let lr = self.lr;
103        let wd = self.weight_decay;
104        let m = zeros_entry(&mut self.m, name, param.len());
105        zip3_for_each(param, m, grad, |p, mi, gi| {
106            // Update direction = sign(b1*m + (1-b1)*g)
107            let c = b1 * *mi + (1.0 - b1) * gi;
108            let sign = if c > 0.0 {
109                1.0
110            } else if c < 0.0 {
111                -1.0
112            } else {
113                0.0
114            };
115            // Decoupled weight decay (matches Chen et al. eq. 1).
116            *p -= lr * (sign + wd * *p);
117            // Then update the momentum with a different β₂.
118            *mi = b2 * *mi + (1.0 - b2) * gi;
119        });
120    }
121}