Skip to main content

flodl/nn/
dropout.rs

1use std::cell::Cell;
2
3use crate::autograd::Variable;
4use crate::tensor::Result;
5
6use super::Module;
7
8/// Inverted dropout module.
9///
10/// Uses a single fused `torch::dropout` kernel (1 autograd node).
11/// During training: randomly zeros elements with probability `p`,
12/// scales remaining by `1/(1-p)`.
13/// During eval: identity function.
14pub struct Dropout {
15    p: f64,
16    training: Cell<bool>,
17}
18
19impl Dropout {
20    /// Create a dropout module with drop probability `p` (0.0 to 1.0).
21    /// Use `set_training(false)` to disable during inference.
22    pub fn new(p: f64) -> Self {
23        Dropout {
24            p,
25            training: Cell::new(true),
26        }
27    }
28
29}
30
31impl Module for Dropout {
32    fn name(&self) -> &str { "dropout" }
33
34    fn forward(&self, input: &Variable) -> Result<Variable> {
35        if !self.training.get() || self.p == 0.0 {
36            return Ok(input.clone());
37        }
38        let result = input.data().dropout(self.p, true)?;
39        Ok(Variable::wrap(result))
40    }
41
42    fn set_training(&self, training: bool) {
43        self.training.set(training);
44    }
45}
46
47/// 2D channel dropout — drops entire channels (feature maps) at once.
48///
49/// Uses a single fused `torch::feature_dropout` kernel (1 autograd node).
50/// During training: randomly zeros entire channels with probability `p`,
51/// scales remaining by `1/(1-p)`. Mask shape is `[B, C, 1, 1]`.
52/// During eval: identity function.
53///
54/// Expects 4-D input `[B, C, H, W]`.
55pub struct Dropout2d {
56    p: f64,
57    training: Cell<bool>,
58}
59
60impl Dropout2d {
61    /// Create a 2D dropout module with channel drop probability `p`.
62    pub fn new(p: f64) -> Self {
63        Dropout2d {
64            p,
65            training: Cell::new(true),
66        }
67    }
68}
69
70impl Module for Dropout2d {
71    fn name(&self) -> &str { "dropout2d" }
72
73    fn forward(&self, input: &Variable) -> Result<Variable> {
74        if !self.training.get() || self.p == 0.0 {
75            return Ok(input.clone());
76        }
77        let result = input.data().feature_dropout(self.p, true)?;
78        Ok(Variable::wrap(result))
79    }
80
81    fn set_training(&self, training: bool) {
82        self.training.set(training);
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::tensor::{DType, Tensor, TensorOptions};
90
91    #[test]
92    fn test_dropout2d_whole_channels_zeroed() {
93        let d = Dropout2d::new(0.5);
94        let opts = TensorOptions { dtype: DType::Float32, device: crate::tensor::test_device() };
95        let input = Variable::new(Tensor::ones(&[2, 8, 4, 4], opts).unwrap(), false);
96
97        let output = d.forward(&input).unwrap();
98        let data = output.data().to_f32_vec().unwrap();
99
100        // Each channel should be either all-zero or all-scaled
101        let h = 4_usize;
102        let w = 4_usize;
103        let scale = 1.0 / 0.5;
104        for b in 0..2_usize {
105            for c in 0..8_usize {
106                let start = b * 8 * h * w + c * h * w;
107                let channel: Vec<f32> = data[start..start + h * w].to_vec();
108                let first = channel[0];
109                // All elements in channel should be equal (either 0 or scale)
110                for &v in &channel {
111                    assert!((v - first).abs() < 1e-5,
112                        "channel [{},{}] not uniform: {} vs {}", b, c, v, first);
113                }
114                assert!(first.abs() < 1e-5 || (first - scale as f32).abs() < 1e-5,
115                    "channel value should be 0 or {}: got {}", scale, first);
116            }
117        }
118    }
119
120    #[test]
121    fn test_dropout2d_eval_identity() {
122        let d = Dropout2d::new(0.5);
123        d.set_training(false);
124        let opts = TensorOptions { dtype: DType::Float32, device: crate::tensor::test_device() };
125        let input = Variable::new(Tensor::ones(&[1, 3, 4, 4], opts).unwrap(), false);
126
127        let output = d.forward(&input).unwrap();
128        let data = output.data().to_f32_vec().unwrap();
129        assert!(data.iter().all(|&v| (v - 1.0).abs() < 1e-5));
130    }
131}