use crate::prelude::*;
use rand::{prelude::StdRng, Rng, SeedableRng};
use std::{cell::RefCell, ops::DerefMut};
#[derive(Clone, Debug)]
pub struct DropoutOneIn<const N: usize> {
rng: RefCell<StdRng>,
}
impl<const N: usize> Default for DropoutOneIn<N> {
fn default() -> Self {
let seed = unique_id().as_u64();
Self {
rng: RefCell::new(StdRng::seed_from_u64(seed)),
}
}
}
impl<const N: usize> CanUpdateWithGradients for DropoutOneIn<N> {
fn update<G: GradientProvider>(&mut self, _: &mut G, _: &mut UnusedTensors) {}
}
impl<const N: usize> ResetParams for DropoutOneIn<N> {
fn reset_params<R: Rng>(&mut self, _: &mut R) {}
}
impl<const N: usize> SaveToNpz for DropoutOneIn<N> {}
impl<const N: usize> LoadFromNpz for DropoutOneIn<N> {}
impl<const N: usize, T: Tensor<Dtype = f32>> Module<T> for DropoutOneIn<N> {
type Output = T;
fn forward(&self, input: T) -> Self::Output {
let mut rng = self.rng.borrow_mut();
dropout(input, 1.0 / N as f32, rng.deref_mut())
}
}
#[derive(Clone, Debug)]
pub struct Dropout {
pub p: f32,
rng: RefCell<StdRng>,
}
impl Dropout {
pub fn new(p: f32, rng_seed: u64) -> Self {
Self {
p,
rng: RefCell::new(StdRng::seed_from_u64(rng_seed)),
}
}
pub fn p(p: f32) -> Self {
let seed = unique_id().as_u64();
Self {
p,
rng: RefCell::new(StdRng::seed_from_u64(seed)),
}
}
}
impl Default for Dropout {
fn default() -> Self {
Self::new(0.5, 0)
}
}
impl CanUpdateWithGradients for Dropout {
fn update<G: GradientProvider>(&mut self, _: &mut G, _: &mut UnusedTensors) {}
}
impl ResetParams for Dropout {
fn reset_params<R: rand::Rng>(&mut self, _: &mut R) {}
}
impl SaveToNpz for Dropout {}
impl LoadFromNpz for Dropout {}
impl<T: Tensor<Dtype = f32>> Module<T> for Dropout {
type Output = T;
fn forward(&self, input: T) -> Self::Output {
let mut rng = self.rng.borrow_mut();
dropout(input, self.p, rng.deref_mut())
}
}
impl<R: Rng + SeedableRng, T: Tensor<Dtype = f32>> Module<(T, R)> for Dropout {
type Output = (T, R);
fn forward(&self, input: (T, R)) -> Self::Output {
let (t, mut rng) = input;
let t = dropout(t, self.p, &mut rng);
(t, rng)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dropout_internal_rng_reproduce() {
let d1 = Dropout::new(0.5, 0);
let d2 = Dropout::new(0.5, 0);
let t: Tensor1D<100> = Tensor1D::ones();
let r1 = d1.forward(t.trace());
let r2 = d2.forward(t.trace());
let r1_2 = d1.forward(t.trace());
assert_eq!(r1.data(), r2.data());
assert!(r1.data() != r1_2.data());
}
#[test]
fn test_dropout_external_rng() {
let rng = StdRng::seed_from_u64(0);
let d = Dropout::p(0.5);
let t: Tensor1D<100> = Tensor1D::ones();
let (r, _rng) = d.forward((t.trace(), rng));
assert!(t.data() != r.data());
}
#[test]
fn test_dropout_no_tape() {
let dropout = Dropout::p(0.5);
let t: Tensor1D<100> = Tensor1D::ones();
let r = dropout.forward(t.clone());
assert_eq!(t.data(), r.data());
}
#[test]
fn test_dropout_tape() {
let dropout = Dropout::p(0.5);
let t: Tensor1D<100> = Tensor1D::ones();
let r = dropout.forward(t.trace());
assert!(t.data() != r.data());
}
}