Skip to main content

Crate rlx_optim

Crate rlx_optim 

Source
Expand description

RLX training-step optimizers.

Host-side f32 step functions for the families surveyed in “A Systematic Review of Optimization Algorithms for Modern Deep Learning” (arXiv:2509.02046v1). Each algorithm exposes a small state struct keyed by parameter name (so the same struct holds moments for every tensor in a model) and a step method that consumes (name, shape, &mut params, &grads).

The API is deliberately minimal: it operates on flat &mut [f32] / &[f32] slices plus a &[usize] shape — matching the rlx_umap::adam pattern. Backends that already ship a fused step kernel (see e.g. rlx_metal::splat_adam) are free to bypass this crate for their hot path; this crate is the portable reference / CPU fallback / the one used when there is no backend fused kernel for the requested algorithm.

§Algorithms

FamilyType
SgdSGD ± momentum / Nesterov
AdamAdam
AdamWAdamW (decoupled decay)
NAdamWNesterov AdamW
RAdamRectified Adam
QHAdamWQuasi-hyperbolic AdamW
LambLAMB (layer-wise adaptive)
AdafactorAdafactor (factored 2nd mom.)
LionLion (sign of EMA)
SoapSOAP (Shampoo-in-Adam-basis)
KronPsgdKron / PSGD
MuonMuon (Newton–Schulz orth.)
SophiaSophia-H
MarsMARS (variance-reduced)

Structs§

Adafactor
Adafactor — factored-second-moment optimizer.
Adam
Bias-corrected first/second moment optimizer.
AdamW
Adam with decoupled weight decay.
KronPsgd
Kron-PSGD — Kronecker-factored preconditioned SGD.
Lamb
Layer-wise Adaptive Moments for Batch training.
Lion
EvoLved sign-momentum optimizer.
Mars
MARS — variance-reduced AdamW. Per-tensor state: three f32 buffers (m, v, previous-gradient cache).
Muon
Muon — Momentum-Orthogonalized-by-Newton-Schulz.
NAdamW
Nesterov AdamW. Per-tensor state: two f32 buffers.
QHAdamW
Quasi-hyperbolic AdamW. Per-tensor state: two f32 buffers.
RAdam
Rectified Adam. Per-tensor state: two f32 buffers.
Sgd
SGD with momentum / Nesterov / L2 weight decay.
Soap
SOAP — Shampoo-in-Adam-basis optimizer.
Sophia
Sophia-H — Hessian-diagonal second-order optimizer.

Traits§

Optimizer
Common parameter-update interface.

Functions§

global_grad_clip_scale
Global L2-norm clip across many tensors. Returns the scale factor (<= 1.0) to multiply every gradient by; callers can pre-scale before passing to [Optimizer::step]. Identical to rlx_umap::adam::global_grad_clip_scale but generic over any iterator yielding slices.
l2_norm
L2 norm across a slice (skipping non-finite entries).