axonml_nn/layers/
dropout.rs1use std::any::Any;
18use std::sync::atomic::{AtomicBool, Ordering};
19
20use axonml_autograd::no_grad::is_grad_enabled;
21use axonml_autograd::{GradFn, GradientFunction, Variable, checkpoint_rng_seed};
22use axonml_tensor::Tensor;
23use rand::rngs::StdRng;
24use rand::{Rng, SeedableRng};
25
26use crate::module::Module;
27
28#[derive(Debug)]
37struct DropoutBackward {
38 next_fns: Vec<Option<GradFn>>,
39 mask_tensor: Tensor<f32>,
41}
42
43impl GradientFunction for DropoutBackward {
44 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
45 let result = grad_output
46 .mul(&self.mask_tensor)
47 .expect("tensor mul failed");
48 vec![Some(result)]
49 }
50
51 fn name(&self) -> &'static str {
52 "DropoutBackward"
53 }
54
55 fn next_functions(&self) -> &[Option<GradFn>] {
56 &self.next_fns
57 }
58
59 fn as_any(&self) -> &dyn Any {
60 self
61 }
62}
63
64pub struct Dropout {
75 p: f32,
77 training: AtomicBool,
79}
80
81impl std::fmt::Debug for Dropout {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("Dropout")
84 .field("p", &self.p)
85 .field("training", &self.training.load(Ordering::Relaxed))
86 .finish()
87 }
88}
89
90impl Dropout {
91 pub fn new(p: f32) -> Self {
93 assert!(
94 (0.0..1.0).contains(&p),
95 "Dropout probability must be in [0, 1)"
96 );
97 Self {
98 p,
99 training: AtomicBool::new(true),
100 }
101 }
102
103 pub fn default_p() -> Self {
105 Self::new(0.5)
106 }
107}
108
109impl Default for Dropout {
110 fn default() -> Self {
111 Self::default_p()
112 }
113}
114
115impl Module for Dropout {
116 fn forward(&self, input: &Variable) -> Variable {
117 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
118 return input.clone();
119 }
120
121 let input_data = input.data();
122 let shape = input_data.shape().to_vec();
123 let numel = input_data.numel();
124 let mut rng = if let Some(seed) = checkpoint_rng_seed() {
126 StdRng::seed_from_u64(seed)
127 } else {
128 StdRng::from_rng(rand::thread_rng()).unwrap()
129 };
130
131 let scale = 1.0 / (1.0 - self.p);
133
134 let mask: Vec<f32> = (0..numel)
136 .map(|_| {
137 if rng.r#gen::<f32>() < self.p {
138 0.0
139 } else {
140 scale
141 }
142 })
143 .collect();
144
145 let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
147 if input_data.device().is_gpu() {
148 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
149 }
150 let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
151
152 let requires_grad = input.requires_grad() && is_grad_enabled();
153
154 if requires_grad {
155 let grad_fn = GradFn::new(DropoutBackward {
156 next_fns: vec![input.grad_fn().cloned()],
157 mask_tensor,
158 });
159 Variable::from_operation(output, grad_fn, true)
160 } else {
161 Variable::from_tensor(output)
162 }
163 }
164
165 fn set_training(&mut self, training: bool) {
166 self.training.store(training, Ordering::Relaxed);
167 }
168
169 fn is_training(&self) -> bool {
170 self.training.load(Ordering::Relaxed)
171 }
172
173 fn name(&self) -> &'static str {
174 "Dropout"
175 }
176}
177
178pub struct Dropout2d {
190 p: f32,
192 training: AtomicBool,
194}
195
196impl std::fmt::Debug for Dropout2d {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 f.debug_struct("Dropout2d")
199 .field("p", &self.p)
200 .field("training", &self.training.load(Ordering::Relaxed))
201 .finish()
202 }
203}
204
205impl Dropout2d {
206 pub fn new(p: f32) -> Self {
208 assert!(
209 (0.0..1.0).contains(&p),
210 "Dropout probability must be in [0, 1)"
211 );
212 Self {
213 p,
214 training: AtomicBool::new(true),
215 }
216 }
217}
218
219impl Module for Dropout2d {
220 fn forward(&self, input: &Variable) -> Variable {
221 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
222 return input.clone();
223 }
224
225 let input_data = input.data();
226 let shape = input_data.shape().to_vec();
227 let batch_size = shape[0];
228 let channels = shape[1];
229 let spatial_size: usize = shape[2..].iter().product();
230
231 let input_vec = input_data.to_vec();
232 let total = input_vec.len();
233 let mut mask = vec![0.0f32; total];
234 let mut rng = if let Some(seed) = checkpoint_rng_seed() {
236 StdRng::seed_from_u64(seed)
237 } else {
238 StdRng::from_rng(rand::thread_rng()).unwrap()
239 };
240 let scale = 1.0 / (1.0 - self.p);
241
242 for b in 0..batch_size {
243 for c in 0..channels {
244 let keep = rng.r#gen::<f32>() >= self.p;
245 let start = b * channels * spatial_size + c * spatial_size;
246 if keep {
247 for i in 0..spatial_size {
248 mask[start + i] = scale;
249 }
250 }
251 }
252 }
253
254 let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
255 if input_data.device().is_gpu() {
256 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
257 }
258 let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
259 let requires_grad = input.requires_grad() && is_grad_enabled();
260
261 if requires_grad {
262 let grad_fn = GradFn::new(DropoutBackward {
263 next_fns: vec![input.grad_fn().cloned()],
264 mask_tensor,
265 });
266 Variable::from_operation(output, grad_fn, true)
267 } else {
268 Variable::from_tensor(output)
269 }
270 }
271
272 fn set_training(&mut self, training: bool) {
273 self.training.store(training, Ordering::Relaxed);
274 }
275
276 fn is_training(&self) -> bool {
277 self.training.load(Ordering::Relaxed)
278 }
279
280 fn name(&self) -> &'static str {
281 "Dropout2d"
282 }
283}
284
285pub struct AlphaDropout {
293 p: f32,
295 training: AtomicBool,
297}
298
299impl AlphaDropout {
300 pub fn new(p: f32) -> Self {
302 assert!(
303 (0.0..1.0).contains(&p),
304 "Dropout probability must be in [0, 1)"
305 );
306 Self {
307 p,
308 training: AtomicBool::new(true),
309 }
310 }
311}
312
313impl Module for AlphaDropout {
314 fn forward(&self, input: &Variable) -> Variable {
315 if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
316 return input.clone();
317 }
318
319 const ALPHA: f32 = 1.673_263_2;
321 const SCALE: f32 = 1.050_701;
322
323 let alpha_p = -ALPHA * SCALE;
324 let a = ((1.0 - self.p) * (1.0 + self.p * alpha_p.powi(2)))
325 .sqrt()
326 .recip();
327 let b = -a * alpha_p * self.p;
328
329 let input_data = input.data();
330 let shape = input_data.shape().to_vec();
331 let numel = input_data.numel();
332 let mut rng = if let Some(seed) = checkpoint_rng_seed() {
334 StdRng::seed_from_u64(seed)
335 } else {
336 StdRng::from_rng(rand::thread_rng()).unwrap()
337 };
338
339 let dropped_val = a * alpha_p + b;
341 let mask_raw: Vec<f32> = (0..numel)
342 .map(|_| if rng.r#gen::<f32>() < self.p { 0.0 } else { a })
343 .collect();
344
345 let bias_raw: Vec<f32> = mask_raw
347 .iter()
348 .map(|&m| if m == 0.0 { dropped_val } else { b })
349 .collect();
350
351 let mut mask_tensor = Tensor::from_vec(mask_raw, &shape).expect("tensor creation failed");
352 let mut bias_tensor = Tensor::from_vec(bias_raw, &shape).expect("tensor creation failed");
353 if input_data.device().is_gpu() {
354 mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
355 bias_tensor = bias_tensor.to_device(input_data.device()).unwrap();
356 }
357
358 let output = input_data
360 .mul(&mask_tensor)
361 .unwrap()
362 .add(&bias_tensor)
363 .unwrap();
364 let requires_grad = input.requires_grad() && is_grad_enabled();
365
366 if requires_grad {
367 let grad_fn = GradFn::new(DropoutBackward {
368 next_fns: vec![input.grad_fn().cloned()],
369 mask_tensor,
370 });
371 Variable::from_operation(output, grad_fn, true)
372 } else {
373 Variable::from_tensor(output)
374 }
375 }
376
377 fn set_training(&mut self, training: bool) {
378 self.training.store(training, Ordering::Relaxed);
379 }
380
381 fn is_training(&self) -> bool {
382 self.training.load(Ordering::Relaxed)
383 }
384
385 fn name(&self) -> &'static str {
386 "AlphaDropout"
387 }
388}
389
390#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn test_dropout_training() {
400 let dropout = Dropout::new(0.5);
401 let input = Variable::new(
402 Tensor::from_vec(vec![1.0; 1000], &[1000]).expect("tensor creation failed"),
403 false,
404 );
405 let output = dropout.forward(&input);
406
407 let output_vec = output.data().to_vec();
409 let num_zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
410
411 assert!(num_zeros > 300 && num_zeros < 700);
413 }
414
415 #[test]
416 fn test_dropout_eval() {
417 let mut dropout = Dropout::new(0.5);
418 dropout.eval();
419
420 let input = Variable::new(
421 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
422 false,
423 );
424 let output = dropout.forward(&input);
425
426 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
428 }
429
430 #[test]
431 fn test_dropout_zero_probability() {
432 let dropout = Dropout::new(0.0);
433 let input = Variable::new(
434 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
435 false,
436 );
437 let output = dropout.forward(&input);
438
439 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
440 }
441}