use crate::ops::gradient_descent_ops::adam;
use crate::tensor::{Input, Tensor, Variable};
use crate::Float;
use crate::Graph;
use crate::NdArray;
use std::sync::{Arc, RwLock};
pub struct Adam<F: Float> {
static_params: StaticParams<F>,
}
impl<T: Float> Default for Adam<T> {
fn default() -> Adam<T> {
let static_params = StaticParams {
alpha: T::from(0.001).unwrap(),
eps: T::from(1e-08).unwrap(),
b1: T::from(0.9).unwrap(),
b2: T::from(0.999).unwrap(),
};
Adam { static_params }
}
}
impl<'t, 's: 't, F: Float> Adam<F> {
pub fn new(static_params: StaticParams<F>) -> Self {
Adam { static_params }
}
pub fn compute_updates(
&self,
params: &[Tensor<'s, F>],
grads: &[Tensor<'s, F>],
states: &AdamState<F>,
g: &'s Graph<F>,
) -> Vec<Tensor<'s, F>> {
let num_params = params.len();
let mut ret = Vec::with_capacity(num_params);
for i in 0..num_params {
let param = ¶ms[i];
let a: *const RwLock<NdArray<F>> = param
.get_variable_array_ptr()
.expect("Adam requires *variables* as its inputs.");
let key = a as usize;
let state = states
.var2state
.get(&key)
.expect("Adam: state object wasn't fed correctly");
let m = g.variable(state.m.clone());
let v = g.variable(state.v.clone());
let t = g.variable(state.t.clone());
ret.push(
Tensor::builder()
.set_inputs(&[
Input::new_mut(param),
Input::new(&grads[i]),
Input::new_mut(&m),
Input::new_mut(&v),
Input::new_mut(&t),
])
.build(
g,
adam::AdamOp {
static_params: self.static_params.clone(),
},
),
);
}
ret
}
}
struct StateArrays<F: Float> {
m: Arc<RwLock<NdArray<F>>>,
v: Arc<RwLock<NdArray<F>>>,
t: Arc<RwLock<NdArray<F>>>,
}
pub struct AdamState<F: Float> {
var2state: crate::FxHashMap<usize, StateArrays<F>>,
}
impl<F: Float> AdamState<F> {
pub fn new(variables: &[&Arc<RwLock<NdArray<F>>>]) -> Self {
let mut map = crate::FxHashMap::default();
for &var in variables {
let key = ((&**var) as *const RwLock<_>) as usize;
let var = var.read().unwrap();
let var_shape = var.shape();
map.insert(
key,
StateArrays {
m: Arc::new(RwLock::new(crate::ndarray_ext::zeros(var_shape))),
v: Arc::new(RwLock::new(crate::ndarray_ext::zeros(var_shape))),
t: Arc::new(RwLock::new(crate::ndarray_ext::from_scalar(F::one()))),
},
);
}
Self { var2state: map }
}
}
#[derive(Clone)]
pub struct StaticParams<T: Float> {
pub alpha: T,
pub eps: T,
pub b1: T,
pub b2: T,
}