use crate::Result;
use super::hip_adamw_bridge::run_rocm_hip_adamw_step;
use super::parameter::{Parameter, f32_to_fp16_bits};
#[derive(Debug, Clone)]
pub struct AdamW {
pub lr: f32,
pub beta1: f32,
pub beta2: f32,
pub eps: f32,
pub weight_decay: f32,
pub step: u32,
}
impl AdamW {
pub fn new(lr: f32, beta1: f32, beta2: f32, eps: f32, weight_decay: f32) -> Self {
Self {
lr,
beta1,
beta2,
eps,
weight_decay,
step: 0,
}
}
pub fn step(&mut self, param: &mut Parameter, grad: &[f32]) -> Result<()> {
let n = param.weight.data.len();
if grad.len() != n {
return Err(crate::Error::backend(format!(
"AdamW step grad length {} does not match parameter length {}",
grad.len(),
n
)));
}
let grad_bits: Vec<u16> = grad.iter().map(|&g| f32_to_fp16_bits(g)).collect();
self.step = self.step.saturating_add(1);
param.step = param.step.saturating_add(1);
run_rocm_hip_adamw_step(
&mut param.weight.data,
&mut param.m,
&mut param.v,
&grad_bits,
self.lr,
self.beta1,
self.beta2,
self.eps,
self.weight_decay,
self.step as i32,
)
}
}