pub struct Adafactor {
pub lr: Option<f32>,
pub beta2_decay: f32,
pub eps1: f32,
pub eps2: f32,
pub clip_threshold: f32,
pub weight_decay: f32,
/* private fields */
}Expand description
Adafactor — factored-second-moment optimizer.
Per-tensor state: a rows-vector + a cols-vector for 2-D
parameters (sublinear in rows·cols), or a full EMA for non-2-D.
Fields§
§lr: Option<f32>Optional manual learning rate. None ⇒ use the “relative
step” rule min(1/√t, 1e-2) · max(ε₂, RMS(θ)) from the paper.
Default None.
beta2_decay: f32β₂_t decay-rate exponent. β₂_t = 1 − tˣ with x = -0.8
(default) means slow decay early, full decay asymptotically.
eps1: f32Squared-gradient stability constant added before each row /
column average. Default 1e-30.
eps2: f32RMS-of-parameter floor for the relative-step rule. Default 1e-3.
clip_threshold: f32Update-RMS clipping threshold (Shazeer & Stern §6). Default 1.0.
weight_decay: f32Decoupled weight-decay coefficient λ. Default 0.0.
Implementations§
Source§impl Adafactor
impl Adafactor
Sourcepub fn new() -> Self
pub fn new() -> Self
Construct with paper defaults (no manual lr ⇒ relative step,
decay_rate = -0.8, ε₁=1e-30, ε₂=1e-3, clip=1.0, λ=0.0).
Sourcepub fn with_lr(self, lr: f32) -> Self
pub fn with_lr(self, lr: f32) -> Self
Switch from the relative-step rule to a manual learning rate.
Sourcepub fn with_weight_decay(self, wd: f32) -> Self
pub fn with_weight_decay(self, wd: f32) -> Self
Override the decoupled-decay coefficient.
Trait Implementations§
Source§impl Optimizer for Adafactor
impl Optimizer for Adafactor
fn step(&mut self, name: &str, shape: &[usize], param: &mut [f32], grad: &[f32])
Source§fn end_iteration(&mut self)
fn end_iteration(&mut self)
step], so most implementations leave this a no-op.Source§fn lr_scale(&self, _name: &str) -> f32
fn lr_scale(&self, _name: &str) -> f32
1.0 for every name. Override when wrapping this crate to
support per-name LR schedules (e.g. embedding-vs-attention
splits, or the Gaussian-splat attribute-typed LR setup). The
CPU impls in this crate currently honor this only when the
caller passes a pre-scaled lr for the relevant call —
backends are encouraged to consult it inside their fused
kernel.