rlx_optim/lion.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Lion — EvoLved Sign Momentum (Chen et al., 2023, "Symbolic
17//! Discovery of Optimization Algorithms").
18//!
19//! # Idea
20//!
21//! Lion was *discovered* by a program-synthesis search over candidate
22//! optimizer expressions. The found rule is shockingly simple — one
23//! momentum buffer, and the update is the **sign** of an
24//! interpolation between the momentum and the gradient.
25//!
26//! # Update rule
27//!
28//! ```text
29//! c_t = β₁·m_{t-1} + (1 − β₁)·g_t
30//! θ_t = θ_{t-1} − lr · ( sign(c_t) + λ·θ_{t-1} )
31//! m_t = β₂·m_{t-1} + (1 − β₂)·g_t // note: different β₂!
32//! ```
33//!
34//! Two distinct betas: `β₁` shapes the *update direction* (faster
35//! adaptation), `β₂` shapes the *carried momentum* (slower memory).
36//!
37//! # When to use
38//!
39//! Half the memory of Adam (one buffer instead of two), often
40//! converges to similar quality on transformers when the LR is
41//! tuned 3–10× lower than the corresponding AdamW LR. Sign updates
42//! get coarse on tiny problems — favor large-batch / large-model
43//! regimes.
44
45use std::collections::HashMap;
46
47use crate::Optimizer;
48use crate::common::{zeros_entry, zip3_for_each};
49
50/// EvoLved sign-momentum optimizer.
51///
52/// Per-tensor state: **one** `f32` buffer (half of Adam's footprint).
53#[derive(Debug, Clone)]
54pub struct Lion {
55 /// Learning rate. **Critical**: typically 3–10× smaller than the
56 /// AdamW LR you'd use on the same model (because the update has
57 /// unit `‖sign(·)‖` per coordinate).
58 pub lr: f32,
59 /// Interpolation coefficient for the *update direction* (β₁ in
60 /// Chen et al.). Default `0.9`.
61 pub beta1: f32,
62 /// EMA coefficient for the *carried momentum* (β₂). Default `0.99`.
63 pub beta2: f32,
64 /// Decoupled weight-decay coefficient λ. Tune ~3–10× higher than
65 /// the AdamW λ you'd pair with the same model. Default `0.0`.
66 pub weight_decay: f32,
67 m: HashMap<String, Vec<f32>>,
68}
69
70impl Lion {
71 /// Construct with `(β₁, β₂, λ) = (0.9, 0.99, 0.0)`.
72 pub fn new(lr: f32) -> Self {
73 Self {
74 lr,
75 beta1: 0.9,
76 beta2: 0.99,
77 weight_decay: 0.0,
78 m: HashMap::new(),
79 }
80 }
81
82 /// Override (β₁, β₂). They serve different roles — see the
83 /// struct-level docs.
84 pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
85 self.beta1 = b1;
86 self.beta2 = b2;
87 self
88 }
89
90 /// Override the decoupled-decay coefficient.
91 pub fn with_weight_decay(mut self, wd: f32) -> Self {
92 self.weight_decay = wd;
93 self
94 }
95}
96
97impl Optimizer for Lion {
98 fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
99 debug_assert_eq!(param.len(), grad.len());
100 let b1 = self.beta1;
101 let b2 = self.beta2;
102 let lr = self.lr;
103 let wd = self.weight_decay;
104 let m = zeros_entry(&mut self.m, name, param.len());
105 zip3_for_each(param, m, grad, |p, mi, gi| {
106 // Update direction = sign(b1*m + (1-b1)*g)
107 let c = b1 * *mi + (1.0 - b1) * gi;
108 let sign = if c > 0.0 {
109 1.0
110 } else if c < 0.0 {
111 -1.0
112 } else {
113 0.0
114 };
115 // Decoupled weight decay (matches Chen et al. eq. 1).
116 *p -= lr * (sign + wd * *p);
117 // Then update the momentum with a different β₂.
118 *mi = b2 * *mi + (1.0 - b2) * gi;
119 });
120 }
121}