use crate::error::ModelError;
use crate::neural_network::neural_network_trait::{Layer, Optimizer};
use crate::neural_network::optimizer::input_validation_function::validate_learning_rate;
use ndarray::Array2;
use rayon::iter::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
};
const SGD_PARALLEL_THRESHOLD: usize = 1024;
pub struct SGD {
learning_rate: f32,
}
impl SGD {
pub fn new(learning_rate: f32) -> Result<Self, ModelError> {
validate_learning_rate(learning_rate)?;
Ok(Self { learning_rate })
}
pub fn update_sgd_parameters(
weights: &mut [f32],
weight_grads: &[f32],
bias: &mut [f32],
bias_grads: &[f32],
lr: f32,
) {
let use_parallel =
weights.len() >= SGD_PARALLEL_THRESHOLD || bias.len() >= SGD_PARALLEL_THRESHOLD;
let update_fn = |params: &mut [f32], grads: &[f32]| {
if use_parallel {
params
.par_iter_mut()
.zip(grads.par_iter())
.for_each(|(p, g)| *p -= *g * lr);
} else {
params
.iter_mut()
.zip(grads.iter())
.for_each(|(p, g)| *p -= *g * lr);
}
};
if use_parallel {
rayon::join(
|| update_fn(weights, weight_grads),
|| update_fn(bias, bias_grads),
);
} else {
update_fn(weights, weight_grads);
update_fn(bias, bias_grads);
}
}
pub fn update_sgd_parameters_rnn(
kernel: &mut Array2<f32>,
grad_kernel: &Array2<f32>,
recurrent_kernel: &mut Array2<f32>,
grad_recurrent_kernel: &Array2<f32>,
bias: &mut Array2<f32>,
grad_bias: &Array2<f32>,
lr: f32,
) {
let use_parallel = kernel.len() >= SGD_PARALLEL_THRESHOLD
|| recurrent_kernel.len() >= SGD_PARALLEL_THRESHOLD
|| bias.len() >= SGD_PARALLEL_THRESHOLD;
if use_parallel {
rayon::join(
|| {
rayon::join(
|| *kernel = kernel.clone() - (grad_kernel.clone() * lr),
|| {
*recurrent_kernel =
recurrent_kernel.clone() - (grad_recurrent_kernel.clone() * lr)
},
)
},
|| *bias = bias.clone() - (grad_bias.clone() * lr),
);
} else {
*kernel = kernel.clone() - (grad_kernel.clone() * lr);
*recurrent_kernel = recurrent_kernel.clone() - (grad_recurrent_kernel.clone() * lr);
*bias = bias.clone() - (grad_bias.clone() * lr);
}
}
}
impl Optimizer for SGD {
fn update(&mut self, layer: &mut dyn Layer) {
layer.update_parameters_sgd(self.learning_rate);
}
}