axonml_nn/layers/
dropout.rs1use std::any::Any;
18use std::sync::atomic::{AtomicBool, Ordering};
19
20use axonml_autograd::no_grad::is_grad_enabled;
21use axonml_autograd::{GradFn, GradientFunction, Variable, checkpoint_rng_seed};
22use axonml_tensor::Tensor;
23use rand::rngs::StdRng;
24use rand::{Rng, SeedableRng};
25
26use crate::module::Module;
27
28#[derive(Debug)]
37struct DropoutBackward {
38 next_fns: Vec<Option<GradFn>>,
39 mask_tensor: Tensor<f32>,
41}
42
43impl GradientFunction for DropoutBackward {
44 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
45 let result = grad_output.mul(&self.mask_tensor).expect("tensor mul failed");
46 vec![Some(result)]
47 }
48
49 fn name(&self) -> &'static str {
50 "DropoutBackward"
51 }
52
53 fn next_functions(&self) -> &[Option<GradFn>] {
54 &self.next_fns
55 }
56
57 fn as_any(&self) -> &dyn Any {
58 self
59 }
60}
61
62pub struct Dropout {
73 p: f32,
75 training: AtomicBool,
77}
78
79impl std::fmt::Debug for Dropout {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("Dropout")
82 .field("p", &self.p)
83 .field("training", &self.training.load(Ordering::Relaxed))
84 .finish()
85 }
86}
87
88impl Dropout {
89 pub fn new(p: f32) -> Self {
91 assert!(
92 (0.0..1.0).contains(&p),
93 "Dropout probability must be in [0, 1)"
94 );
95 Self {
96 p,
97 training: AtomicBool::new(true),
98 }
99 }
100
101 pub fn default_p() -> Self {
103 Self::new(0.5)
104 }
105}
106
107impl Default for Dropout {
108 fn default() -> Self {
109 Self::default_p()
110 }
111}
112
113impl Module for Dropout {
114 fn forward(&self, input: &Variable) -> Variable {
115 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
116 return input.clone();
117 }
118
119 let input_data = input.data();
120 let shape = input_data.shape().to_vec();
121 let numel = input_data.numel();
122 let mut rng = if let Some(seed) = checkpoint_rng_seed() {
124 StdRng::seed_from_u64(seed)
125 } else {
126 StdRng::from_rng(rand::thread_rng()).unwrap()
127 };
128
129 let scale = 1.0 / (1.0 - self.p);
131
132 let mask: Vec<f32> = (0..numel)
134 .map(|_| {
135 if rng.r#gen::<f32>() < self.p {
136 0.0
137 } else {
138 scale
139 }
140 })
141 .collect();
142
143 let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
145 if input_data.device().is_gpu() {
146 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
147 }
148 let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
149
150 let requires_grad = input.requires_grad() && is_grad_enabled();
151
152 if requires_grad {
153 let grad_fn = GradFn::new(DropoutBackward {
154 next_fns: vec![input.grad_fn().cloned()],
155 mask_tensor,
156 });
157 Variable::from_operation(output, grad_fn, true)
158 } else {
159 Variable::from_tensor(output)
160 }
161 }
162
163 fn set_training(&mut self, training: bool) {
164 self.training.store(training, Ordering::Relaxed);
165 }
166
167 fn is_training(&self) -> bool {
168 self.training.load(Ordering::Relaxed)
169 }
170
171 fn name(&self) -> &'static str {
172 "Dropout"
173 }
174}
175
176pub struct Dropout2d {
188 p: f32,
190 training: AtomicBool,
192}
193
194impl std::fmt::Debug for Dropout2d {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 f.debug_struct("Dropout2d")
197 .field("p", &self.p)
198 .field("training", &self.training.load(Ordering::Relaxed))
199 .finish()
200 }
201}
202
203impl Dropout2d {
204 pub fn new(p: f32) -> Self {
206 assert!(
207 (0.0..1.0).contains(&p),
208 "Dropout probability must be in [0, 1)"
209 );
210 Self {
211 p,
212 training: AtomicBool::new(true),
213 }
214 }
215}
216
217impl Module for Dropout2d {
218 fn forward(&self, input: &Variable) -> Variable {
219 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
220 return input.clone();
221 }
222
223 let input_data = input.data();
224 let shape = input_data.shape().to_vec();
225 let batch_size = shape[0];
226 let channels = shape[1];
227 let spatial_size: usize = shape[2..].iter().product();
228
229 let input_vec = input_data.to_vec();
230 let total = input_vec.len();
231 let mut mask = vec![0.0f32; total];
232 let mut rng = if let Some(seed) = checkpoint_rng_seed() {
234 StdRng::seed_from_u64(seed)
235 } else {
236 StdRng::from_rng(rand::thread_rng()).unwrap()
237 };
238 let scale = 1.0 / (1.0 - self.p);
239
240 for b in 0..batch_size {
241 for c in 0..channels {
242 let keep = rng.r#gen::<f32>() >= self.p;
243 let start = b * channels * spatial_size + c * spatial_size;
244 if keep {
245 for i in 0..spatial_size {
246 mask[start + i] = scale;
247 }
248 }
249 }
250 }
251
252 let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
253 if input_data.device().is_gpu() {
254 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
255 }
256 let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
257 let requires_grad = input.requires_grad() && is_grad_enabled();
258
259 if requires_grad {
260 let grad_fn = GradFn::new(DropoutBackward {
261 next_fns: vec![input.grad_fn().cloned()],
262 mask_tensor,
263 });
264 Variable::from_operation(output, grad_fn, true)
265 } else {
266 Variable::from_tensor(output)
267 }
268 }
269
270 fn set_training(&mut self, training: bool) {
271 self.training.store(training, Ordering::Relaxed);
272 }
273
274 fn is_training(&self) -> bool {
275 self.training.load(Ordering::Relaxed)
276 }
277
278 fn name(&self) -> &'static str {
279 "Dropout2d"
280 }
281}
282
283pub struct AlphaDropout {
291 p: f32,
293 training: AtomicBool,
295}
296
297impl AlphaDropout {
298 pub fn new(p: f32) -> Self {
300 assert!(
301 (0.0..1.0).contains(&p),
302 "Dropout probability must be in [0, 1)"
303 );
304 Self {
305 p,
306 training: AtomicBool::new(true),
307 }
308 }
309}
310
311impl Module for AlphaDropout {
312 fn forward(&self, input: &Variable) -> Variable {
313 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
314 return input.clone();
315 }
316
317 const ALPHA: f32 = 1.673_263_2;
319 const SCALE: f32 = 1.050_701;
320
321 let alpha_p = -ALPHA * SCALE;
322 let a = ((1.0 - self.p) * (1.0 + self.p * alpha_p.powi(2)))
323 .sqrt()
324 .recip();
325 let b = -a * alpha_p * self.p;
326
327 let input_data = input.data();
328 let shape = input_data.shape().to_vec();
329 let numel = input_data.numel();
330 let mut rng = if let Some(seed) = checkpoint_rng_seed() {
332 StdRng::seed_from_u64(seed)
333 } else {
334 StdRng::from_rng(rand::thread_rng()).unwrap()
335 };
336
337 let dropped_val = a * alpha_p + b;
339 let mask_raw: Vec<f32> = (0..numel)
340 .map(|_| if rng.r#gen::<f32>() < self.p { 0.0 } else { a })
341 .collect();
342
343 let bias_raw: Vec<f32> = mask_raw
345 .iter()
346 .map(|&m| if m == 0.0 { dropped_val } else { b })
347 .collect();
348
349 let mut mask_tensor = Tensor::from_vec(mask_raw, &shape).expect("tensor creation failed");
350 let mut bias_tensor = Tensor::from_vec(bias_raw, &shape).expect("tensor creation failed");
351 if input_data.device().is_gpu() {
352 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
353 bias_tensor = bias_tensor.to_device(input_data.device()).unwrap();
354 }
355
356 let output = input_data
358 .mul(&mask_tensor)
359 .unwrap()
360 .add(&bias_tensor)
361 .unwrap();
362 let requires_grad = input.requires_grad() && is_grad_enabled();
363
364 if requires_grad {
365 let grad_fn = GradFn::new(DropoutBackward {
366 next_fns: vec![input.grad_fn().cloned()],
367 mask_tensor,
368 });
369 Variable::from_operation(output, grad_fn, true)
370 } else {
371 Variable::from_tensor(output)
372 }
373 }
374
375 fn set_training(&mut self, training: bool) {
376 self.training.store(training, Ordering::Relaxed);
377 }
378
379 fn is_training(&self) -> bool {
380 self.training.load(Ordering::Relaxed)
381 }
382
383 fn name(&self) -> &'static str {
384 "AlphaDropout"
385 }
386}
387
388#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn test_dropout_training() {
398 let dropout = Dropout::new(0.5);
399 let input = Variable::new(Tensor::from_vec(vec![1.0; 1000], &[1000]).expect("tensor creation failed"), false);
400 let output = dropout.forward(&input);
401
402 let output_vec = output.data().to_vec();
404 let num_zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
405
406 assert!(num_zeros > 300 && num_zeros < 700);
408 }
409
410 #[test]
411 fn test_dropout_eval() {
412 let mut dropout = Dropout::new(0.5);
413 dropout.eval();
414
415 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), false);
416 let output = dropout.forward(&input);
417
418 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
420 }
421
422 #[test]
423 fn test_dropout_zero_probability() {
424 let dropout = Dropout::new(0.0);
425 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), false);
426 let output = dropout.forward(&input);
427
428 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
429 }
430}