Skip to main content

entrenar/optim/simd/
adam.rs

1//! Fused Adam parameter update kernel (KAIZEN-026)
2//!
3//! Single-pass loop over all elements — zero temporary allocations.
4//! The compiler auto-vectorizes this into SIMD instructions (AVX2/AVX-512).
5//!
6//! # Contract (C-ADAM-FUSED-001)
7//!
8//! - **Precondition**: All slices have equal length
9//! - **Postcondition**: m, v, param updated in-place per Adam equations
10//! - **Invariant**: v[i] >= 0 for all i (squared gradient accumulation)
11//! - **Invariant**: All outputs finite for finite inputs
12
13/// Fused Adam parameter update.
14///
15/// Updates momentum, variance, and parameters in a single pass with
16/// zero temporary allocations.
17///
18/// # Arguments
19/// * `grad` - Gradient vector
20/// * `m` - First moment (momentum) vector (updated in-place)
21/// * `v` - Second moment (variance) vector (updated in-place)
22/// * `param` - Parameter vector (updated in-place)
23/// * `beta1` - Momentum decay rate
24/// * `beta2` - Variance decay rate
25/// * `lr_t` - Bias-corrected learning rate
26/// * `epsilon` - Small constant for numerical stability
27#[allow(clippy::too_many_arguments)]
28pub fn simd_adam_update(
29    grad: &[f32],
30    m: &mut [f32],
31    v: &mut [f32],
32    param: &mut [f32],
33    beta1: f32,
34    beta2: f32,
35    lr_t: f32,
36    epsilon: f32,
37) {
38    assert_eq!(grad.len(), m.len(), "Gradient and momentum lengths must match");
39    assert_eq!(grad.len(), v.len(), "Gradient and variance lengths must match");
40    assert_eq!(grad.len(), param.len(), "Gradient and parameter lengths must match");
41
42    let one_minus_beta1 = 1.0 - beta1;
43    let one_minus_beta2 = 1.0 - beta2;
44
45    // Single fused pass — compiler auto-vectorizes this loop
46    for i in 0..grad.len() {
47        // m_t = β1 * m_{t-1} + (1 - β1) * g
48        m[i] = beta1 * m[i] + one_minus_beta1 * grad[i];
49        // v_t = β2 * v_{t-1} + (1 - β2) * g²
50        v[i] = beta2 * v[i] + one_minus_beta2 * grad[i] * grad[i];
51        // θ_t = θ_{t-1} - lr_t * m_t / (√v_t + ε)
52        param[i] -= lr_t * m[i] / (v[i].sqrt() + epsilon);
53    }
54}