use crate::asg::NodeType;
use crate::nn::Module;
use crate::tensor::Tensor;
use std::rc::Rc;
pub struct Dropout {
pub p: f32,
pub training: bool,
}
impl Dropout {
pub fn new(p: f32) -> Self {
assert!(
(0.0..1.0).contains(&p),
"Dropout probability must be in [0, 1), got {}",
p
);
Self { p, training: true }
}
pub fn train(&mut self) {
self.training = true;
}
pub fn eval(&mut self) {
self.training = false;
}
pub fn is_training(&self) -> bool {
self.training
}
}
impl Module for Dropout {
fn forward(&self, x: &Tensor) -> Tensor {
if !self.training || self.p == 0.0 {
return x.clone();
}
let ctx = &x.context;
let mask_id = ctx.borrow_mut().main_graph_mut().add_node(
None,
NodeType::DropoutMask {
shape_provider: x.node_id,
p: self.p,
},
);
let mask = Tensor {
node_id: mask_id,
context: Rc::clone(ctx),
};
x * &mask
}
fn parameters(&self) -> Vec<Tensor> {
Vec::new()
}
}
pub struct SpatialDropout {
pub p: f32,
pub training: bool,
}
impl SpatialDropout {
pub fn new(p: f32) -> Self {
assert!(
(0.0..1.0).contains(&p),
"Dropout probability must be in [0, 1), got {}",
p
);
Self { p, training: true }
}
pub fn train(&mut self) {
self.training = true;
}
pub fn eval(&mut self) {
self.training = false;
}
}
impl Module for SpatialDropout {
fn forward(&self, x: &Tensor) -> Tensor {
Dropout {
p: self.p,
training: self.training,
}
.forward(x)
}
fn parameters(&self) -> Vec<Tensor> {
Vec::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::GraphContext;
use std::cell::RefCell;
use std::rc::Rc;
#[test]
fn dropout_creation_validates_p() {
let _ = Dropout::new(0.0);
let _ = Dropout::new(0.5);
let _ = Dropout::new(0.99);
}
#[test]
#[should_panic(expected = "Dropout probability must be in [0, 1)")]
fn dropout_rejects_p_one() {
let _ = Dropout::new(1.0);
}
#[test]
fn dropout_train_eval_toggles() {
let mut d = Dropout::new(0.5);
assert!(d.is_training());
d.eval();
assert!(!d.is_training());
d.train();
assert!(d.is_training());
}
#[test]
fn dropout_eval_is_passthrough() {
let ctx = Rc::new(RefCell::new(GraphContext::new()));
let x = Tensor::new_input(&ctx, "x");
let nodes_before = ctx.borrow().main_graph().nodes.len();
let mut d = Dropout::new(0.5);
d.eval();
let y = d.forward(&x);
let nodes_after = ctx.borrow().main_graph().nodes.len();
assert_eq!(y.node_id, x.node_id);
assert_eq!(nodes_before, nodes_after);
}
#[test]
fn dropout_train_emits_mask_and_multiply() {
let ctx = Rc::new(RefCell::new(GraphContext::new()));
let x = Tensor::new_input(&ctx, "x");
let nodes_before = ctx.borrow().main_graph().nodes.len();
let d = Dropout::new(0.3);
let _y = d.forward(&x);
let nodes_after = ctx.borrow().main_graph().nodes.len();
assert_eq!(nodes_after - nodes_before, 2);
}
}