Skip to main content

rlx_optim/
sophia.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Sophia-H — Second-order Clipped Stochastic Optimization (Liu, Xie,
17//! Zhang, Ma, 2023).
18//!
19//! # Idea
20//!
21//! Adam preconditions by `1/√v_t` (a noisy proxy for the inverse
22//! Hessian *diagonal*); Sophia preconditions by the **actual Hessian
23//! diagonal**, computed periodically via a Hutchinson estimator or a
24//! Gauss–Newton approximation. The crucial trick is a *per-coordinate
25//! clip* of the resulting update — even with a noisy Hessian, the
26//! clip caps each coordinate's step at `ρ`, so adversarial curvature
27//! estimates can never blow up the trajectory.
28//!
29//! # Update rule
30//!
31//! ```text
32//! m_t = β₁·m_{t-1} + (1 − β₁)·g_t                  // first moment EMA
33//! [every K steps the caller updates h via Sophia::update_hessian:]
34//!   h ← β₂·h + (1 − β₂)·diag(H_t)                  // Hessian-diag EMA
35//! u_i = m_{t,i} / max(γ · h_i, ε)
36//! u_i = clip(u_i, −ρ, +ρ)
37//! θ_t = θ_{t-1} − lr · ( u + λ·θ_{t-1} )
38//! ```
39//!
40//! # HVP oracle
41//!
42//! This crate doesn't ship an HVP oracle (it lives in `rlx-autodiff`
43//! as [`rlx_autodiff::hvp`](../../rlx_autodiff/fn.hvp.html)). Call
44//! [`Sophia::update_hessian`] yourself whenever you have a fresh
45//! diagonal estimate (Hutchinson: `H_diag ≈ u ⊙ (∇²L · u)` with random
46//! Rademacher `u`; or Gauss–Newton: `H_diag ≈ g_t²` from a held-out
47//! micro-batch). If you never update it, Sophia degenerates to a
48//! magnitude-clipped first-moment step.
49//!
50//! # When to use
51//!
52//! Curvature-aware optimization for LLM pre-training; the original
53//! paper reports ~2× wall-clock speedup vs AdamW at the same loss.
54//! State cost: two buffers per parameter (`m`, `h`).
55
56use std::collections::HashMap;
57
58use crate::Optimizer;
59use crate::common::zeros_entry;
60
61/// Sophia-H — Hessian-diagonal second-order optimizer.
62#[derive(Debug, Clone)]
63pub struct Sophia {
64    /// Learning rate. Typically slightly *larger* than the AdamW LR
65    /// you'd use on the same model, because the clip bounds the step.
66    pub lr: f32,
67    /// First-moment EMA decay β₁. Default `0.965`.
68    pub beta1: f32,
69    /// Hessian-diagonal EMA decay β₂. Default `0.99`.
70    pub beta2: f32,
71    /// Hessian scale γ (Liu et al. default `0.01`). Multiplies the
72    /// Hessian estimate before forming the denominator.
73    pub gamma: f32,
74    /// Per-coordinate clip threshold ρ. Default `0.04` — the
75    /// dimensionless cap on each step's magnitude.
76    pub rho: f32,
77    /// Denominator floor. Default `1e-12`.
78    pub eps: f32,
79    /// Decoupled weight-decay coefficient λ. Default `0.1` (large by
80    /// AdamW standards — Sophia tolerates more decay).
81    pub weight_decay: f32,
82    step: u64,
83    m: HashMap<String, Vec<f32>>,
84    h: HashMap<String, Vec<f32>>,
85}
86
87impl Sophia {
88    /// Construct with `(β₁, β₂, γ, ρ, ε, λ) = (0.965, 0.99, 0.01, 0.04, 1e-12, 0.1)`.
89    pub fn new(lr: f32) -> Self {
90        Self {
91            lr,
92            beta1: 0.965,
93            beta2: 0.99,
94            gamma: 0.01,
95            rho: 0.04,
96            eps: 1e-12,
97            weight_decay: 0.1,
98            step: 0,
99            m: HashMap::new(),
100            h: HashMap::new(),
101        }
102    }
103
104    /// Override (β₁, β₂).
105    pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
106        self.beta1 = b1;
107        self.beta2 = b2;
108        self
109    }
110
111    /// Override the decoupled-decay coefficient.
112    pub fn with_weight_decay(mut self, wd: f32) -> Self {
113        self.weight_decay = wd;
114        self
115    }
116
117    /// Update the diagonal-Hessian estimate for parameter `name`.
118    /// `h_hat` should be a fresh estimate (typically `H_diag` from a
119    /// Hutchinson estimator or `g²` from a Gauss-Newton approximation).
120    pub fn update_hessian(&mut self, name: &str, h_hat: &[f32]) {
121        let h = zeros_entry(&mut self.h, name, h_hat.len());
122        let b2 = self.beta2;
123        for i in 0..h.len() {
124            h[i] = b2 * h[i] + (1.0 - b2) * h_hat[i];
125        }
126    }
127}
128
129impl Optimizer for Sophia {
130    fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
131        debug_assert_eq!(param.len(), grad.len());
132        let b1 = self.beta1;
133        let gamma = self.gamma.max(self.eps);
134        let rho = self.rho;
135        let eps = self.eps;
136        let lr = self.lr;
137        let wd = self.weight_decay;
138        let m = zeros_entry(&mut self.m, name, param.len());
139        for i in 0..param.len() {
140            m[i] = b1 * m[i] + (1.0 - b1) * grad[i];
141        }
142        // Snapshot h (zero if not yet populated).
143        let h_default = vec![0.0f32; param.len()];
144        let h = self.h.get(name).unwrap_or(&h_default);
145        for i in 0..param.len() {
146            let denom = (gamma * h[i]).max(eps);
147            let mut u = m[i] / denom;
148            // Per-coordinate clip to [-rho, rho].
149            if u > rho {
150                u = rho;
151            } else if u < -rho {
152                u = -rho;
153            }
154            // Decoupled decay.
155            param[i] -= lr * (u + wd * param[i]);
156        }
157    }
158
159    fn end_iteration(&mut self) {
160        self.step += 1;
161    }
162}