Skip to main content

axonml_nn/layers/
dropout.rs

1//! Dropout Layers - Regularization via Random Zeroing
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/dropout.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::any::Any;
18use std::sync::atomic::{AtomicBool, Ordering};
19
20use axonml_autograd::no_grad::is_grad_enabled;
21use axonml_autograd::{GradFn, GradientFunction, Variable, checkpoint_rng_seed};
22use axonml_tensor::Tensor;
23use rand::rngs::StdRng;
24use rand::{Rng, SeedableRng};
25
26use crate::module::Module;
27
28// =============================================================================
29// DropoutBackward
30// =============================================================================
31
32/// Gradient function for Dropout.
33///
34/// Applies the same mask used in the forward pass: gradient is scaled where
35/// elements were kept, and zeroed where elements were dropped.
36#[derive(Debug)]
37struct DropoutBackward {
38    next_fns: Vec<Option<GradFn>>,
39    /// The mask as a tensor (stored on same device as input — GPU or CPU).
40    mask_tensor: Tensor<f32>,
41}
42
43impl GradientFunction for DropoutBackward {
44    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
45        let result = grad_output.mul(&self.mask_tensor).expect("tensor mul failed");
46        vec![Some(result)]
47    }
48
49    fn name(&self) -> &'static str {
50        "DropoutBackward"
51    }
52
53    fn next_functions(&self) -> &[Option<GradFn>] {
54        &self.next_fns
55    }
56
57    fn as_any(&self) -> &dyn Any {
58        self
59    }
60}
61
62// =============================================================================
63// Dropout
64// =============================================================================
65
66/// During training, randomly zeros some elements with probability p.
67///
68/// During evaluation, returns input unchanged.
69///
70/// # Arguments
71/// * `p` - Probability of an element to be zeroed (default: 0.5)
72pub struct Dropout {
73    /// Dropout probability.
74    p: f32,
75    /// Whether in training mode.
76    training: AtomicBool,
77}
78
79impl std::fmt::Debug for Dropout {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.debug_struct("Dropout")
82            .field("p", &self.p)
83            .field("training", &self.training.load(Ordering::Relaxed))
84            .finish()
85    }
86}
87
88impl Dropout {
89    /// Creates a new Dropout layer with the given probability.
90    pub fn new(p: f32) -> Self {
91        assert!(
92            (0.0..1.0).contains(&p),
93            "Dropout probability must be in [0, 1)"
94        );
95        Self {
96            p,
97            training: AtomicBool::new(true),
98        }
99    }
100
101    /// Creates a Dropout layer with default probability (0.5).
102    pub fn default_p() -> Self {
103        Self::new(0.5)
104    }
105}
106
107impl Default for Dropout {
108    fn default() -> Self {
109        Self::default_p()
110    }
111}
112
113impl Module for Dropout {
114    fn forward(&self, input: &Variable) -> Variable {
115        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
116            return input.clone();
117        }
118
119        let input_data = input.data();
120        let shape = input_data.shape().to_vec();
121        let numel = input_data.numel();
122        // Use deterministic RNG during checkpoint recomputation
123        let mut rng = if let Some(seed) = checkpoint_rng_seed() {
124            StdRng::seed_from_u64(seed)
125        } else {
126            StdRng::from_rng(rand::thread_rng()).unwrap()
127        };
128
129        // Scale factor for inverted dropout
130        let scale = 1.0 / (1.0 - self.p);
131
132        // Build mask on CPU: 0.0 for dropped, scale for kept
133        let mask: Vec<f32> = (0..numel)
134            .map(|_| {
135                if rng.r#gen::<f32>() < self.p {
136                    0.0
137                } else {
138                    scale
139                }
140            })
141            .collect();
142
143        // Create mask tensor and move to input device
144        let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
145        if input_data.device().is_gpu() {
146            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
147        }
148        let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
149
150        let requires_grad = input.requires_grad() && is_grad_enabled();
151
152        if requires_grad {
153            let grad_fn = GradFn::new(DropoutBackward {
154                next_fns: vec![input.grad_fn().cloned()],
155                mask_tensor,
156            });
157            Variable::from_operation(output, grad_fn, true)
158        } else {
159            Variable::from_tensor(output)
160        }
161    }
162
163    fn set_training(&mut self, training: bool) {
164        self.training.store(training, Ordering::Relaxed);
165    }
166
167    fn is_training(&self) -> bool {
168        self.training.load(Ordering::Relaxed)
169    }
170
171    fn name(&self) -> &'static str {
172        "Dropout"
173    }
174}
175
176// =============================================================================
177// Dropout2d
178// =============================================================================
179
180/// Randomly zeros entire channels during training.
181///
182/// Useful for spatial data like images.
183///
184/// # Shape
185/// - Input: (N, C, H, W)
186/// - Output: Same as input
187pub struct Dropout2d {
188    /// Dropout probability.
189    p: f32,
190    /// Whether in training mode.
191    training: AtomicBool,
192}
193
194impl std::fmt::Debug for Dropout2d {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        f.debug_struct("Dropout2d")
197            .field("p", &self.p)
198            .field("training", &self.training.load(Ordering::Relaxed))
199            .finish()
200    }
201}
202
203impl Dropout2d {
204    /// Creates a new Dropout2d layer.
205    pub fn new(p: f32) -> Self {
206        assert!(
207            (0.0..1.0).contains(&p),
208            "Dropout probability must be in [0, 1)"
209        );
210        Self {
211            p,
212            training: AtomicBool::new(true),
213        }
214    }
215}
216
217impl Module for Dropout2d {
218    fn forward(&self, input: &Variable) -> Variable {
219        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
220            return input.clone();
221        }
222
223        let input_data = input.data();
224        let shape = input_data.shape().to_vec();
225        let batch_size = shape[0];
226        let channels = shape[1];
227        let spatial_size: usize = shape[2..].iter().product();
228
229        let input_vec = input_data.to_vec();
230        let total = input_vec.len();
231        let mut mask = vec![0.0f32; total];
232        // Use deterministic RNG during checkpoint recomputation
233        let mut rng = if let Some(seed) = checkpoint_rng_seed() {
234            StdRng::seed_from_u64(seed)
235        } else {
236            StdRng::from_rng(rand::thread_rng()).unwrap()
237        };
238        let scale = 1.0 / (1.0 - self.p);
239
240        for b in 0..batch_size {
241            for c in 0..channels {
242                let keep = rng.r#gen::<f32>() >= self.p;
243                let start = b * channels * spatial_size + c * spatial_size;
244                if keep {
245                    for i in 0..spatial_size {
246                        mask[start + i] = scale;
247                    }
248                }
249            }
250        }
251
252        let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
253        if input_data.device().is_gpu() {
254            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
255        }
256        let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
257        let requires_grad = input.requires_grad() && is_grad_enabled();
258
259        if requires_grad {
260            let grad_fn = GradFn::new(DropoutBackward {
261                next_fns: vec![input.grad_fn().cloned()],
262                mask_tensor,
263            });
264            Variable::from_operation(output, grad_fn, true)
265        } else {
266            Variable::from_tensor(output)
267        }
268    }
269
270    fn set_training(&mut self, training: bool) {
271        self.training.store(training, Ordering::Relaxed);
272    }
273
274    fn is_training(&self) -> bool {
275        self.training.load(Ordering::Relaxed)
276    }
277
278    fn name(&self) -> &'static str {
279        "Dropout2d"
280    }
281}
282
283// =============================================================================
284// AlphaDropout
285// =============================================================================
286
287/// Alpha Dropout for Self-Normalizing Neural Networks (SNNs).
288///
289/// Preserves the mean and variance of inputs by using specific alpha values.
290pub struct AlphaDropout {
291    /// Dropout probability.
292    p: f32,
293    /// Whether in training mode.
294    training: AtomicBool,
295}
296
297impl AlphaDropout {
298    /// Creates a new AlphaDropout layer.
299    pub fn new(p: f32) -> Self {
300        assert!(
301            (0.0..1.0).contains(&p),
302            "Dropout probability must be in [0, 1)"
303        );
304        Self {
305            p,
306            training: AtomicBool::new(true),
307        }
308    }
309}
310
311impl Module for AlphaDropout {
312    fn forward(&self, input: &Variable) -> Variable {
313        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
314            return input.clone();
315        }
316
317        // SELU parameters
318        const ALPHA: f32 = 1.673_263_2;
319        const SCALE: f32 = 1.050_701;
320
321        let alpha_p = -ALPHA * SCALE;
322        let a = ((1.0 - self.p) * (1.0 + self.p * alpha_p.powi(2)))
323            .sqrt()
324            .recip();
325        let b = -a * alpha_p * self.p;
326
327        let input_data = input.data();
328        let shape = input_data.shape().to_vec();
329        let numel = input_data.numel();
330        // Use deterministic RNG during checkpoint recomputation
331        let mut rng = if let Some(seed) = checkpoint_rng_seed() {
332            StdRng::seed_from_u64(seed)
333        } else {
334            StdRng::from_rng(rand::thread_rng()).unwrap()
335        };
336
337        // Build mask on CPU: 'a' where kept, 0.0 where dropped
338        let dropped_val = a * alpha_p + b;
339        let mask_raw: Vec<f32> = (0..numel)
340            .map(|_| if rng.r#gen::<f32>() < self.p { 0.0 } else { a })
341            .collect();
342
343        // Build bias tensor: dropped_val where dropped, b where kept
344        let bias_raw: Vec<f32> = mask_raw
345            .iter()
346            .map(|&m| if m == 0.0 { dropped_val } else { b })
347            .collect();
348
349        let mut mask_tensor = Tensor::from_vec(mask_raw, &shape).expect("tensor creation failed");
350        let mut bias_tensor = Tensor::from_vec(bias_raw, &shape).expect("tensor creation failed");
351        if input_data.device().is_gpu() {
352            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
353            bias_tensor = bias_tensor.to_device(input_data.device()).unwrap();
354        }
355
356        // output = mask * input + bias  (all Tensor ops, GPU-dispatched)
357        let output = input_data
358            .mul(&mask_tensor)
359            .unwrap()
360            .add(&bias_tensor)
361            .unwrap();
362        let requires_grad = input.requires_grad() && is_grad_enabled();
363
364        if requires_grad {
365            let grad_fn = GradFn::new(DropoutBackward {
366                next_fns: vec![input.grad_fn().cloned()],
367                mask_tensor,
368            });
369            Variable::from_operation(output, grad_fn, true)
370        } else {
371            Variable::from_tensor(output)
372        }
373    }
374
375    fn set_training(&mut self, training: bool) {
376        self.training.store(training, Ordering::Relaxed);
377    }
378
379    fn is_training(&self) -> bool {
380        self.training.load(Ordering::Relaxed)
381    }
382
383    fn name(&self) -> &'static str {
384        "AlphaDropout"
385    }
386}
387
388// =============================================================================
389// Tests
390// =============================================================================
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn test_dropout_training() {
398        let dropout = Dropout::new(0.5);
399        let input = Variable::new(Tensor::from_vec(vec![1.0; 1000], &[1000]).expect("tensor creation failed"), false);
400        let output = dropout.forward(&input);
401
402        // Some values should be zero, some should be scaled
403        let output_vec = output.data().to_vec();
404        let num_zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
405
406        // With p=0.5, roughly half should be zero (with some variance)
407        assert!(num_zeros > 300 && num_zeros < 700);
408    }
409
410    #[test]
411    fn test_dropout_eval() {
412        let mut dropout = Dropout::new(0.5);
413        dropout.eval();
414
415        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), false);
416        let output = dropout.forward(&input);
417
418        // In eval mode, output should equal input
419        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
420    }
421
422    #[test]
423    fn test_dropout_zero_probability() {
424        let dropout = Dropout::new(0.0);
425        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), false);
426        let output = dropout.forward(&input);
427
428        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
429    }
430}