#[allow(clippy::too_many_arguments)]
pub fn simd_adam_update(
grad: &[f32],
m: &mut [f32],
v: &mut [f32],
param: &mut [f32],
beta1: f32,
beta2: f32,
lr_t: f32,
epsilon: f32,
) {
assert_eq!(grad.len(), m.len(), "Gradient and momentum lengths must match");
assert_eq!(grad.len(), v.len(), "Gradient and variance lengths must match");
assert_eq!(grad.len(), param.len(), "Gradient and parameter lengths must match");
let one_minus_beta1 = 1.0 - beta1;
let one_minus_beta2 = 1.0 - beta2;
for i in 0..grad.len() {
m[i] = beta1 * m[i] + one_minus_beta1 * grad[i];
v[i] = beta2 * v[i] + one_minus_beta2 * grad[i] * grad[i];
param[i] -= lr_t * m[i] / (v[i].sqrt() + epsilon);
}
}