pub trait Optimizer {
// Required method
fn step(
&mut self,
name: &str,
shape: &[usize],
param: &mut [f32],
grad: &[f32],
);
// Provided methods
fn end_iteration(&mut self) { ... }
fn lr_scale(&self, _name: &str) -> f32 { ... }
}Expand description
Common parameter-update interface.
name keys the per-parameter state (moments, preconditioners),
shape is the parameter’s logical shape (used by matrix-aware
algorithms like Adafactor / SOAP / Muon — ignored by elementwise
ones), param is updated in place from grad. grad is treated
as read-only; callers that need gradient clipping should pre-scale
it (see global_grad_clip_scale).
§Implementing for a backend
Every algorithm in this crate provides a CPU reference impl. A
backend (e.g. rlx-metal, rlx-cuda) is free to write its own
fused step kernel and impl Optimizer for a wrapper struct that
owns device buffers — the trait places no requirement on where
the state lives, only on the entry-point signature. The
rlx-metal::splat_adam kernel is the canonical example of a
backend that bypasses this crate entirely; you can wrap it with a
5-line impl Optimizer if you want a uniform interface from a
generic trainer.
§Per-tensor learning rate
For optimizers that don’t need per-tensor LR variation (most
transformer pre-training), set lr_scale to
return 1.0 (the default). For domain-specific use cases — e.g.
3D Gaussian splatting, where different attributes need wildly
different step sizes — override lr_scale to
multiply the base lr by a per-name factor. The provided method
on the trait does NOT scale automatically; algorithms are free to
consult it via Optimizer::lr_scale inside their step.
Required Methods§
Provided Methods§
Sourcefn end_iteration(&mut self)
fn end_iteration(&mut self)
Advance the global step counter. Most algorithms increment per
call to [step], so most implementations leave this a no-op.
Sourcefn lr_scale(&self, _name: &str) -> f32
fn lr_scale(&self, _name: &str) -> f32
Per-tensor multiplier on the effective learning rate. Default
is 1.0 for every name. Override when wrapping this crate to
support per-name LR schedules (e.g. embedding-vs-attention
splits, or the Gaussian-splat attribute-typed LR setup). The
CPU impls in this crate currently honor this only when the
caller passes a pre-scaled lr for the relevant call —
backends are encouraged to consult it inside their fused
kernel.
Dyn Compatibility§
This trait is dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".