use crate::{
shapes::*,
tensor::{NoneTape, OwnedTape, Tensor},
tensor_ops::*,
};
use super::*;
#[derive(Clone, Debug, Default)]
pub struct DropoutOneIn<const N: usize>;
impl<const N: usize> ZeroSizedModule for DropoutOneIn<N> {}
impl<const N: usize, S: Shape, E: Dtype, D: Device<E>> Module<Tensor<S, E, D, NoneTape>>
for DropoutOneIn<N>
{
type Output = Tensor<S, E, D, NoneTape>;
type Error = D::Err;
fn try_forward(&self, input: Tensor<S, E, D, NoneTape>) -> Result<Self::Output, D::Err> {
Ok(input)
}
}
impl<const N: usize, S: Shape, E: Dtype, D: Device<E>> ModuleMut<Tensor<S, E, D, OwnedTape<E, D>>>
for DropoutOneIn<N>
{
type Output = Tensor<S, E, D, OwnedTape<E, D>>;
type Error = D::Err;
fn try_forward_mut(
&mut self,
input: Tensor<S, E, D, OwnedTape<E, D>>,
) -> Result<Self::Output, D::Err> {
input.try_dropout(1.0 / N as f64)
}
}
#[derive(Clone, Debug)]
pub struct Dropout {
pub p: f32,
}
impl Default for Dropout {
fn default() -> Self {
Self { p: 0.5 }
}
}
impl<D: Device<E>, E: Dtype> BuildOnDevice<D, E> for Dropout {
type Built = Dropout;
}
impl<E: Dtype, D: Device<E>> TensorCollection<E, D> for Dropout {
type To<E2: Dtype, D2: Device<E2>> = Dropout;
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
visitor.visit_fields(
<Self as TensorCollection<E, D>>::scalar(
"p",
|s| &s.p,
|s| &mut s.p,
ScalarOptions::from_default(0.5),
),
|p| Dropout { p },
)
}
}
impl<S: Shape, E: Dtype, D: Device<E>> Module<Tensor<S, E, D, NoneTape>> for Dropout {
type Output = Tensor<S, E, D, NoneTape>;
type Error = D::Err;
fn try_forward(&self, input: Tensor<S, E, D, NoneTape>) -> Result<Self::Output, D::Err> {
Ok(input)
}
}
impl<S: Shape, E: Dtype, D: Device<E>> ModuleMut<Tensor<S, E, D, OwnedTape<E, D>>> for Dropout {
type Output = Tensor<S, E, D, OwnedTape<E, D>>;
type Error = D::Err;
fn try_forward_mut(
&mut self,
input: Tensor<S, E, D, OwnedTape<E, D>>,
) -> Result<Self::Output, D::Err> {
input.try_dropout(self.p)
}
}
#[cfg(test)]
mod tests {
use crate::{
shapes::Rank1,
tensor::{AsArray, OnesTensor, Trace},
tests::*,
};
use super::*;
#[test]
fn test_dropout_internal_rng_reproduce() {
let dev: TestDevice = Default::default();
let mut d1 = Dropout { p: 0.5 };
let mut d2 = Dropout { p: 0.5 };
let t: Tensor<Rank1<100>, TestDtype, _> = dev.ones();
let r1 = d1.forward_mut(t.leaky_trace());
let r2 = d2.forward_mut(t.leaky_trace());
let r1_2 = d1.forward_mut(t.leaky_trace());
assert_ne!(r1.array(), r2.array());
assert_ne!(r1.array(), r1_2.array());
}
#[test]
fn test_dropout_no_tape() {
let dev: TestDevice = Default::default();
let dropout = Dropout { p: 0.5 };
let t: Tensor<Rank1<100>, TestDtype, _> = dev.ones();
let r = dropout.forward(t.clone());
assert_eq!(t.array(), r.array());
}
#[test]
fn test_dropout_tape() {
let dev: TestDevice = Default::default();
let mut dropout = Dropout { p: 0.5 };
let t: Tensor<Rank1<100>, TestDtype, _> = dev.ones();
let r = dropout.forward_mut(t.leaky_trace());
assert_ne!(t.array(), r.array());
}
}