use crate::error::ModelError;
use crate::neural_network::layer::recurrent_layer::input_validation_function::validate_dimension_greater_than_zero;
use crate::neural_network::optimizer::OptimizerCache;
use crate::neural_network::optimizer::ada_grad::AdaGradStates;
use crate::neural_network::optimizer::adam::AdamStates;
use crate::neural_network::optimizer::rms_prop::RMSpropCache;
use ndarray::{Array, Array2};
use ndarray_rand::{RandomExt, rand_distr::Uniform};
pub struct Gate {
pub kernel: Array2<f32>,
pub recurrent_kernel: Array2<f32>,
pub bias: Array2<f32>,
pub grad_kernel: Option<Array2<f32>>,
pub grad_recurrent_kernel: Option<Array2<f32>>,
pub grad_bias: Option<Array2<f32>>,
pub optimizer_cache: OptimizerCache,
}
impl Gate {
pub fn new(input_dim: usize, units: usize, bias_init_value: f32) -> Result<Self, ModelError> {
validate_dimension_greater_than_zero(input_dim, "input_dim")?;
validate_dimension_greater_than_zero(units, "units")?;
let limit = (6.0 / (input_dim + units) as f32).sqrt();
let kernel = Array::random((input_dim, units), Uniform::new(-limit, limit).unwrap());
let mut recurrent_kernel = Array::random((units, units), Uniform::new(-1.0, 1.0).unwrap());
if units > 0 {
for mut col in recurrent_kernel.columns_mut() {
let norm = col.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
col /= norm;
}
}
}
let bias = Array::from_elem((1, units), bias_init_value);
Ok(Self {
kernel,
recurrent_kernel,
bias,
grad_kernel: None,
grad_recurrent_kernel: None,
grad_bias: None,
optimizer_cache: OptimizerCache::default(),
})
}
}
const GRADIENT_CLIP_VALUE: f32 = 5.0;
#[inline]
pub fn compute_gate_value(gate: &Gate, x_t: &Array2<f32>, h_prev: &Array2<f32>) -> Array2<f32> {
x_t.dot(&gate.kernel) + h_prev.dot(&gate.recurrent_kernel) + &gate.bias
}
#[inline]
pub fn take_cache<T>(cache: &mut Option<T>, error_msg: &str) -> Result<T, ModelError> {
cache
.take()
.ok_or_else(|| ModelError::ProcessingError(error_msg.to_string()))
}
#[inline]
pub fn store_gate_gradients(
gate: &mut Gate,
grad_kernel: Array2<f32>,
grad_recurrent: Array2<f32>,
grad_bias: Array2<f32>,
) {
gate.grad_kernel = Some(grad_kernel);
gate.grad_recurrent_kernel = Some(grad_recurrent);
gate.grad_bias = Some(grad_bias);
}
#[inline]
pub fn update_gate_sgd(gate: &mut Gate, lr: f32) {
if let (Some(gk), Some(grk), Some(gb)) = (
&gate.grad_kernel,
&gate.grad_recurrent_kernel,
&gate.grad_bias,
) {
let gk_clipped = gk.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
let grk_clipped = grk.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
let gb_clipped = gb.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
gate.kernel = &gate.kernel - &(lr * &gk_clipped);
gate.recurrent_kernel = &gate.recurrent_kernel - &(lr * &grk_clipped);
gate.bias = &gate.bias - &(lr * &gb_clipped);
}
}
#[inline]
pub fn update_gate_adam(
gate: &mut Gate,
input_dim: usize,
units: usize,
lr: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
t: u64,
) {
if gate.optimizer_cache.adam_states.is_none() {
gate.optimizer_cache.adam_states = Some(AdamStates::new(
(input_dim, units),
Some((units, units)),
(1, units),
));
}
if let (Some(gk), Some(grk), Some(gb)) = (
&gate.grad_kernel,
&gate.grad_recurrent_kernel,
&gate.grad_bias,
) {
let gk_clipped = gk.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
let grk_clipped = grk.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
let gb_clipped = gb.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
let adam_states = gate.optimizer_cache.adam_states.as_mut().unwrap();
let (k_update, rk_update, b_update) = adam_states.update_parameter(
&gk_clipped,
Some(&grk_clipped),
&gb_clipped,
beta1,
beta2,
epsilon,
t,
lr,
);
gate.kernel = &gate.kernel - &k_update;
gate.recurrent_kernel = &gate.recurrent_kernel - &rk_update.unwrap();
gate.bias = &gate.bias - &b_update;
}
}
#[inline]
pub fn update_gate_rmsprop(
gate: &mut Gate,
input_dim: usize,
units: usize,
lr: f32,
rho: f32,
epsilon: f32,
) {
if gate.optimizer_cache.rmsprop_cache.is_none() {
gate.optimizer_cache.rmsprop_cache = Some(RMSpropCache::new(
(input_dim, units),
Some((units, units)),
(1, units),
));
}
if let (Some(gk), Some(grk), Some(gb)) = (
&gate.grad_kernel,
&gate.grad_recurrent_kernel,
&gate.grad_bias,
) {
let gk_clipped = gk.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
let grk_clipped = grk.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
let gb_clipped = gb.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
if let Some(ref mut cache) = gate.optimizer_cache.rmsprop_cache {
cache.update_parameters(
&mut gate.kernel,
Some(&mut gate.recurrent_kernel),
&mut gate.bias,
&gk_clipped,
Some(&grk_clipped),
&gb_clipped,
rho,
lr,
epsilon,
);
}
}
}
#[inline]
pub fn update_gate_ada_grad(
gate: &mut Gate,
input_dim: usize,
units: usize,
lr: f32,
epsilon: f32,
) {
if gate.optimizer_cache.ada_grad_cache.is_none() {
gate.optimizer_cache.ada_grad_cache = Some(AdaGradStates::new(
(input_dim, units),
Some((units, units)),
(1, units),
));
}
if let (Some(gk), Some(grk), Some(gb)) = (
&gate.grad_kernel,
&gate.grad_recurrent_kernel,
&gate.grad_bias,
) {
let gk_clipped = gk.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
let grk_clipped = grk.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
let gb_clipped = gb.mapv(|x| x.clamp(-GRADIENT_CLIP_VALUE, GRADIENT_CLIP_VALUE));
let ada_grad_cache = gate.optimizer_cache.ada_grad_cache.as_mut().unwrap();
let (k_update, rk_update, b_update) = ada_grad_cache.update_parameter(
&gk_clipped,
Some(&grk_clipped),
&gb_clipped,
epsilon,
lr,
);
gate.kernel = &gate.kernel - &k_update;
gate.recurrent_kernel = &gate.recurrent_kernel - &rk_update.unwrap();
gate.bias = &gate.bias - &b_update;
}
}