1use std::cell::Cell;
2
3use crate::autograd::Variable;
4use crate::tensor::Result;
5
6use super::Module;
7
8pub struct Dropout {
15 p: f64,
16 training: Cell<bool>,
17}
18
19impl Dropout {
20 pub fn new(p: f64) -> Self {
23 Dropout {
24 p,
25 training: Cell::new(true),
26 }
27 }
28
29}
30
31impl Module for Dropout {
32 fn name(&self) -> &str { "dropout" }
33
34 fn forward(&self, input: &Variable) -> Result<Variable> {
35 if !self.training.get() || self.p == 0.0 {
36 return Ok(input.clone());
37 }
38 let result = input.data().dropout(self.p, true)?;
39 Ok(Variable::wrap(result))
40 }
41
42 fn set_training(&self, training: bool) {
43 self.training.set(training);
44 }
45}
46
47pub struct Dropout2d {
56 p: f64,
57 training: Cell<bool>,
58}
59
60impl Dropout2d {
61 pub fn new(p: f64) -> Self {
63 Dropout2d {
64 p,
65 training: Cell::new(true),
66 }
67 }
68}
69
70impl Module for Dropout2d {
71 fn name(&self) -> &str { "dropout2d" }
72
73 fn forward(&self, input: &Variable) -> Result<Variable> {
74 if !self.training.get() || self.p == 0.0 {
75 return Ok(input.clone());
76 }
77 let result = input.data().feature_dropout(self.p, true)?;
78 Ok(Variable::wrap(result))
79 }
80
81 fn set_training(&self, training: bool) {
82 self.training.set(training);
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::tensor::{DType, Tensor, TensorOptions};
90
91 #[test]
92 fn test_dropout2d_whole_channels_zeroed() {
93 let d = Dropout2d::new(0.5);
94 let opts = TensorOptions { dtype: DType::Float32, device: crate::tensor::test_device() };
95 let input = Variable::new(Tensor::ones(&[2, 8, 4, 4], opts).unwrap(), false);
96
97 let output = d.forward(&input).unwrap();
98 let data = output.data().to_f32_vec().unwrap();
99
100 let h = 4_usize;
102 let w = 4_usize;
103 let scale = 1.0 / 0.5;
104 for b in 0..2_usize {
105 for c in 0..8_usize {
106 let start = b * 8 * h * w + c * h * w;
107 let channel: Vec<f32> = data[start..start + h * w].to_vec();
108 let first = channel[0];
109 for &v in &channel {
111 assert!((v - first).abs() < 1e-5,
112 "channel [{},{}] not uniform: {} vs {}", b, c, v, first);
113 }
114 assert!(first.abs() < 1e-5 || (first - scale as f32).abs() < 1e-5,
115 "channel value should be 0 or {}: got {}", scale, first);
116 }
117 }
118 }
119
120 #[test]
121 fn test_dropout2d_eval_identity() {
122 let d = Dropout2d::new(0.5);
123 d.set_training(false);
124 let opts = TensorOptions { dtype: DType::Float32, device: crate::tensor::test_device() };
125 let input = Variable::new(Tensor::ones(&[1, 3, 4, 4], opts).unwrap(), false);
126
127 let output = d.forward(&input).unwrap();
128 let data = output.data().to_f32_vec().unwrap();
129 assert!(data.iter().all(|&v| (v - 1.0).abs() < 1e-5));
130 }
131}