use crate as burn;
use crate::config::Config;
use crate::module::Forward;
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, Tensor};
#[derive(Config)]
pub struct DropoutConfig {
pub prob: f64,
}
#[derive(Clone, Debug)]
pub struct Dropout {
prob: f64,
}
impl Dropout {
pub fn new(config: &DropoutConfig) -> Self {
Self { prob: config.prob }
}
}
impl<B: Backend, const D: usize> Forward<Tensor<B, D>, Tensor<B, D>> for Dropout {
fn forward(&self, input: Tensor<B, D>) -> Tensor<B, D> {
if !B::ad_enabled() || self.prob == 0.0 {
return input;
}
let random = input.random_like(Distribution::Bernoulli(self.prob));
let mask = random.equal_scalar(1);
let x = input.mask_fill(&mask, 0.0_f32);
x / (1.0 - self.prob)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Shape;
use crate::{TestADBackend, TestBackend};
#[test]
fn with_ad_backend_should_mark_input() {
let tensor = Tensor::<TestADBackend, 2>::ones(Shape::new([100, 100]));
let dropout = Dropout::new(&DropoutConfig { prob: 0.5 });
let output = dropout.forward(tensor.clone());
assert_ne!(tensor.to_data(), output.to_data());
}
#[test]
fn without_ad_backend_should_not_change_input() {
let tensor = Tensor::<TestBackend, 2>::ones(Shape::new([100, 100]));
let dropout = Dropout::new(&DropoutConfig { prob: 0.5 });
let output = dropout.forward(tensor.clone());
assert_eq!(tensor.to_data(), output.to_data());
}
}