rlx_optim/sophia.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Sophia-H — Second-order Clipped Stochastic Optimization (Liu, Xie,
17//! Zhang, Ma, 2023).
18//!
19//! # Idea
20//!
21//! Adam preconditions by `1/√v_t` (a noisy proxy for the inverse
22//! Hessian *diagonal*); Sophia preconditions by the **actual Hessian
23//! diagonal**, computed periodically via a Hutchinson estimator or a
24//! Gauss–Newton approximation. The crucial trick is a *per-coordinate
25//! clip* of the resulting update — even with a noisy Hessian, the
26//! clip caps each coordinate's step at `ρ`, so adversarial curvature
27//! estimates can never blow up the trajectory.
28//!
29//! # Update rule
30//!
31//! ```text
32//! m_t = β₁·m_{t-1} + (1 − β₁)·g_t // first moment EMA
33//! [every K steps the caller updates h via Sophia::update_hessian:]
34//! h ← β₂·h + (1 − β₂)·diag(H_t) // Hessian-diag EMA
35//! u_i = m_{t,i} / max(γ · h_i, ε)
36//! u_i = clip(u_i, −ρ, +ρ)
37//! θ_t = θ_{t-1} − lr · ( u + λ·θ_{t-1} )
38//! ```
39//!
40//! # HVP oracle
41//!
42//! This crate doesn't ship an HVP oracle (it lives in `rlx-autodiff`
43//! as [`rlx_autodiff::hvp`](../../rlx_autodiff/fn.hvp.html)). Call
44//! [`Sophia::update_hessian`] yourself whenever you have a fresh
45//! diagonal estimate (Hutchinson: `H_diag ≈ u ⊙ (∇²L · u)` with random
46//! Rademacher `u`; or Gauss–Newton: `H_diag ≈ g_t²` from a held-out
47//! micro-batch). If you never update it, Sophia degenerates to a
48//! magnitude-clipped first-moment step.
49//!
50//! # When to use
51//!
52//! Curvature-aware optimization for LLM pre-training; the original
53//! paper reports ~2× wall-clock speedup vs AdamW at the same loss.
54//! State cost: two buffers per parameter (`m`, `h`).
55
56use std::collections::HashMap;
57
58use crate::Optimizer;
59use crate::common::zeros_entry;
60
61/// Sophia-H — Hessian-diagonal second-order optimizer.
62#[derive(Debug, Clone)]
63pub struct Sophia {
64 /// Learning rate. Typically slightly *larger* than the AdamW LR
65 /// you'd use on the same model, because the clip bounds the step.
66 pub lr: f32,
67 /// First-moment EMA decay β₁. Default `0.965`.
68 pub beta1: f32,
69 /// Hessian-diagonal EMA decay β₂. Default `0.99`.
70 pub beta2: f32,
71 /// Hessian scale γ (Liu et al. default `0.01`). Multiplies the
72 /// Hessian estimate before forming the denominator.
73 pub gamma: f32,
74 /// Per-coordinate clip threshold ρ. Default `0.04` — the
75 /// dimensionless cap on each step's magnitude.
76 pub rho: f32,
77 /// Denominator floor. Default `1e-12`.
78 pub eps: f32,
79 /// Decoupled weight-decay coefficient λ. Default `0.1` (large by
80 /// AdamW standards — Sophia tolerates more decay).
81 pub weight_decay: f32,
82 step: u64,
83 m: HashMap<String, Vec<f32>>,
84 h: HashMap<String, Vec<f32>>,
85}
86
87impl Sophia {
88 /// Construct with `(β₁, β₂, γ, ρ, ε, λ) = (0.965, 0.99, 0.01, 0.04, 1e-12, 0.1)`.
89 pub fn new(lr: f32) -> Self {
90 Self {
91 lr,
92 beta1: 0.965,
93 beta2: 0.99,
94 gamma: 0.01,
95 rho: 0.04,
96 eps: 1e-12,
97 weight_decay: 0.1,
98 step: 0,
99 m: HashMap::new(),
100 h: HashMap::new(),
101 }
102 }
103
104 /// Override (β₁, β₂).
105 pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
106 self.beta1 = b1;
107 self.beta2 = b2;
108 self
109 }
110
111 /// Override the decoupled-decay coefficient.
112 pub fn with_weight_decay(mut self, wd: f32) -> Self {
113 self.weight_decay = wd;
114 self
115 }
116
117 /// Update the diagonal-Hessian estimate for parameter `name`.
118 /// `h_hat` should be a fresh estimate (typically `H_diag` from a
119 /// Hutchinson estimator or `g²` from a Gauss-Newton approximation).
120 pub fn update_hessian(&mut self, name: &str, h_hat: &[f32]) {
121 let h = zeros_entry(&mut self.h, name, h_hat.len());
122 let b2 = self.beta2;
123 for i in 0..h.len() {
124 h[i] = b2 * h[i] + (1.0 - b2) * h_hat[i];
125 }
126 }
127}
128
129impl Optimizer for Sophia {
130 fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
131 debug_assert_eq!(param.len(), grad.len());
132 let b1 = self.beta1;
133 let gamma = self.gamma.max(self.eps);
134 let rho = self.rho;
135 let eps = self.eps;
136 let lr = self.lr;
137 let wd = self.weight_decay;
138 let m = zeros_entry(&mut self.m, name, param.len());
139 for i in 0..param.len() {
140 m[i] = b1 * m[i] + (1.0 - b1) * grad[i];
141 }
142 // Snapshot h (zero if not yet populated).
143 let h_default = vec![0.0f32; param.len()];
144 let h = self.h.get(name).unwrap_or(&h_default);
145 for i in 0..param.len() {
146 let denom = (gamma * h[i]).max(eps);
147 let mut u = m[i] / denom;
148 // Per-coordinate clip to [-rho, rho].
149 if u > rho {
150 u = rho;
151 } else if u < -rho {
152 u = -rho;
153 }
154 // Decoupled decay.
155 param[i] -= lr * (u + wd * param[i]);
156 }
157 }
158
159 fn end_iteration(&mut self) {
160 self.step += 1;
161 }
162}