rlx_optim/lib.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RLX training-step optimizers.
17//!
18//! Host-side `f32` step functions for the families surveyed in
19//! "A Systematic Review of Optimization Algorithms for Modern Deep
20//! Learning" (arXiv:2509.02046v1). Each algorithm exposes a small
21//! state struct keyed by parameter *name* (so the same struct holds
22//! moments for every tensor in a model) and a `step` method that
23//! consumes `(name, shape, &mut params, &grads)`.
24//!
25//! The API is deliberately minimal: it operates on flat `&mut [f32]`
26//! / `&[f32]` slices plus a `&[usize]` shape — matching the
27//! [`rlx_umap::adam`](../rlx_umap/adam/index.html) pattern. Backends
28//! that already ship a fused step kernel (see e.g.
29//! `rlx_metal::splat_adam`) are free to bypass this crate for their
30//! hot path; this crate is the portable reference / CPU fallback / the
31//! one used when there is no backend fused kernel for the requested
32//! algorithm.
33//!
34//! # Algorithms
35//!
36//! | Family | Type |
37//! |-----------------|-------------------------------|
38//! | [`Sgd`] | SGD ± momentum / Nesterov |
39//! | [`Adam`] | Adam |
40//! | [`AdamW`] | AdamW (decoupled decay) |
41//! | [`NAdamW`] | Nesterov AdamW |
42//! | [`RAdam`] | Rectified Adam |
43//! | [`QHAdamW`] | Quasi-hyperbolic AdamW |
44//! | [`Lamb`] | LAMB (layer-wise adaptive) |
45//! | [`Adafactor`] | Adafactor (factored 2nd mom.) |
46//! | [`Lion`] | Lion (sign of EMA) |
47//! | [`Soap`] | SOAP (Shampoo-in-Adam-basis) |
48//! | [`KronPsgd`] | Kron / PSGD |
49//! | [`Muon`] | Muon (Newton–Schulz orth.) |
50//! | [`Sophia`] | Sophia-H |
51//! | [`Mars`] | MARS (variance-reduced) |
52
53#![forbid(unsafe_code)]
54
55mod common;
56
57mod adafactor;
58mod adam;
59mod adamw;
60mod kron_psgd;
61mod lamb;
62mod lion;
63mod mars;
64mod muon;
65mod nadamw;
66mod qhadamw;
67mod radam;
68mod sgd;
69mod soap;
70mod sophia;
71
72pub use adafactor::Adafactor;
73pub use adam::Adam;
74pub use adamw::AdamW;
75pub use kron_psgd::KronPsgd;
76pub use lamb::Lamb;
77pub use lion::Lion;
78pub use mars::Mars;
79pub use muon::Muon;
80pub use nadamw::NAdamW;
81pub use qhadamw::QHAdamW;
82pub use radam::RAdam;
83pub use sgd::Sgd;
84pub use soap::Soap;
85pub use sophia::Sophia;
86
87pub use common::{global_grad_clip_scale, l2_norm};
88
89/// Common parameter-update interface.
90///
91/// `name` keys the per-parameter state (moments, preconditioners),
92/// `shape` is the parameter's logical shape (used by matrix-aware
93/// algorithms like Adafactor / SOAP / Muon — ignored by elementwise
94/// ones), `param` is updated in place from `grad`. `grad` is treated
95/// as read-only; callers that need gradient clipping should pre-scale
96/// it (see [`global_grad_clip_scale`]).
97///
98/// # Implementing for a backend
99///
100/// Every algorithm in this crate provides a CPU reference impl. A
101/// backend (e.g. `rlx-metal`, `rlx-cuda`) is free to write its own
102/// fused step kernel and impl `Optimizer` for a wrapper struct that
103/// owns device buffers — the trait places no requirement on where
104/// the state lives, only on the entry-point signature. The
105/// `rlx-metal::splat_adam` kernel is the canonical example of a
106/// backend that bypasses this crate entirely; you can wrap it with a
107/// 5-line `impl Optimizer` if you want a uniform interface from a
108/// generic trainer.
109///
110/// # Per-tensor learning rate
111///
112/// For optimizers that don't need per-tensor LR variation (most
113/// transformer pre-training), set [`lr_scale`](Self::lr_scale) to
114/// return `1.0` (the default). For domain-specific use cases — e.g.
115/// 3D Gaussian splatting, where different attributes need wildly
116/// different step sizes — override [`lr_scale`](Self::lr_scale) to
117/// multiply the base `lr` by a per-name factor. The provided method
118/// on the trait does NOT scale automatically; algorithms are free to
119/// consult it via [`Optimizer::lr_scale`] inside their `step`.
120pub trait Optimizer {
121 fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32]);
122
123 /// Advance the global step counter. Most algorithms increment per
124 /// call to [`step`], so most implementations leave this a no-op.
125 fn end_iteration(&mut self) {}
126
127 /// Per-tensor multiplier on the effective learning rate. Default
128 /// is `1.0` for every name. Override when wrapping this crate to
129 /// support per-name LR schedules (e.g. embedding-vs-attention
130 /// splits, or the Gaussian-splat attribute-typed LR setup). The
131 /// CPU impls in this crate currently honor this only when the
132 /// caller passes a pre-scaled `lr` for the relevant call —
133 /// backends are encouraged to consult it inside their fused
134 /// kernel.
135 fn lr_scale(&self, _name: &str) -> f32 {
136 1.0
137 }
138}