use crate::optimizers::Optimizer;
use crate::tensor::Tensor;
use crate::tensor_ops::gradient_descent_ops::adamw;
use crate::variable::VariableID;
use crate::{Context, Float, VariableEnvironment};
pub struct AdamW<F: Float> {
pub alpha: F,
pub eps: F,
pub b1: F,
pub b2: F,
pub weight_decay: F,
pub adamw_namespace_id: &'static str,
}
impl<F: Float> AdamW<F> {
pub fn default(
unique_namespace_id: &'static str,
var_id_list: impl IntoIterator<Item = VariableID>,
env_handle: &mut VariableEnvironment<F>,
) -> AdamW<F> {
AdamW::new(
F::from(0.001).expect("Failed to convert constant to float"), F::from(1e-08).expect("Failed to convert constant to float"), F::from(0.9).expect("Failed to convert constant to float"), F::from(0.999).expect("Failed to convert constant to float"), F::from(0.01).expect("Failed to convert constant to float"), var_id_list,
env_handle,
unique_namespace_id,
)
}
pub fn new(
alpha: F,
eps: F,
b1: F,
b2: F,
weight_decay: F,
var_id_list: impl IntoIterator<Item = VariableID>,
env: &mut VariableEnvironment<F>,
adamw_namespace_id: &'static str,
) -> AdamW<F> {
for vid in var_id_list.into_iter() {
let m_name = format!("{vid}m");
let v_name = format!("{vid}v");
let t_name = format!("{vid}t");
let (m, v, t) = {
let target_var = env
.get_array_by_id(vid)
.expect("variable array not found")
.borrow();
let varshape = target_var.shape();
(
crate::ndarray_ext::zeros(varshape), crate::ndarray_ext::zeros(varshape), crate::ndarray_ext::from_scalar(F::one()), )
};
let mut adamw_ns = env.namespace_mut(adamw_namespace_id);
adamw_ns.slot().name(m_name).set(m);
adamw_ns.slot().name(v_name).set(v);
adamw_ns.slot().name(t_name).set(t);
}
AdamW {
alpha,
eps,
b1,
b2,
weight_decay,
adamw_namespace_id,
}
}
pub fn with_weight_decay(
weight_decay: F,
unique_namespace_id: &'static str,
var_id_list: impl IntoIterator<Item = VariableID>,
env_handle: &mut VariableEnvironment<F>,
) -> AdamW<F> {
AdamW::new(
F::from(0.001).expect("Failed to convert constant to float"), F::from(1e-08).expect("Failed to convert constant to float"), F::from(0.9).expect("Failed to convert constant to float"), F::from(0.999).expect("Failed to convert constant to float"), weight_decay, var_id_list,
env_handle,
unique_namespace_id,
)
}
pub fn with_lr_and_weight_decay(
alpha: F,
weight_decay: F,
unique_namespace_id: &'static str,
var_id_list: impl IntoIterator<Item = VariableID>,
env_handle: &mut VariableEnvironment<F>,
) -> AdamW<F> {
AdamW::new(
alpha, F::from(1e-08).expect("Failed to convert constant to float"), F::from(0.9).expect("Failed to convert constant to float"), F::from(0.999).expect("Failed to convert constant to float"), weight_decay, var_id_list,
env_handle,
unique_namespace_id,
)
}
}
impl<F: Float> Optimizer<F> for AdamW<F> {
fn compute_updates<'g, A, B>(
&self,
params: &[A],
grads: &[B],
g: &'g Context<F>,
) -> Vec<Tensor<'g, F>>
where
A: AsRef<Tensor<'g, F>> + Copy,
B: AsRef<Tensor<'g, F>> + Copy,
{
let num_params = params.len();
assert_eq!(num_params, grads.len());
let mut ret = Vec::with_capacity(num_params);
for i in 0..num_params {
let param = params[i].as_ref();
let namespace = g.namespace(self.adamw_namespace_id);
let var_id = param.get_variable_id().expect("Got non-variable tensor");
let m = g.variable_by_name(format!("{var_id}m"), &namespace);
let v = g.variable_by_name(format!("{var_id}v"), &namespace);
let t = g.variable_by_name(format!("{var_id}t"), &namespace);
let adamw_op = Tensor::builder(g)
.append_input(param, true)
.append_input(grads[i].as_ref(), false)
.append_input(m, true)
.append_input(v, true)
.append_input(t, true)
.build(adamw::AdamWOp {
alpha: self.alpha,
eps: self.eps,
b1: self.b1,
b2: self.b2,
weight_decay: self.weight_decay,
});
eprintln!("Created AdamWOp with all 5 inputs");
ret.push(adamw_op);
}
ret
}
}