1use ghostflow_core::Tensor;
4use crate::module::Module;
5use rand::Rng;
6
7pub struct Dropout {
9 p: f32,
10 training: bool,
11}
12
13impl Dropout {
14 pub fn new(p: f32) -> Self {
15 assert!(p >= 0.0 && p < 1.0, "Dropout probability must be in [0, 1)");
16 Dropout { p, training: true }
17 }
18}
19
20impl Default for Dropout {
21 fn default() -> Self {
22 Self::new(0.5)
23 }
24}
25
26impl Module for Dropout {
27 fn forward(&self, input: &Tensor) -> Tensor {
28 if !self.training || self.p == 0.0 {
29 return input.clone();
30 }
31
32 let data = input.data_f32();
33 let mut rng = rand::thread_rng();
34 let scale = 1.0 / (1.0 - self.p);
35
36 let output: Vec<f32> = data.iter()
37 .map(|&x| {
38 if rng.gen::<f32>() < self.p {
39 0.0
40 } else {
41 x * scale
42 }
43 })
44 .collect();
45
46 Tensor::from_slice(&output, input.dims()).unwrap()
47 }
48
49 fn parameters(&self) -> Vec<Tensor> {
50 vec![]
51 }
52
53 fn train(&mut self) {
54 self.training = true;
55 }
56
57 fn eval(&mut self) {
58 self.training = false;
59 }
60
61 fn is_training(&self) -> bool {
62 self.training
63 }
64}
65
66pub struct Dropout2d {
68 p: f32,
69 training: bool,
70}
71
72impl Dropout2d {
73 pub fn new(p: f32) -> Self {
74 assert!(p >= 0.0 && p < 1.0);
75 Dropout2d { p, training: true }
76 }
77}
78
79impl Module for Dropout2d {
80 fn forward(&self, input: &Tensor) -> Tensor {
81 if !self.training || self.p == 0.0 {
82 return input.clone();
83 }
84
85 let dims = input.dims();
86 let batch = dims[0];
87 let channels = dims[1];
88 let spatial: usize = dims[2..].iter().product();
89
90 let data = input.data_f32();
91 let mut rng = rand::thread_rng();
92 let scale = 1.0 / (1.0 - self.p);
93
94 let mut output = data.clone();
95
96 for b in 0..batch {
97 for c in 0..channels {
98 if rng.gen::<f32>() < self.p {
99 let start = (b * channels + c) * spatial;
101 for i in 0..spatial {
102 output[start + i] = 0.0;
103 }
104 } else {
105 let start = (b * channels + c) * spatial;
107 for i in 0..spatial {
108 output[start + i] *= scale;
109 }
110 }
111 }
112 }
113
114 Tensor::from_slice(&output, dims).unwrap()
115 }
116
117 fn parameters(&self) -> Vec<Tensor> { vec![] }
118 fn train(&mut self) { self.training = true; }
119 fn eval(&mut self) { self.training = false; }
120 fn is_training(&self) -> bool { self.training }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_dropout_eval() {
129 let mut dropout = Dropout::new(0.5);
130 dropout.eval();
131
132 let input = Tensor::ones(&[10, 10]);
133 let output = dropout.forward(&input);
134
135 assert_eq!(output.data_f32(), input.data_f32());
137 }
138
139 #[test]
140 fn test_dropout_train() {
141 let dropout = Dropout::new(0.5);
142 let input = Tensor::ones(&[100, 100]);
143 let output = dropout.forward(&input);
144
145 let zeros = output.data_f32().iter().filter(|&&x| x == 0.0).count();
147 assert!(zeros > 0);
148 }
149}