axonml_nn/layers/
dropout.rs1use std::sync::atomic::{AtomicBool, Ordering};
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12use rand::Rng;
13
14use crate::module::Module;
15
16pub struct Dropout {
27 p: f32,
29 training: AtomicBool,
31}
32
33impl std::fmt::Debug for Dropout {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("Dropout")
36 .field("p", &self.p)
37 .field("training", &self.training.load(Ordering::Relaxed))
38 .finish()
39 }
40}
41
42impl Dropout {
43 pub fn new(p: f32) -> Self {
45 assert!(
46 (0.0..1.0).contains(&p),
47 "Dropout probability must be in [0, 1)"
48 );
49 Self {
50 p,
51 training: AtomicBool::new(true),
52 }
53 }
54
55 pub fn default_p() -> Self {
57 Self::new(0.5)
58 }
59}
60
61impl Default for Dropout {
62 fn default() -> Self {
63 Self::default_p()
64 }
65}
66
67impl Module for Dropout {
68 fn forward(&self, input: &Variable) -> Variable {
69 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
70 return input.clone();
71 }
72
73 let input_data = input.data();
74 let input_vec = input_data.to_vec();
75 let mut rng = rand::thread_rng();
76
77 let scale = 1.0 / (1.0 - self.p);
79
80 let output_vec: Vec<f32> = input_vec
81 .iter()
82 .map(|&x| {
83 if rng.gen::<f32>() < self.p {
84 0.0
85 } else {
86 x * scale
87 }
88 })
89 .collect();
90
91 let output = Tensor::from_vec(output_vec, input_data.shape()).unwrap();
92 Variable::new(output, input.requires_grad())
93 }
94
95 fn set_training(&mut self, training: bool) {
96 self.training.store(training, Ordering::Relaxed);
97 }
98
99 fn is_training(&self) -> bool {
100 self.training.load(Ordering::Relaxed)
101 }
102
103 fn name(&self) -> &'static str {
104 "Dropout"
105 }
106}
107
108pub struct Dropout2d {
120 p: f32,
122 training: AtomicBool,
124}
125
126impl std::fmt::Debug for Dropout2d {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 f.debug_struct("Dropout2d")
129 .field("p", &self.p)
130 .field("training", &self.training.load(Ordering::Relaxed))
131 .finish()
132 }
133}
134
135impl Dropout2d {
136 pub fn new(p: f32) -> Self {
138 assert!(
139 (0.0..1.0).contains(&p),
140 "Dropout probability must be in [0, 1)"
141 );
142 Self {
143 p,
144 training: AtomicBool::new(true),
145 }
146 }
147}
148
149impl Module for Dropout2d {
150 fn forward(&self, input: &Variable) -> Variable {
151 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
152 return input.clone();
153 }
154
155 let input_data = input.data();
156 let shape = input_data.shape().to_vec();
157 let batch_size = shape[0];
158 let channels = shape[1];
159 let spatial_size: usize = shape[2..].iter().product();
160
161 let input_vec = input_data.to_vec();
162 let mut output_vec = input_vec.clone();
163 let mut rng = rand::thread_rng();
164 let scale = 1.0 / (1.0 - self.p);
165
166 for b in 0..batch_size {
167 for c in 0..channels {
168 if rng.gen::<f32>() < self.p {
169 let start = b * channels * spatial_size + c * spatial_size;
171 for i in 0..spatial_size {
172 output_vec[start + i] = 0.0;
173 }
174 } else {
175 let start = b * channels * spatial_size + c * spatial_size;
177 for i in 0..spatial_size {
178 output_vec[start + i] *= scale;
179 }
180 }
181 }
182 }
183
184 let output = Tensor::from_vec(output_vec, &shape).unwrap();
185 Variable::new(output, input.requires_grad())
186 }
187
188 fn set_training(&mut self, training: bool) {
189 self.training.store(training, Ordering::Relaxed);
190 }
191
192 fn is_training(&self) -> bool {
193 self.training.load(Ordering::Relaxed)
194 }
195
196 fn name(&self) -> &'static str {
197 "Dropout2d"
198 }
199}
200
201pub struct AlphaDropout {
209 p: f32,
211 training: AtomicBool,
213}
214
215impl AlphaDropout {
216 pub fn new(p: f32) -> Self {
218 assert!(
219 (0.0..1.0).contains(&p),
220 "Dropout probability must be in [0, 1)"
221 );
222 Self {
223 p,
224 training: AtomicBool::new(true),
225 }
226 }
227}
228
229impl Module for AlphaDropout {
230 fn forward(&self, input: &Variable) -> Variable {
231 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
232 return input.clone();
233 }
234
235 const ALPHA: f32 = 1.673_263_2;
237 const SCALE: f32 = 1.050_701;
238
239 let alpha_p = -ALPHA * SCALE;
240 let a = ((1.0 - self.p) * (1.0 + self.p * alpha_p.powi(2)))
241 .sqrt()
242 .recip();
243 let b = -a * alpha_p * self.p;
244
245 let input_data = input.data();
246 let input_vec = input_data.to_vec();
247 let mut rng = rand::thread_rng();
248
249 let output_vec: Vec<f32> = input_vec
250 .iter()
251 .map(|&x| {
252 if rng.gen::<f32>() < self.p {
253 a * alpha_p + b
254 } else {
255 a * x + b
256 }
257 })
258 .collect();
259
260 let output = Tensor::from_vec(output_vec, input_data.shape()).unwrap();
261 Variable::new(output, input.requires_grad())
262 }
263
264 fn set_training(&mut self, training: bool) {
265 self.training.store(training, Ordering::Relaxed);
266 }
267
268 fn is_training(&self) -> bool {
269 self.training.load(Ordering::Relaxed)
270 }
271
272 fn name(&self) -> &'static str {
273 "AlphaDropout"
274 }
275}
276
277#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_dropout_training() {
287 let dropout = Dropout::new(0.5);
288 let input = Variable::new(Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap(), false);
289 let output = dropout.forward(&input);
290
291 let output_vec = output.data().to_vec();
293 let num_zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
294
295 assert!(num_zeros > 300 && num_zeros < 700);
297 }
298
299 #[test]
300 fn test_dropout_eval() {
301 let mut dropout = Dropout::new(0.5);
302 dropout.eval();
303
304 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
305 let output = dropout.forward(&input);
306
307 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
309 }
310
311 #[test]
312 fn test_dropout_zero_probability() {
313 let dropout = Dropout::new(0.0);
314 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
315 let output = dropout.forward(&input);
316
317 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
318 }
319}