1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//! RLX training-step optimizers.
//!
//! Host-side `f32` step functions for the families surveyed in
//! "A Systematic Review of Optimization Algorithms for Modern Deep
//! Learning" (arXiv:2509.02046v1). Each algorithm exposes a small
//! state struct keyed by parameter *name* (so the same struct holds
//! moments for every tensor in a model) and a `step` method that
//! consumes `(name, shape, &mut params, &grads)`.
//!
//! The API is deliberately minimal: it operates on flat `&mut [f32]`
//! / `&[f32]` slices plus a `&[usize]` shape — matching the
//! [`rlx_umap::adam`](../rlx_umap/adam/index.html) pattern. Backends
//! that already ship a fused step kernel (see e.g.
//! `rlx_metal::splat_adam`) are free to bypass this crate for their
//! hot path; this crate is the portable reference / CPU fallback / the
//! one used when there is no backend fused kernel for the requested
//! algorithm.
//!
//! # Algorithms
//!
//! | Family | Type |
//! |-----------------|-------------------------------|
//! | [`Sgd`] | SGD ± momentum / Nesterov |
//! | [`Adam`] | Adam |
//! | [`AdamW`] | AdamW (decoupled decay) |
//! | [`NAdamW`] | Nesterov AdamW |
//! | [`RAdam`] | Rectified Adam |
//! | [`QHAdamW`] | Quasi-hyperbolic AdamW |
//! | [`Lamb`] | LAMB (layer-wise adaptive) |
//! | [`Adafactor`] | Adafactor (factored 2nd mom.) |
//! | [`Lion`] | Lion (sign of EMA) |
//! | [`Soap`] | SOAP (Shampoo-in-Adam-basis) |
//! | [`KronPsgd`] | Kron / PSGD |
//! | [`Muon`] | Muon (Newton–Schulz orth.) |
//! | [`Sophia`] | Sophia-H |
//! | [`Mars`] | MARS (variance-reduced) |
pub use Adafactor;
pub use Adam;
pub use AdamW;
pub use KronPsgd;
pub use Lamb;
pub use Lion;
pub use Mars;
pub use Muon;
pub use NAdamW;
pub use QHAdamW;
pub use RAdam;
pub use Sgd;
pub use Soap;
pub use Sophia;
pub use ;
/// Common parameter-update interface.
///
/// `name` keys the per-parameter state (moments, preconditioners),
/// `shape` is the parameter's logical shape (used by matrix-aware
/// algorithms like Adafactor / SOAP / Muon — ignored by elementwise
/// ones), `param` is updated in place from `grad`. `grad` is treated
/// as read-only; callers that need gradient clipping should pre-scale
/// it (see [`global_grad_clip_scale`]).
///
/// # Implementing for a backend
///
/// Every algorithm in this crate provides a CPU reference impl. A
/// backend (e.g. `rlx-metal`, `rlx-cuda`) is free to write its own
/// fused step kernel and impl `Optimizer` for a wrapper struct that
/// owns device buffers — the trait places no requirement on where
/// the state lives, only on the entry-point signature. The
/// `rlx-metal::splat_adam` kernel is the canonical example of a
/// backend that bypasses this crate entirely; you can wrap it with a
/// 5-line `impl Optimizer` if you want a uniform interface from a
/// generic trainer.
///
/// # Per-tensor learning rate
///
/// For optimizers that don't need per-tensor LR variation (most
/// transformer pre-training), set [`lr_scale`](Self::lr_scale) to
/// return `1.0` (the default). For domain-specific use cases — e.g.
/// 3D Gaussian splatting, where different attributes need wildly
/// different step sizes — override [`lr_scale`](Self::lr_scale) to
/// multiply the base `lr` by a per-name factor. The provided method
/// on the trait does NOT scale automatically; algorithms are free to
/// consult it via [`Optimizer::lr_scale`] inside their `step`.