Skip to main content

Optimizer

Trait Optimizer 

Source
pub trait Optimizer {
    // Required method
    fn step(
        &mut self,
        name: &str,
        shape: &[usize],
        param: &mut [f32],
        grad: &[f32],
    );

    // Provided methods
    fn end_iteration(&mut self) { ... }
    fn lr_scale(&self, _name: &str) -> f32 { ... }
}
Expand description

Common parameter-update interface.

name keys the per-parameter state (moments, preconditioners), shape is the parameter’s logical shape (used by matrix-aware algorithms like Adafactor / SOAP / Muon — ignored by elementwise ones), param is updated in place from grad. grad is treated as read-only; callers that need gradient clipping should pre-scale it (see global_grad_clip_scale).

§Implementing for a backend

Every algorithm in this crate provides a CPU reference impl. A backend (e.g. rlx-metal, rlx-cuda) is free to write its own fused step kernel and impl Optimizer for a wrapper struct that owns device buffers — the trait places no requirement on where the state lives, only on the entry-point signature. The rlx-metal::splat_adam kernel is the canonical example of a backend that bypasses this crate entirely; you can wrap it with a 5-line impl Optimizer if you want a uniform interface from a generic trainer.

§Per-tensor learning rate

For optimizers that don’t need per-tensor LR variation (most transformer pre-training), set lr_scale to return 1.0 (the default). For domain-specific use cases — e.g. 3D Gaussian splatting, where different attributes need wildly different step sizes — override lr_scale to multiply the base lr by a per-name factor. The provided method on the trait does NOT scale automatically; algorithms are free to consult it via Optimizer::lr_scale inside their step.

Required Methods§

Source

fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32])

Provided Methods§

Source

fn end_iteration(&mut self)

Advance the global step counter. Most algorithms increment per call to [step], so most implementations leave this a no-op.

Source

fn lr_scale(&self, _name: &str) -> f32

Per-tensor multiplier on the effective learning rate. Default is 1.0 for every name. Override when wrapping this crate to support per-name LR schedules (e.g. embedding-vs-attention splits, or the Gaussian-splat attribute-typed LR setup). The CPU impls in this crate currently honor this only when the caller passes a pre-scaled lr for the relevant call — backends are encouraged to consult it inside their fused kernel.

Dyn Compatibility§

This trait is dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety".

Implementors§