axonml_nn/layers/
dropout.rs1use std::any::Any;
18use std::sync::atomic::{AtomicBool, Ordering};
19
20use axonml_autograd::no_grad::is_grad_enabled;
21use axonml_autograd::{GradFn, GradientFunction, Variable};
22use axonml_tensor::Tensor;
23use rand::Rng;
24
25use crate::module::Module;
26
27#[derive(Debug)]
36struct DropoutBackward {
37 next_fns: Vec<Option<GradFn>>,
38 mask_tensor: Tensor<f32>,
40}
41
42impl GradientFunction for DropoutBackward {
43 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
44 let result = grad_output.mul(&self.mask_tensor).unwrap();
45 vec![Some(result)]
46 }
47
48 fn name(&self) -> &'static str {
49 "DropoutBackward"
50 }
51
52 fn next_functions(&self) -> &[Option<GradFn>] {
53 &self.next_fns
54 }
55
56 fn as_any(&self) -> &dyn Any {
57 self
58 }
59}
60
61pub struct Dropout {
72 p: f32,
74 training: AtomicBool,
76}
77
78impl std::fmt::Debug for Dropout {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("Dropout")
81 .field("p", &self.p)
82 .field("training", &self.training.load(Ordering::Relaxed))
83 .finish()
84 }
85}
86
87impl Dropout {
88 pub fn new(p: f32) -> Self {
90 assert!(
91 (0.0..1.0).contains(&p),
92 "Dropout probability must be in [0, 1)"
93 );
94 Self {
95 p,
96 training: AtomicBool::new(true),
97 }
98 }
99
100 pub fn default_p() -> Self {
102 Self::new(0.5)
103 }
104}
105
106impl Default for Dropout {
107 fn default() -> Self {
108 Self::default_p()
109 }
110}
111
112impl Module for Dropout {
113 fn forward(&self, input: &Variable) -> Variable {
114 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
115 return input.clone();
116 }
117
118 let input_data = input.data();
119 let shape = input_data.shape().to_vec();
120 let numel = input_data.numel();
121 let mut rng = rand::thread_rng();
122
123 let scale = 1.0 / (1.0 - self.p);
125
126 let mask: Vec<f32> = (0..numel)
128 .map(|_| {
129 if rng.r#gen::<f32>() < self.p {
130 0.0
131 } else {
132 scale
133 }
134 })
135 .collect();
136
137 let mut mask_tensor = Tensor::from_vec(mask, &shape).unwrap();
139 if input_data.device().is_gpu() {
140 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
141 }
142 let output = input_data.mul(&mask_tensor).unwrap();
143
144 let requires_grad = input.requires_grad() && is_grad_enabled();
145
146 if requires_grad {
147 let grad_fn = GradFn::new(DropoutBackward {
148 next_fns: vec![input.grad_fn().cloned()],
149 mask_tensor,
150 });
151 Variable::from_operation(output, grad_fn, true)
152 } else {
153 Variable::from_tensor(output)
154 }
155 }
156
157 fn set_training(&mut self, training: bool) {
158 self.training.store(training, Ordering::Relaxed);
159 }
160
161 fn is_training(&self) -> bool {
162 self.training.load(Ordering::Relaxed)
163 }
164
165 fn name(&self) -> &'static str {
166 "Dropout"
167 }
168}
169
170pub struct Dropout2d {
182 p: f32,
184 training: AtomicBool,
186}
187
188impl std::fmt::Debug for Dropout2d {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("Dropout2d")
191 .field("p", &self.p)
192 .field("training", &self.training.load(Ordering::Relaxed))
193 .finish()
194 }
195}
196
197impl Dropout2d {
198 pub fn new(p: f32) -> Self {
200 assert!(
201 (0.0..1.0).contains(&p),
202 "Dropout probability must be in [0, 1)"
203 );
204 Self {
205 p,
206 training: AtomicBool::new(true),
207 }
208 }
209}
210
211impl Module for Dropout2d {
212 fn forward(&self, input: &Variable) -> Variable {
213 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
214 return input.clone();
215 }
216
217 let input_data = input.data();
218 let shape = input_data.shape().to_vec();
219 let batch_size = shape[0];
220 let channels = shape[1];
221 let spatial_size: usize = shape[2..].iter().product();
222
223 let input_vec = input_data.to_vec();
224 let total = input_vec.len();
225 let mut mask = vec![0.0f32; total];
226 let mut rng = rand::thread_rng();
227 let scale = 1.0 / (1.0 - self.p);
228
229 for b in 0..batch_size {
230 for c in 0..channels {
231 let keep = rng.r#gen::<f32>() >= self.p;
232 let start = b * channels * spatial_size + c * spatial_size;
233 if keep {
234 for i in 0..spatial_size {
235 mask[start + i] = scale;
236 }
237 }
238 }
239 }
240
241 let mut mask_tensor = Tensor::from_vec(mask, &shape).unwrap();
242 if input_data.device().is_gpu() {
243 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
244 }
245 let output = input_data.mul(&mask_tensor).unwrap();
246 let requires_grad = input.requires_grad() && is_grad_enabled();
247
248 if requires_grad {
249 let grad_fn = GradFn::new(DropoutBackward {
250 next_fns: vec![input.grad_fn().cloned()],
251 mask_tensor,
252 });
253 Variable::from_operation(output, grad_fn, true)
254 } else {
255 Variable::from_tensor(output)
256 }
257 }
258
259 fn set_training(&mut self, training: bool) {
260 self.training.store(training, Ordering::Relaxed);
261 }
262
263 fn is_training(&self) -> bool {
264 self.training.load(Ordering::Relaxed)
265 }
266
267 fn name(&self) -> &'static str {
268 "Dropout2d"
269 }
270}
271
272pub struct AlphaDropout {
280 p: f32,
282 training: AtomicBool,
284}
285
286impl AlphaDropout {
287 pub fn new(p: f32) -> Self {
289 assert!(
290 (0.0..1.0).contains(&p),
291 "Dropout probability must be in [0, 1)"
292 );
293 Self {
294 p,
295 training: AtomicBool::new(true),
296 }
297 }
298}
299
300impl Module for AlphaDropout {
301 fn forward(&self, input: &Variable) -> Variable {
302 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
303 return input.clone();
304 }
305
306 const ALPHA: f32 = 1.673_263_2;
308 const SCALE: f32 = 1.050_701;
309
310 let alpha_p = -ALPHA * SCALE;
311 let a = ((1.0 - self.p) * (1.0 + self.p * alpha_p.powi(2)))
312 .sqrt()
313 .recip();
314 let b = -a * alpha_p * self.p;
315
316 let input_data = input.data();
317 let shape = input_data.shape().to_vec();
318 let numel = input_data.numel();
319 let mut rng = rand::thread_rng();
320
321 let dropped_val = a * alpha_p + b;
323 let mask_raw: Vec<f32> = (0..numel)
324 .map(|_| if rng.r#gen::<f32>() < self.p { 0.0 } else { a })
325 .collect();
326
327 let bias_raw: Vec<f32> = mask_raw
329 .iter()
330 .map(|&m| if m == 0.0 { dropped_val } else { b })
331 .collect();
332
333 let mut mask_tensor = Tensor::from_vec(mask_raw, &shape).unwrap();
334 let mut bias_tensor = Tensor::from_vec(bias_raw, &shape).unwrap();
335 if input_data.device().is_gpu() {
336 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
337 bias_tensor = bias_tensor.to_device(input_data.device()).unwrap();
338 }
339
340 let output = input_data
342 .mul(&mask_tensor)
343 .unwrap()
344 .add(&bias_tensor)
345 .unwrap();
346 let requires_grad = input.requires_grad() && is_grad_enabled();
347
348 if requires_grad {
349 let grad_fn = GradFn::new(DropoutBackward {
350 next_fns: vec![input.grad_fn().cloned()],
351 mask_tensor,
352 });
353 Variable::from_operation(output, grad_fn, true)
354 } else {
355 Variable::from_tensor(output)
356 }
357 }
358
359 fn set_training(&mut self, training: bool) {
360 self.training.store(training, Ordering::Relaxed);
361 }
362
363 fn is_training(&self) -> bool {
364 self.training.load(Ordering::Relaxed)
365 }
366
367 fn name(&self) -> &'static str {
368 "AlphaDropout"
369 }
370}
371
372#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_dropout_training() {
382 let dropout = Dropout::new(0.5);
383 let input = Variable::new(Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap(), false);
384 let output = dropout.forward(&input);
385
386 let output_vec = output.data().to_vec();
388 let num_zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
389
390 assert!(num_zeros > 300 && num_zeros < 700);
392 }
393
394 #[test]
395 fn test_dropout_eval() {
396 let mut dropout = Dropout::new(0.5);
397 dropout.eval();
398
399 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
400 let output = dropout.forward(&input);
401
402 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
404 }
405
406 #[test]
407 fn test_dropout_zero_probability() {
408 let dropout = Dropout::new(0.0);
409 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
410 let output = dropout.forward(&input);
411
412 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
413 }
414}