use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::neural_network_trait::{Layer, Optimizer};
use crate::neural_network::optimizer::input_validation_function::validate_positive_finite;
use ndarray::{Array2, Array3, Array4, Array5};
const ADA_GRAD_PARALLEL_THRESHOLD: usize = 1024;
pub struct AdaGrad {
learning_rate: f32,
epsilon: f32,
}
impl AdaGrad {
pub fn new(learning_rate: f32, epsilon: f32) -> Result<Self, ModelError> {
validate_positive_finite(learning_rate, "learning_rate")?;
validate_positive_finite(epsilon, "epsilon")?;
Ok(Self {
learning_rate,
epsilon,
})
}
}
impl Optimizer for AdaGrad {
fn update(&mut self, layer: &mut dyn Layer) {
layer.update_parameters_ada_grad(self.learning_rate, self.epsilon);
}
}
#[derive(Debug, Clone, Default)]
pub struct AdaGradStates {
pub accumulator: Array2<f32>,
pub accumulator_recurrent: Option<Array2<f32>>,
pub accumulator_bias: Array2<f32>,
}
impl AdaGradStates {
pub fn new(
dims_param: (usize, usize),
dims_recurrent: Option<(usize, usize)>,
dims_bias: (usize, usize),
) -> Self {
let accumulator_recurrent = dims_recurrent.map(|dims| Array2::zeros(dims));
Self {
accumulator: Array2::zeros(dims_param),
accumulator_recurrent,
accumulator_bias: Array2::zeros(dims_bias),
}
}
pub fn update_parameter(
&mut self,
grad_param: &Array2<f32>,
grad_recurrent: Option<&Array2<f32>>,
grad_bias: &Array2<f32>,
epsilon: f32,
lr: f32,
) -> (Array2<f32>, Option<Array2<f32>>, Array2<f32>) {
let use_parallel = self.accumulator.len() >= ADA_GRAD_PARALLEL_THRESHOLD;
if use_parallel {
rayon::join(
|| Self::update_ada_grad_param(&mut self.accumulator, grad_param),
|| Self::update_ada_grad_param(&mut self.accumulator_bias, grad_bias),
);
} else {
Self::update_ada_grad_param(&mut self.accumulator, grad_param);
Self::update_ada_grad_param(&mut self.accumulator_bias, grad_bias);
}
let recurrent_accumulator = if let (Some(acc_r), Some(g_r)) =
(self.accumulator_recurrent.as_mut(), grad_recurrent)
{
Self::update_ada_grad_param(acc_r, g_r);
Some(acc_r)
} else {
None
};
let (param_update, bias_update) = if use_parallel {
rayon::join(
|| lr * grad_param / &(self.accumulator.mapv(f32::sqrt) + epsilon),
|| lr * grad_bias / &(self.accumulator_bias.mapv(f32::sqrt) + epsilon),
)
} else {
(
lr * grad_param / &(self.accumulator.mapv(f32::sqrt) + epsilon),
lr * grad_bias / &(self.accumulator_bias.mapv(f32::sqrt) + epsilon),
)
};
let recurrent_update = recurrent_accumulator
.map(|acc_r| lr * grad_recurrent.unwrap() / &(acc_r.mapv(f32::sqrt) + epsilon));
(param_update, recurrent_update, bias_update)
}
fn update_ada_grad_param(accumulator: &mut Array2<f32>, g: &Array2<f32>) {
*accumulator = &*accumulator + &g.mapv(|x| x * x);
}
}
#[derive(Debug, Clone, Default)]
pub struct AdaGradStatesConv1D {
pub accumulator: Array3<f32>,
pub accumulator_bias: Array2<f32>,
}
#[derive(Debug, Clone, Default)]
pub struct AdaGradStatesConv2D {
pub accumulator: Array4<f32>,
pub accumulator_bias: Array2<f32>,
}
#[derive(Debug, Clone, Default)]
pub struct AdaGradStatesConv3D {
pub accumulator: Array5<f32>,
pub accumulator_bias: Array2<f32>,
}
#[derive(Debug, Clone, Default)]
pub struct AdaGradStatesNormalizationLayer {
pub acc_grad_gamma: Tensor,
pub acc_grad_beta: Tensor,
}