use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::neural_network_trait::{Layer, Optimizer};
use crate::neural_network::optimizer::input_validation_function::{
validate_decay_rate, validate_epsilon, validate_learning_rate,
};
use ndarray::{Array2, Array3, Array4, Array5};
const ADAM_PARALLEL_THRESHOLD: usize = 1024;
pub struct Adam {
learning_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
t: u64,
}
impl Adam {
pub fn new(
learning_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
) -> Result<Self, ModelError> {
validate_learning_rate(learning_rate)?;
validate_decay_rate(beta1, "beta1")?;
validate_decay_rate(beta2, "beta2")?;
validate_epsilon(epsilon)?;
Ok(Self {
learning_rate,
beta1,
beta2,
epsilon,
t: 0,
})
}
}
impl Optimizer for Adam {
fn update(&mut self, layer: &mut dyn Layer) {
self.t += 1; layer.update_parameters_adam(
self.learning_rate,
self.beta1,
self.beta2,
self.epsilon,
self.t,
);
}
}
#[derive(Debug, Clone, Default)]
pub struct AdamStates {
pub m: Array2<f32>,
pub v: Array2<f32>,
pub m_recurrent: Option<Array2<f32>>,
pub v_recurrent: Option<Array2<f32>>,
pub m_bias: Array2<f32>,
pub v_bias: Array2<f32>,
}
impl AdamStates {
pub fn new(
dims_param: (usize, usize),
dims_recurrent: Option<(usize, usize)>,
dims_bias: (usize, usize),
) -> Self {
let m_recurrent = dims_recurrent.map(|dims| Array2::zeros(dims));
let v_recurrent = dims_recurrent.map(|dims| Array2::zeros(dims));
Self {
m: Array2::zeros(dims_param),
v: Array2::zeros(dims_param),
m_recurrent,
v_recurrent,
m_bias: Array2::zeros(dims_bias),
v_bias: Array2::zeros(dims_bias),
}
}
pub fn update_parameter(
&mut self,
grad_param: &Array2<f32>,
grad_recurrent: Option<&Array2<f32>>,
grad_bias: &Array2<f32>,
beta1: f32,
beta2: f32,
epsilon: f32,
t: u64,
lr: f32,
) -> (Array2<f32>, Option<Array2<f32>>, Array2<f32>) {
Self::update_adam_param(&mut self.m, &mut self.v, grad_param, beta1, beta2);
let recurrent_update = if let (Some(m_r), Some(v_r), Some(g_r)) = (
self.m_recurrent.as_mut(),
self.v_recurrent.as_mut(),
grad_recurrent,
) {
Self::update_adam_param(m_r, v_r, g_r, beta1, beta2);
Some((m_r, v_r))
} else {
None
};
Self::update_adam_param(&mut self.m_bias, &mut self.v_bias, grad_bias, beta1, beta2);
let use_parallel = self.m.len() >= ADAM_PARALLEL_THRESHOLD;
let (param_update, bias_update) = if use_parallel {
let (m_hat, v_hat) = rayon::join(
|| self.m.mapv(|x| x / (1.0 - beta1.powi(t as i32))),
|| self.v.mapv(|x| x / (1.0 - beta2.powi(t as i32))),
);
let (m_hat_bias, v_hat_bias) = rayon::join(
|| self.m_bias.mapv(|x| x / (1.0 - beta1.powi(t as i32))),
|| self.v_bias.mapv(|x| x / (1.0 - beta2.powi(t as i32))),
);
rayon::join(
|| lr * &m_hat / &(v_hat.mapv(f32::sqrt) + epsilon),
|| lr * &m_hat_bias / &(v_hat_bias.mapv(f32::sqrt) + epsilon),
)
} else {
let m_hat = self.m.mapv(|x| x / (1.0 - beta1.powi(t as i32)));
let v_hat = self.v.mapv(|x| x / (1.0 - beta2.powi(t as i32)));
let m_hat_bias = self.m_bias.mapv(|x| x / (1.0 - beta1.powi(t as i32)));
let v_hat_bias = self.v_bias.mapv(|x| x / (1.0 - beta2.powi(t as i32)));
(
lr * &m_hat / &(v_hat.mapv(f32::sqrt) + epsilon),
lr * &m_hat_bias / &(v_hat_bias.mapv(f32::sqrt) + epsilon),
)
};
let recurrent_update = recurrent_update.map(|(m_r, v_r)| {
if use_parallel {
let (m_hat_r, v_hat_r) = rayon::join(
|| m_r.mapv(|x| x / (1.0 - beta1.powi(t as i32))),
|| v_r.mapv(|x| x / (1.0 - beta2.powi(t as i32))),
);
lr * &m_hat_r / &(v_hat_r.mapv(f32::sqrt) + epsilon)
} else {
let m_hat_r = m_r.mapv(|x| x / (1.0 - beta1.powi(t as i32)));
let v_hat_r = v_r.mapv(|x| x / (1.0 - beta2.powi(t as i32)));
lr * &m_hat_r / &(v_hat_r.mapv(f32::sqrt) + epsilon)
}
});
(param_update, recurrent_update, bias_update)
}
fn update_adam_param(
m: &mut Array2<f32>,
v: &mut Array2<f32>,
g: &Array2<f32>,
beta1: f32,
beta2: f32,
) {
let use_parallel = m.len() >= ADAM_PARALLEL_THRESHOLD;
if use_parallel {
let (m_updated, v_updated) = rayon::join(
|| m.mapv(|x| x * beta1) + &(g * (1.0 - beta1)),
|| v.mapv(|x| x * beta2) + &(g.mapv(|x| x * x) * (1.0 - beta2)),
);
*m = m_updated;
*v = v_updated;
} else {
*m = m.mapv(|x| x * beta1) + &(g * (1.0 - beta1));
*v = v.mapv(|x| x * beta2) + &(g.mapv(|x| x * x) * (1.0 - beta2));
}
}
}
#[derive(Debug, Clone, Default)]
pub struct AdamStatesConv1D {
pub m: Array3<f32>,
pub v: Array3<f32>,
pub m_bias: Array2<f32>,
pub v_bias: Array2<f32>,
}
#[derive(Debug, Clone, Default)]
pub struct AdamStatesConv2D {
pub m: Array4<f32>,
pub v: Array4<f32>,
pub m_bias: Array2<f32>,
pub v_bias: Array2<f32>,
}
#[derive(Debug, Clone, Default)]
pub struct AdamStatesConv3D {
pub m: Array5<f32>,
pub v: Array5<f32>,
pub m_bias: Array2<f32>,
pub v_bias: Array2<f32>,
}
#[derive(Debug, Clone, Default)]
pub struct AdamStatesNormalizationLayer {
pub m_gamma: Tensor,
pub v_gamma: Tensor,
pub m_beta: Tensor,
pub v_beta: Tensor,
}