axonml_nn/layers/
dropout.rs1use std::any::Any;
23use std::sync::atomic::{AtomicBool, Ordering};
24
25use axonml_autograd::no_grad::is_grad_enabled;
26use axonml_autograd::{GradFn, GradientFunction, Variable, checkpoint_rng_seed};
27use axonml_tensor::Tensor;
28use rand::rngs::StdRng;
29use rand::{Rng, SeedableRng};
30
31use crate::module::Module;
32
33#[derive(Debug)]
42struct DropoutBackward {
43 next_fns: Vec<Option<GradFn>>,
44 mask_tensor: Tensor<f32>,
46}
47
48impl GradientFunction for DropoutBackward {
49 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
50 let result = grad_output
51 .mul(&self.mask_tensor)
52 .expect("tensor mul failed");
53 vec![Some(result)]
54 }
55
56 fn name(&self) -> &'static str {
57 "DropoutBackward"
58 }
59
60 fn next_functions(&self) -> &[Option<GradFn>] {
61 &self.next_fns
62 }
63
64 fn as_any(&self) -> &dyn Any {
65 self
66 }
67}
68
69pub struct Dropout {
80 p: f32,
82 training: AtomicBool,
84}
85
86impl std::fmt::Debug for Dropout {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct("Dropout")
89 .field("p", &self.p)
90 .field("training", &self.training.load(Ordering::Relaxed))
91 .finish()
92 }
93}
94
95impl Dropout {
96 pub fn new(p: f32) -> Self {
98 assert!(
99 (0.0..1.0).contains(&p),
100 "Dropout probability must be in [0, 1)"
101 );
102 Self {
103 p,
104 training: AtomicBool::new(true),
105 }
106 }
107
108 pub fn default_p() -> Self {
110 Self::new(0.5)
111 }
112}
113
114impl Default for Dropout {
115 fn default() -> Self {
116 Self::default_p()
117 }
118}
119
120impl Module for Dropout {
121 fn forward(&self, input: &Variable) -> Variable {
122 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
123 return input.clone();
124 }
125
126 let input_data = input.data();
127 let shape = input_data.shape().to_vec();
128 let numel = input_data.numel();
129 let mut rng = if let Some(seed) = checkpoint_rng_seed() {
131 StdRng::seed_from_u64(seed)
132 } else {
133 StdRng::from_rng(rand::thread_rng()).unwrap()
134 };
135
136 let scale = 1.0 / (1.0 - self.p);
138
139 let mask: Vec<f32> = (0..numel)
141 .map(|_| {
142 if rng.r#gen::<f32>() < self.p {
143 0.0
144 } else {
145 scale
146 }
147 })
148 .collect();
149
150 let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
152 if input_data.device().is_gpu() {
153 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
154 }
155 let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
156
157 let requires_grad = input.requires_grad() && is_grad_enabled();
158
159 if requires_grad {
160 let grad_fn = GradFn::new(DropoutBackward {
161 next_fns: vec![input.grad_fn().cloned()],
162 mask_tensor,
163 });
164 Variable::from_operation(output, grad_fn, true)
165 } else {
166 Variable::from_tensor(output)
167 }
168 }
169
170 fn set_training(&mut self, training: bool) {
171 self.training.store(training, Ordering::Relaxed);
172 }
173
174 fn is_training(&self) -> bool {
175 self.training.load(Ordering::Relaxed)
176 }
177
178 fn name(&self) -> &'static str {
179 "Dropout"
180 }
181}
182
183pub struct Dropout2d {
195 p: f32,
197 training: AtomicBool,
199}
200
201impl std::fmt::Debug for Dropout2d {
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 f.debug_struct("Dropout2d")
204 .field("p", &self.p)
205 .field("training", &self.training.load(Ordering::Relaxed))
206 .finish()
207 }
208}
209
210impl Dropout2d {
211 pub fn new(p: f32) -> Self {
213 assert!(
214 (0.0..1.0).contains(&p),
215 "Dropout probability must be in [0, 1)"
216 );
217 Self {
218 p,
219 training: AtomicBool::new(true),
220 }
221 }
222}
223
224impl Module for Dropout2d {
225 fn forward(&self, input: &Variable) -> Variable {
226 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
227 return input.clone();
228 }
229
230 let input_data = input.data();
231 let shape = input_data.shape().to_vec();
232 let batch_size = shape[0];
233 let channels = shape[1];
234 let spatial_size: usize = shape[2..].iter().product();
235
236 let input_vec = input_data.to_vec();
237 let total = input_vec.len();
238 let mut mask = vec![0.0f32; total];
239 let mut rng = if let Some(seed) = checkpoint_rng_seed() {
241 StdRng::seed_from_u64(seed)
242 } else {
243 StdRng::from_rng(rand::thread_rng()).unwrap()
244 };
245 let scale = 1.0 / (1.0 - self.p);
246
247 for b in 0..batch_size {
248 for c in 0..channels {
249 let keep = rng.r#gen::<f32>() >= self.p;
250 let start = b * channels * spatial_size + c * spatial_size;
251 if keep {
252 for i in 0..spatial_size {
253 mask[start + i] = scale;
254 }
255 }
256 }
257 }
258
259 let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
260 if input_data.device().is_gpu() {
261 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
262 }
263 let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
264 let requires_grad = input.requires_grad() && is_grad_enabled();
265
266 if requires_grad {
267 let grad_fn = GradFn::new(DropoutBackward {
268 next_fns: vec![input.grad_fn().cloned()],
269 mask_tensor,
270 });
271 Variable::from_operation(output, grad_fn, true)
272 } else {
273 Variable::from_tensor(output)
274 }
275 }
276
277 fn set_training(&mut self, training: bool) {
278 self.training.store(training, Ordering::Relaxed);
279 }
280
281 fn is_training(&self) -> bool {
282 self.training.load(Ordering::Relaxed)
283 }
284
285 fn name(&self) -> &'static str {
286 "Dropout2d"
287 }
288}
289
290pub struct AlphaDropout {
298 p: f32,
300 training: AtomicBool,
302}
303
304impl AlphaDropout {
305 pub fn new(p: f32) -> Self {
307 assert!(
308 (0.0..1.0).contains(&p),
309 "Dropout probability must be in [0, 1)"
310 );
311 Self {
312 p,
313 training: AtomicBool::new(true),
314 }
315 }
316}
317
318impl Module for AlphaDropout {
319 fn forward(&self, input: &Variable) -> Variable {
320 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
321 return input.clone();
322 }
323
324 const ALPHA: f32 = 1.673_263_2;
326 const SCALE: f32 = 1.050_701;
327
328 let alpha_p = -ALPHA * SCALE;
329 let a = ((1.0 - self.p) * (1.0 + self.p * alpha_p.powi(2)))
330 .sqrt()
331 .recip();
332 let b = -a * alpha_p * self.p;
333
334 let input_data = input.data();
335 let shape = input_data.shape().to_vec();
336 let numel = input_data.numel();
337 let mut rng = if let Some(seed) = checkpoint_rng_seed() {
339 StdRng::seed_from_u64(seed)
340 } else {
341 StdRng::from_rng(rand::thread_rng()).unwrap()
342 };
343
344 let dropped_val = a * alpha_p + b;
346 let mask_raw: Vec<f32> = (0..numel)
347 .map(|_| if rng.r#gen::<f32>() < self.p { 0.0 } else { a })
348 .collect();
349
350 let bias_raw: Vec<f32> = mask_raw
352 .iter()
353 .map(|&m| if m == 0.0 { dropped_val } else { b })
354 .collect();
355
356 let mut mask_tensor = Tensor::from_vec(mask_raw, &shape).expect("tensor creation failed");
357 let mut bias_tensor = Tensor::from_vec(bias_raw, &shape).expect("tensor creation failed");
358 if input_data.device().is_gpu() {
359 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
360 bias_tensor = bias_tensor.to_device(input_data.device()).unwrap();
361 }
362
363 let output = input_data
365 .mul(&mask_tensor)
366 .unwrap()
367 .add(&bias_tensor)
368 .unwrap();
369 let requires_grad = input.requires_grad() && is_grad_enabled();
370
371 if requires_grad {
372 let grad_fn = GradFn::new(DropoutBackward {
373 next_fns: vec![input.grad_fn().cloned()],
374 mask_tensor,
375 });
376 Variable::from_operation(output, grad_fn, true)
377 } else {
378 Variable::from_tensor(output)
379 }
380 }
381
382 fn set_training(&mut self, training: bool) {
383 self.training.store(training, Ordering::Relaxed);
384 }
385
386 fn is_training(&self) -> bool {
387 self.training.load(Ordering::Relaxed)
388 }
389
390 fn name(&self) -> &'static str {
391 "AlphaDropout"
392 }
393}
394
395#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_dropout_training() {
405 let dropout = Dropout::new(0.5);
406 let input = Variable::new(
407 Tensor::from_vec(vec![1.0; 1000], &[1000]).expect("tensor creation failed"),
408 false,
409 );
410 let output = dropout.forward(&input);
411
412 let output_vec = output.data().to_vec();
414 let num_zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
415
416 assert!(num_zeros > 300 && num_zeros < 700);
418 }
419
420 #[test]
421 fn test_dropout_eval() {
422 let mut dropout = Dropout::new(0.5);
423 dropout.eval();
424
425 let input = Variable::new(
426 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
427 false,
428 );
429 let output = dropout.forward(&input);
430
431 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
433 }
434
435 #[test]
436 fn test_dropout_zero_probability() {
437 let dropout = Dropout::new(0.0);
438 let input = Variable::new(
439 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
440 false,
441 );
442 let output = dropout.forward(&input);
443
444 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
445 }
446}