1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
//! Fused AdamW parameter update kernel (KAIZEN-026)
//!
//! Single-pass loop over all elements — zero temporary allocations.
//! The compiler auto-vectorizes this into SIMD instructions (AVX2/AVX-512).
//!
//! Previous implementation (pre-KAIZEN-026) created 14 temporary Vector
//! allocations per call via trueno::vector::Vector operations. For Qwen3-4B
//! LoRA (5.9M params across ~200 tensors): ~330 MB of temporaries per
//! optimizer step, with 14 passes over the data.
//!
//! # Contract (C-ADAMW-FUSED-001)
//!
//! - **Precondition**: All slices have equal length
//! - **Postcondition**: m, v, param updated in-place per AdamW equations
//! - **Invariant**: v[i] >= 0 for all i (squared gradient accumulation)
//! - **Invariant**: All outputs finite for finite inputs
/// Fused AdamW parameter update with decoupled weight decay.
///
/// Updates momentum, variance, and parameters in a single pass with
/// zero temporary allocations.
///
/// # Arguments
/// * `grad` - Gradient vector
/// * `m` - First moment (momentum) vector (updated in-place)
/// * `v` - Second moment (variance) vector (updated in-place)
/// * `param` - Parameter vector (updated in-place)
/// * `beta1` - Momentum decay rate
/// * `beta2` - Variance decay rate
/// * `lr` - Learning rate
/// * `lr_t` - Bias-corrected learning rate for adaptive update
/// * `weight_decay` - Weight decay coefficient
/// * `epsilon` - Small constant for numerical stability