use crate::ndarray_ext::NdArray;
use crate::op::OpError;
use crate::Float;
pub(crate) struct AdamWOp<F: Float> {
pub(crate) alpha: F,
pub(crate) eps: F,
pub(crate) b1: F,
pub(crate) b2: F,
pub(crate) weight_decay: F,
}
impl<F: Float> crate::op::Op<F> for AdamWOp<F> {
fn compute(&self, ctx: &mut crate::op::ComputeContext<F>) -> Result<(), OpError> {
eprintln!(
"AdamWOp::compute - Number of inputs: {}",
ctx.inputs().len()
);
for (i, input) in ctx.inputs().iter().enumerate() {
eprintln!("Input {}: shape {:?}", i, input.shape());
}
if ctx.inputs().len() < 5 {
return Err(OpError::IncompatibleShape(format!(
"AdamWOp requires 5 inputs, but got {}",
ctx.inputs().len()
)));
}
let param = ctx.input(0).to_owned(); let grad = ctx.input(1).to_owned(); let m = ctx.input(2).to_owned(); let v = ctx.input(3).to_owned(); let t_array = ctx.input(4).to_owned();
let gradshape = grad.shape().to_vec();
let t_val = t_array[scirs2_core::ndarray::IxDyn(&[])];
let new_t = t_val + F::one();
let new_t_array = NdArray::from_elem(scirs2_core::ndarray::IxDyn(&[]), new_t);
let mut new_m: NdArray<F>;
let mut new_v: NdArray<F>;
if m.shape().is_empty() && !gradshape.is_empty() {
let m_val = m[scirs2_core::ndarray::IxDyn(&[])];
new_m = NdArray::from_elem(scirs2_core::ndarray::IxDyn(&gradshape), m_val);
} else {
new_m = m.to_owned();
}
if v.shape().is_empty() && !gradshape.is_empty() {
let v_val = v[scirs2_core::ndarray::IxDyn(&[])];
new_v = NdArray::from_elem(scirs2_core::ndarray::IxDyn(&gradshape), v_val);
} else {
new_v = v.to_owned();
}
let mut new_param: NdArray<F>;
if param.shape().is_empty() && !gradshape.is_empty() {
let param_val = param[scirs2_core::ndarray::IxDyn(&[])];
new_param = NdArray::from_elem(scirs2_core::ndarray::IxDyn(&gradshape), param_val);
} else {
new_param = param.to_owned();
}
let tmp_b1 = F::one() - self.b1;
new_m.zip_mut_with(&grad, move |m_val, g_val| {
*m_val = *m_val * self.b1 + tmp_b1 * *g_val
});
let tmp_b2 = F::one() - self.b2;
new_v.zip_mut_with(&grad, move |v_val, g_val| {
*v_val = *v_val * self.b2 + tmp_b2 * *g_val * *g_val
});
let m_correction = F::one() / (F::one() - self.b1.powf(new_t));
let v_correction = F::one() / (F::one() - self.b2.powf(new_t));
let m_hat = new_m.mapv(move |m_val| m_val * m_correction);
let v_hat = new_v.mapv(move |v_val| v_val * v_correction);
let mut grad_update = m_hat.to_owned();
grad_update.zip_mut_with(&v_hat, move |m_hat_val, v_hat_val| {
*m_hat_val /= v_hat_val.sqrt() + self.eps;
});
new_param.zip_mut_with(&grad_update, move |param_val, grad_update_val| {
*param_val *= F::one() - self.alpha * self.weight_decay;
*param_val -= self.alpha * *grad_update_val;
});
ctx.append_output(new_param); ctx.append_output(grad); ctx.append_output(new_m); ctx.append_output(new_v); ctx.append_output(new_t_array);
Ok(())
}
fn grad(&self, ctx: &mut crate::op::GradientContext<F>) {
ctx.append_input_grad(0, None);
ctx.append_input_grad(1, None);
ctx.append_input_grad(2, None);
ctx.append_input_grad(3, None);
ctx.append_input_grad(4, None);
}
}