Skip to main content

entrenar/optim/simd/
adamw.rs

1//! Fused AdamW parameter update kernel (KAIZEN-026)
2//!
3//! Single-pass loop over all elements — zero temporary allocations.
4//! The compiler auto-vectorizes this into SIMD instructions (AVX2/AVX-512).
5//!
6//! Previous implementation (pre-KAIZEN-026) created 14 temporary Vector
7//! allocations per call via trueno::vector::Vector operations.  For Qwen3-4B
8//! LoRA (5.9M params across ~200 tensors): ~330 MB of temporaries per
9//! optimizer step, with 14 passes over the data.
10//!
11//! # Contract (C-ADAMW-FUSED-001)
12//!
13//! - **Precondition**: All slices have equal length
14//! - **Postcondition**: m, v, param updated in-place per AdamW equations
15//! - **Invariant**: v[i] >= 0 for all i (squared gradient accumulation)
16//! - **Invariant**: All outputs finite for finite inputs
17
18/// Fused AdamW parameter update with decoupled weight decay.
19///
20/// Updates momentum, variance, and parameters in a single pass with
21/// zero temporary allocations.
22///
23/// # Arguments
24/// * `grad` - Gradient vector
25/// * `m` - First moment (momentum) vector (updated in-place)
26/// * `v` - Second moment (variance) vector (updated in-place)
27/// * `param` - Parameter vector (updated in-place)
28/// * `beta1` - Momentum decay rate
29/// * `beta2` - Variance decay rate
30/// * `lr` - Learning rate
31/// * `lr_t` - Bias-corrected learning rate for adaptive update
32/// * `weight_decay` - Weight decay coefficient
33/// * `epsilon` - Small constant for numerical stability
34#[allow(clippy::too_many_arguments)]
35pub fn simd_adamw_update(
36    grad: &[f32],
37    m: &mut [f32],
38    v: &mut [f32],
39    param: &mut [f32],
40    beta1: f32,
41    beta2: f32,
42    lr: f32,
43    lr_t: f32,
44    weight_decay: f32,
45    epsilon: f32,
46) {
47    assert_eq!(grad.len(), m.len(), "Gradient and momentum lengths must match");
48    assert_eq!(grad.len(), v.len(), "Gradient and variance lengths must match");
49    assert_eq!(grad.len(), param.len(), "Gradient and parameter lengths must match");
50
51    let one_minus_beta1 = 1.0 - beta1;
52    let one_minus_beta2 = 1.0 - beta2;
53    let wd_factor = 1.0 - lr * weight_decay;
54
55    // Single fused pass — compiler auto-vectorizes this loop
56    for i in 0..grad.len() {
57        // m_t = β1 * m_{t-1} + (1 - β1) * g
58        m[i] = beta1 * m[i] + one_minus_beta1 * grad[i];
59        // v_t = β2 * v_{t-1} + (1 - β2) * g²
60        v[i] = beta2 * v[i] + one_minus_beta2 * grad[i] * grad[i];
61        // θ_t = (1 - lr * λ) * θ_{t-1} - lr_t * m_t / (√v_t + ε)
62        param[i] = wd_factor * param[i] - lr_t * m[i] / (v[i].sqrt() + epsilon);
63    }
64}