use tch::nn::ModuleT;
use tch::{Kind, Tensor};
#[derive(Debug)]
pub struct Dropout {
dropout_prob: f64,
}
impl Dropout {
pub fn new(p: f64) -> Dropout {
Dropout { dropout_prob: p }
}
}
impl ModuleT for Dropout {
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
input.dropout(self.dropout_prob, train)
}
}
#[derive(Debug)]
pub struct XDropout {
dropout_prob: f64,
}
impl XDropout {
pub fn new(p: f64) -> XDropout {
XDropout { dropout_prob: p }
}
}
impl ModuleT for XDropout {
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
if train {
let mask = (Tensor::ones([1], (input.kind(), input.device()))
- input
.empty_like()
.bernoulli_float_(1_f64 - self.dropout_prob))
.to_kind(Kind::Bool);
input.masked_fill(&mask, 0) / (1_f64 - self.dropout_prob)
} else {
input.shallow_clone()
}
}
}