Skip to main content

simd_adamw_update

Function simd_adamw_update 

Source
pub fn simd_adamw_update(
    grad: &[f32],
    m: &mut [f32],
    v: &mut [f32],
    param: &mut [f32],
    beta1: f32,
    beta2: f32,
    lr: f32,
    lr_t: f32,
    weight_decay: f32,
    epsilon: f32,
)
Expand description

Fused AdamW parameter update with decoupled weight decay.

Updates momentum, variance, and parameters in a single pass with zero temporary allocations.

ยงArguments

  • grad - Gradient vector
  • m - First moment (momentum) vector (updated in-place)
  • v - Second moment (variance) vector (updated in-place)
  • param - Parameter vector (updated in-place)
  • beta1 - Momentum decay rate
  • beta2 - Variance decay rate
  • lr - Learning rate
  • lr_t - Bias-corrected learning rate for adaptive update
  • weight_decay - Weight decay coefficient
  • epsilon - Small constant for numerical stability