use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::layer::regularization_layer::dropout_layer::{
dropout_backward, dropout_output_shape,
};
use crate::neural_network::layer::regularization_layer::input_validation_function::{
validate_input_shape, validate_rate,
};
use crate::neural_network::neural_network_trait::Layer;
use ndarray_rand::{RandomExt, rand_distr::Uniform};
const DROPOUT_PARALLEL_THRESHOLD: usize = 10000;
pub struct Dropout {
rate: f32,
input_shape: Vec<usize>,
mask: Option<Tensor>,
training: bool,
}
impl Dropout {
pub fn new(rate: f32, input_shape: Vec<usize>) -> Result<Self, ModelError> {
validate_rate(rate, "Dropout rate")?;
Ok(Dropout {
rate,
input_shape,
mask: None,
training: true,
})
}
mode_dependent_layer_set_training!();
}
impl Layer for Dropout {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
validate_rate(self.rate, "Dropout rate")?;
validate_input_shape(input.shape(), &self.input_shape)?;
if !self.training {
return Ok(input.clone());
}
if self.rate == 0.0 {
eprintln!("Dropout rate is 0.0, so this layer has no effect on the output.");
return Ok(input.clone());
}
if self.rate == 1.0 {
eprintln!("Dropout rate is 1.0, so this layer will return all zeros.");
return Ok(Tensor::zeros(input.raw_dim()));
}
let mut mask = Tensor::random(input.raw_dim(), Uniform::new(0.0, 1.0).unwrap());
if input.len() >= DROPOUT_PARALLEL_THRESHOLD {
mask.par_mapv_inplace(|x| if x >= self.rate { 1.0 } else { 0.0 });
} else {
mask.mapv_inplace(|x| if x >= self.rate { 1.0 } else { 0.0 });
}
let scale = 1.0 / (1.0 - self.rate);
let output = input * &mask * scale;
self.mask = Some(mask);
Ok(output)
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
dropout_backward(grad_output, &self.mask, self.training, self.rate)
}
fn layer_type(&self) -> &str {
"Dropout"
}
fn output_shape(&self) -> String {
dropout_output_shape(&self.input_shape)
}
no_trainable_parameters_layer_functions!();
mode_dependent_layer_trait!();
}