Skip to main content

axonml_nn/layers/
dropout.rs

1//! Dropout regularization — `Dropout` and `Dropout2d`.
2//!
3//! 441 lines. `Dropout` randomly zeros elements with probability `p` during
4//! training and scales by `1/(1-p)` (inverted dropout). `Dropout2d` zeros
5//! entire channels for spatial data. Both are no-ops during eval mode.
6//!
7//! # File
8//! `crates/axonml-nn/src/layers/dropout.rs`
9//!
10//! # Author
11//! Andrew Jewell Sr. — AutomataNexus LLC
12//! ORCID: 0009-0005-2158-7060
13//!
14//! # Updated
15//! April 14, 2026 11:15 PM EST
16//!
17//! # Disclaimer
18//! Use at own risk. This software is provided "as is", without warranty of any
19//! kind, express or implied. The author and AutomataNexus shall not be held
20//! liable for any damages arising from the use of this software.
21
22use std::any::Any;
23use std::sync::atomic::{AtomicBool, Ordering};
24
25use axonml_autograd::no_grad::is_grad_enabled;
26use axonml_autograd::{GradFn, GradientFunction, Variable, checkpoint_rng_seed};
27use axonml_tensor::Tensor;
28use rand::rngs::StdRng;
29use rand::{Rng, SeedableRng};
30
31use crate::module::Module;
32
33// =============================================================================
34// DropoutBackward
35// =============================================================================
36
37/// Gradient function for Dropout.
38///
39/// Applies the same mask used in the forward pass: gradient is scaled where
40/// elements were kept, and zeroed where elements were dropped.
41#[derive(Debug)]
42struct DropoutBackward {
43    next_fns: Vec<Option<GradFn>>,
44    /// The mask as a tensor (stored on same device as input — GPU or CPU).
45    mask_tensor: Tensor<f32>,
46}
47
48impl GradientFunction for DropoutBackward {
49    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
50        let result = grad_output
51            .mul(&self.mask_tensor)
52            .expect("tensor mul failed");
53        vec![Some(result)]
54    }
55
56    fn name(&self) -> &'static str {
57        "DropoutBackward"
58    }
59
60    fn next_functions(&self) -> &[Option<GradFn>] {
61        &self.next_fns
62    }
63
64    fn as_any(&self) -> &dyn Any {
65        self
66    }
67}
68
69// =============================================================================
70// Dropout
71// =============================================================================
72
73/// During training, randomly zeros some elements with probability p.
74///
75/// During evaluation, returns input unchanged.
76///
77/// # Arguments
78/// * `p` - Probability of an element to be zeroed (default: 0.5)
79pub struct Dropout {
80    /// Dropout probability.
81    p: f32,
82    /// Whether in training mode.
83    training: AtomicBool,
84}
85
86impl std::fmt::Debug for Dropout {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("Dropout")
89            .field("p", &self.p)
90            .field("training", &self.training.load(Ordering::Relaxed))
91            .finish()
92    }
93}
94
95impl Dropout {
96    /// Creates a new Dropout layer with the given probability.
97    pub fn new(p: f32) -> Self {
98        assert!(
99            (0.0..1.0).contains(&p),
100            "Dropout probability must be in [0, 1)"
101        );
102        Self {
103            p,
104            training: AtomicBool::new(true),
105        }
106    }
107
108    /// Creates a Dropout layer with default probability (0.5).
109    pub fn default_p() -> Self {
110        Self::new(0.5)
111    }
112}
113
114impl Default for Dropout {
115    fn default() -> Self {
116        Self::default_p()
117    }
118}
119
120impl Module for Dropout {
121    fn forward(&self, input: &Variable) -> Variable {
122        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
123            return input.clone();
124        }
125
126        let input_data = input.data();
127        let shape = input_data.shape().to_vec();
128        let numel = input_data.numel();
129        // Use deterministic RNG during checkpoint recomputation
130        let mut rng = if let Some(seed) = checkpoint_rng_seed() {
131            StdRng::seed_from_u64(seed)
132        } else {
133            StdRng::from_rng(rand::thread_rng()).unwrap()
134        };
135
136        // Scale factor for inverted dropout
137        let scale = 1.0 / (1.0 - self.p);
138
139        // Build mask on CPU: 0.0 for dropped, scale for kept
140        let mask: Vec<f32> = (0..numel)
141            .map(|_| {
142                if rng.r#gen::<f32>() < self.p {
143                    0.0
144                } else {
145                    scale
146                }
147            })
148            .collect();
149
150        // Create mask tensor and move to input device
151        let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
152        if input_data.device().is_gpu() {
153            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
154        }
155        let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
156
157        let requires_grad = input.requires_grad() && is_grad_enabled();
158
159        if requires_grad {
160            let grad_fn = GradFn::new(DropoutBackward {
161                next_fns: vec![input.grad_fn().cloned()],
162                mask_tensor,
163            });
164            Variable::from_operation(output, grad_fn, true)
165        } else {
166            Variable::from_tensor(output)
167        }
168    }
169
170    fn set_training(&mut self, training: bool) {
171        self.training.store(training, Ordering::Relaxed);
172    }
173
174    fn is_training(&self) -> bool {
175        self.training.load(Ordering::Relaxed)
176    }
177
178    fn name(&self) -> &'static str {
179        "Dropout"
180    }
181}
182
183// =============================================================================
184// Dropout2d
185// =============================================================================
186
187/// Randomly zeros entire channels during training.
188///
189/// Useful for spatial data like images.
190///
191/// # Shape
192/// - Input: (N, C, H, W)
193/// - Output: Same as input
194pub struct Dropout2d {
195    /// Dropout probability.
196    p: f32,
197    /// Whether in training mode.
198    training: AtomicBool,
199}
200
201impl std::fmt::Debug for Dropout2d {
202    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        f.debug_struct("Dropout2d")
204            .field("p", &self.p)
205            .field("training", &self.training.load(Ordering::Relaxed))
206            .finish()
207    }
208}
209
210impl Dropout2d {
211    /// Creates a new Dropout2d layer.
212    pub fn new(p: f32) -> Self {
213        assert!(
214            (0.0..1.0).contains(&p),
215            "Dropout probability must be in [0, 1)"
216        );
217        Self {
218            p,
219            training: AtomicBool::new(true),
220        }
221    }
222}
223
224impl Module for Dropout2d {
225    fn forward(&self, input: &Variable) -> Variable {
226        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
227            return input.clone();
228        }
229
230        let input_data = input.data();
231        let shape = input_data.shape().to_vec();
232        let batch_size = shape[0];
233        let channels = shape[1];
234        let spatial_size: usize = shape[2..].iter().product();
235
236        let input_vec = input_data.to_vec();
237        let total = input_vec.len();
238        let mut mask = vec![0.0f32; total];
239        // Use deterministic RNG during checkpoint recomputation
240        let mut rng = if let Some(seed) = checkpoint_rng_seed() {
241            StdRng::seed_from_u64(seed)
242        } else {
243            StdRng::from_rng(rand::thread_rng()).unwrap()
244        };
245        let scale = 1.0 / (1.0 - self.p);
246
247        for b in 0..batch_size {
248            for c in 0..channels {
249                let keep = rng.r#gen::<f32>() >= self.p;
250                let start = b * channels * spatial_size + c * spatial_size;
251                if keep {
252                    for i in 0..spatial_size {
253                        mask[start + i] = scale;
254                    }
255                }
256            }
257        }
258
259        let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
260        if input_data.device().is_gpu() {
261            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
262        }
263        let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
264        let requires_grad = input.requires_grad() && is_grad_enabled();
265
266        if requires_grad {
267            let grad_fn = GradFn::new(DropoutBackward {
268                next_fns: vec![input.grad_fn().cloned()],
269                mask_tensor,
270            });
271            Variable::from_operation(output, grad_fn, true)
272        } else {
273            Variable::from_tensor(output)
274        }
275    }
276
277    fn set_training(&mut self, training: bool) {
278        self.training.store(training, Ordering::Relaxed);
279    }
280
281    fn is_training(&self) -> bool {
282        self.training.load(Ordering::Relaxed)
283    }
284
285    fn name(&self) -> &'static str {
286        "Dropout2d"
287    }
288}
289
290// =============================================================================
291// AlphaDropout
292// =============================================================================
293
294/// Alpha Dropout for Self-Normalizing Neural Networks (SNNs).
295///
296/// Preserves the mean and variance of inputs by using specific alpha values.
297pub struct AlphaDropout {
298    /// Dropout probability.
299    p: f32,
300    /// Whether in training mode.
301    training: AtomicBool,
302}
303
304impl AlphaDropout {
305    /// Creates a new AlphaDropout layer.
306    pub fn new(p: f32) -> Self {
307        assert!(
308            (0.0..1.0).contains(&p),
309            "Dropout probability must be in [0, 1)"
310        );
311        Self {
312            p,
313            training: AtomicBool::new(true),
314        }
315    }
316}
317
318impl Module for AlphaDropout {
319    fn forward(&self, input: &Variable) -> Variable {
320        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
321            return input.clone();
322        }
323
324        // SELU parameters
325        const ALPHA: f32 = 1.673_263_2;
326        const SCALE: f32 = 1.050_701;
327
328        let alpha_p = -ALPHA * SCALE;
329        let a = ((1.0 - self.p) * (1.0 + self.p * alpha_p.powi(2)))
330            .sqrt()
331            .recip();
332        let b = -a * alpha_p * self.p;
333
334        let input_data = input.data();
335        let shape = input_data.shape().to_vec();
336        let numel = input_data.numel();
337        // Use deterministic RNG during checkpoint recomputation
338        let mut rng = if let Some(seed) = checkpoint_rng_seed() {
339            StdRng::seed_from_u64(seed)
340        } else {
341            StdRng::from_rng(rand::thread_rng()).unwrap()
342        };
343
344        // Build mask on CPU: 'a' where kept, 0.0 where dropped
345        let dropped_val = a * alpha_p + b;
346        let mask_raw: Vec<f32> = (0..numel)
347            .map(|_| if rng.r#gen::<f32>() < self.p { 0.0 } else { a })
348            .collect();
349
350        // Build bias tensor: dropped_val where dropped, b where kept
351        let bias_raw: Vec<f32> = mask_raw
352            .iter()
353            .map(|&m| if m == 0.0 { dropped_val } else { b })
354            .collect();
355
356        let mut mask_tensor = Tensor::from_vec(mask_raw, &shape).expect("tensor creation failed");
357        let mut bias_tensor = Tensor::from_vec(bias_raw, &shape).expect("tensor creation failed");
358        if input_data.device().is_gpu() {
359            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
360            bias_tensor = bias_tensor.to_device(input_data.device()).unwrap();
361        }
362
363        // output = mask * input + bias  (all Tensor ops, GPU-dispatched)
364        let output = input_data
365            .mul(&mask_tensor)
366            .unwrap()
367            .add(&bias_tensor)
368            .unwrap();
369        let requires_grad = input.requires_grad() && is_grad_enabled();
370
371        if requires_grad {
372            let grad_fn = GradFn::new(DropoutBackward {
373                next_fns: vec![input.grad_fn().cloned()],
374                mask_tensor,
375            });
376            Variable::from_operation(output, grad_fn, true)
377        } else {
378            Variable::from_tensor(output)
379        }
380    }
381
382    fn set_training(&mut self, training: bool) {
383        self.training.store(training, Ordering::Relaxed);
384    }
385
386    fn is_training(&self) -> bool {
387        self.training.load(Ordering::Relaxed)
388    }
389
390    fn name(&self) -> &'static str {
391        "AlphaDropout"
392    }
393}
394
395// =============================================================================
396// Tests
397// =============================================================================
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_dropout_training() {
405        let dropout = Dropout::new(0.5);
406        let input = Variable::new(
407            Tensor::from_vec(vec![1.0; 1000], &[1000]).expect("tensor creation failed"),
408            false,
409        );
410        let output = dropout.forward(&input);
411
412        // Some values should be zero, some should be scaled
413        let output_vec = output.data().to_vec();
414        let num_zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
415
416        // With p=0.5, roughly half should be zero (with some variance)
417        assert!(num_zeros > 300 && num_zeros < 700);
418    }
419
420    #[test]
421    fn test_dropout_eval() {
422        let mut dropout = Dropout::new(0.5);
423        dropout.eval();
424
425        let input = Variable::new(
426            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
427            false,
428        );
429        let output = dropout.forward(&input);
430
431        // In eval mode, output should equal input
432        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
433    }
434
435    #[test]
436    fn test_dropout_zero_probability() {
437        let dropout = Dropout::new(0.0);
438        let input = Variable::new(
439            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
440            false,
441        );
442        let output = dropout.forward(&input);
443
444        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
445    }
446}