use super::kind::OptimizerError;
pub struct AdamwConfig {
pub learning_rate: f32,
pub step: u32,
pub beta1: f32,
pub beta2: f32,
pub eps: f32,
pub weight_decay: f32,
}
pub fn step_adamw(
params: &mut [f32],
grads: &[f32],
m: &mut [f32],
v: &mut [f32],
cfg: AdamwConfig,
) -> Result<(), OptimizerError> {
let AdamwConfig { learning_rate, step, beta1, beta2, eps, weight_decay } = cfg;
if step == 0 {
return Err(OptimizerError::StepOverflow);
}
if !beta1.is_finite() || !beta2.is_finite() || !eps.is_finite() || !weight_decay.is_finite() {
return Err(OptimizerError::InvalidHyperParams);
}
if !(0.0..1.0).contains(&beta1) || !(0.0..1.0).contains(&beta2) || eps <= 0.0 {
return Err(OptimizerError::InvalidHyperParams);
}
if params.len() != grads.len() || m.len() != params.len() || v.len() != params.len() {
return Err(OptimizerError::ShapeMismatch);
}
let t = step as f32;
let bc1 = 1.0 - crate::math::powf(beta1, t);
let bc2 = 1.0 - crate::math::powf(beta2, t);
if bc1 <= 0.0 || bc2 <= 0.0 {
return Err(OptimizerError::StepOverflow);
}
for i in 0..params.len() {
let g = grads[i] + weight_decay * params[i];
m[i] = beta1 * m[i] + (1.0 - beta1) * g;
v[i] = beta2 * v[i] + (1.0 - beta2) * g * g;
let m_hat = m[i] / bc1;
let v_hat = v[i] / bc2;
params[i] -= learning_rate * m_hat / (crate::math::sqrtf(v_hat) + eps);
}
Ok(())
}