use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::neural_network_trait::{Layer, Optimizer};
use crate::neural_network::optimizer::input_validation_function::{
validate_decay_rate, validate_epsilon, validate_learning_rate,
};
use ndarray::{Array2, Array3, Array4, Array5};
const RMS_PROP_PARALLEL_THRESHOLD: usize = 1024;
pub struct RMSprop {
learning_rate: f32,
rho: f32,
epsilon: f32,
}
impl RMSprop {
pub fn new(learning_rate: f32, rho: f32, epsilon: f32) -> Result<Self, ModelError> {
validate_learning_rate(learning_rate)?;
validate_decay_rate(rho, "rho")?;
validate_epsilon(epsilon)?;
Ok(Self {
learning_rate,
rho,
epsilon,
})
}
}
impl Optimizer for RMSprop {
fn update(&mut self, layer: &mut dyn Layer) {
layer.update_parameters_rmsprop(self.learning_rate, self.rho, self.epsilon);
}
}
#[derive(Debug, Clone, Default)]
pub struct RMSpropCache {
pub cache: Array2<f32>,
pub cache_recurrent: Option<Array2<f32>>,
pub bias: Array2<f32>,
}
impl RMSpropCache {
pub fn new(
dims: (usize, usize),
recurrent_dims: Option<(usize, usize)>,
bias_dims: (usize, usize),
) -> Self {
Self {
cache: Array2::<f32>::zeros(dims),
cache_recurrent: recurrent_dims.map(|dims| Array2::<f32>::zeros(dims)),
bias: Array2::<f32>::zeros(bias_dims),
}
}
pub fn update_param(
param: &mut Array2<f32>,
grad: &Array2<f32>,
cache: &mut Array2<f32>,
rho: f32,
lr: f32,
epsilon: f32,
) {
let use_parallel = param.len() >= RMS_PROP_PARALLEL_THRESHOLD;
if use_parallel {
let (new_cache, new_param) = rayon::join(
|| cache.mapv(|x| x * rho) + &(grad.mapv(|x| x * x) * (1.0 - rho)),
|| {
let temp_cache =
cache.mapv(|x| x * rho) + &(grad.mapv(|x| x * x) * (1.0 - rho));
&*param - &(lr * grad / &(temp_cache.mapv(f32::sqrt) + epsilon))
},
);
*cache = new_cache;
*param = new_param;
} else {
*cache = cache.mapv(|x| x * rho) + &(grad.mapv(|x| x * x) * (1.0 - rho));
*param = &*param - &(lr * grad / &(cache.mapv(f32::sqrt) + epsilon));
}
}
pub fn update_parameters(
&mut self,
param: &mut Array2<f32>,
recurrent_param: Option<&mut Array2<f32>>,
bias_param: &mut Array2<f32>,
grad: &Array2<f32>,
recurrent_grad: Option<&Array2<f32>>,
bias_grad: &Array2<f32>,
rho: f32,
lr: f32,
epsilon: f32,
) {
Self::update_param(param, grad, &mut self.cache, rho, lr, epsilon);
if let (Some(rec_param), Some(rec_grad), Some(ref mut rec_cache)) = (
recurrent_param,
recurrent_grad,
self.cache_recurrent.as_mut(),
) {
Self::update_param(rec_param, rec_grad, rec_cache, rho, lr, epsilon);
}
Self::update_param(bias_param, bias_grad, &mut self.bias, rho, lr, epsilon);
}
}
#[derive(Debug, Clone, Default)]
pub struct RMSpropCacheConv1D {
pub cache: Option<Array3<f32>>,
pub bias: Option<Array2<f32>>,
}
#[derive(Debug, Clone, Default)]
pub struct RMSpropCacheConv2D {
pub cache: Array4<f32>,
pub bias: Array2<f32>,
}
#[derive(Debug, Clone, Default)]
pub struct RMSpropCacheConv3D {
pub cache: Array5<f32>,
pub bias: Array2<f32>,
}
#[derive(Debug, Clone, Default)]
pub struct RMSpropCacheNormalizationLayer {
pub cache_gamma: Tensor,
pub cache_beta: Tensor,
}