use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
use scirs2_core::numeric::Float;
use scirs2_core::random::Rng;
use scirs2_core::Random;
use std::cell::RefCell;
use std::fmt::Debug;
use crate::error::Result;
use crate::regularizers::Regularizer;
#[derive(Debug)]
pub struct Dropout<A: Float + Debug> {
rate: A,
rng: RefCell<Random<scirs2_core::random::rngs::StdRng>>,
training: bool,
mask: RefCell<Option<Array<A, scirs2_core::ndarray::IxDyn>>>,
}
impl<A: Float + Debug + Send + Sync> Dropout<A> {
pub fn new<R: Rng>(rate: A, rng: &mut R) -> Self {
let rate = rate.max(A::zero()).min(A::one());
let mut seed_bytes = [0u8; 8];
rng.fill_bytes(&mut seed_bytes);
let seed = u64::from_ne_bytes(seed_bytes);
let rng = Random::seed(seed);
Self {
rate,
rng: RefCell::new(rng),
training: true,
mask: RefCell::new(None),
}
}
pub fn rate(&self) -> A {
self.rate
}
pub fn set_rate(&mut self, rate: A) -> &mut Self {
self.rate = rate.max(A::zero()).min(A::one());
*self.mask.borrow_mut() = None;
self
}
pub fn train(&mut self) -> &mut Self {
self.training = true;
self
}
pub fn eval(&mut self) -> &mut Self {
self.training = false;
self
}
pub fn is_training(&self) -> bool {
self.training
}
fn create_mask<D: Dimension>(&self, shape: D) -> Array<A, D> {
if !self.training || self.rate <= A::zero() {
return Array::ones(shape);
}
let keep_prob = A::one() - self.rate;
let scale = A::one() / keep_prob;
let mut rng = self.rng.borrow_mut();
let mut mask = Array::zeros(shape);
for elem in mask.iter_mut() {
let rand_val =
A::from(rng.gen_range(0.0..1.0)).expect("failed to convert random value");
if rand_val > self.rate {
*elem = scale;
}
}
mask
}
}
impl<A, D> Regularizer<A, D> for Dropout<A>
where
A: Float + ScalarOperand + Debug + Send + Sync,
D: Dimension<Pattern = D>,
{
fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
if !self.training || self.rate <= A::zero() {
return Ok(A::zero());
}
let mask = {
let mask_ref = self.mask.borrow();
match &*mask_ref {
Some(m) if m.shape() == gradients.shape() => {
m.clone()
.into_dimensionality::<D>()
.expect("mask dimensionality conversion failed")
}
_ => {
drop(mask_ref);
self.create_mask(gradients.dim())
}
}
};
Zip::from(gradients).and(&mask).for_each(|grad, &mask_val| {
*grad = *grad * mask_val;
});
Ok(A::zero())
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
Ok(A::zero())
}
}