pub struct QHAdamW {
pub lr: f32,
pub beta1: f32,
pub beta2: f32,
pub nu1: f32,
pub nu2: f32,
pub eps: f32,
pub weight_decay: f32,
/* private fields */
}Expand description
Quasi-hyperbolic AdamW. Per-tensor state: two f32 buffers.
Fields§
§lr: f32Learning rate.
beta1: f32First-moment EMA decay β₁. Ma & Yarats recommend 0.995 (much
closer to 1 than vanilla Adam) — the QH interpolation already
keeps current-gradient weight in the numerator.
beta2: f32Second-moment EMA decay β₂. Default 0.999.
nu1: f32First-moment QH interpolation coefficient ν₁ ∈ [0, 1].
1.0 = pure EMA (standard Adam first moment); 0.0 = pure
current gradient (no momentum). Default 0.7.
nu2: f32Second-moment QH interpolation coefficient ν₂. 1.0 = standard
Adam denominator. Default 1.0.
eps: f32Denominator stability constant. Default 1e-8.
weight_decay: f32Decoupled weight-decay coefficient λ. Default 0.01.
Implementations§
Source§impl QHAdamW
impl QHAdamW
Sourcepub fn new(lr: f32) -> Self
pub fn new(lr: f32) -> Self
Construct with (β₁, β₂, ν₁, ν₂, ε, λ) = (0.995, 0.999, 0.7, 1.0, 1e-8, 0.01).
Sourcepub fn with_betas(self, b1: f32, b2: f32) -> Self
pub fn with_betas(self, b1: f32, b2: f32) -> Self
Override (β₁, β₂).
Sourcepub fn with_nus(self, n1: f32, n2: f32) -> Self
pub fn with_nus(self, n1: f32, n2: f32) -> Self
Override the quasi-hyperbolic coefficients (ν₁, ν₂).
Sourcepub fn with_weight_decay(self, wd: f32) -> Self
pub fn with_weight_decay(self, wd: f32) -> Self
Override the decoupled-decay coefficient.
Trait Implementations§
Source§impl Optimizer for QHAdamW
impl Optimizer for QHAdamW
fn step( &mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32], )
Source§fn end_iteration(&mut self)
fn end_iteration(&mut self)
step], so most implementations leave this a no-op.Source§fn lr_scale(&self, _name: &str) -> f32
fn lr_scale(&self, _name: &str) -> f32
1.0 for every name. Override when wrapping this crate to
support per-name LR schedules (e.g. embedding-vs-attention
splits, or the Gaussian-splat attribute-typed LR setup). The
CPU impls in this crate currently honor this only when the
caller passes a pre-scaled lr for the relevant call —
backends are encouraged to consult it inside their fused
kernel.