Skip to main content

axonml_nn/layers/
dropout.rs

1//! Dropout Layers - Regularization via Random Zeroing
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/dropout.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::any::Any;
18use std::sync::atomic::{AtomicBool, Ordering};
19
20use axonml_autograd::no_grad::is_grad_enabled;
21use axonml_autograd::{GradFn, GradientFunction, Variable, checkpoint_rng_seed};
22use axonml_tensor::Tensor;
23use rand::rngs::StdRng;
24use rand::{Rng, SeedableRng};
25
26use crate::module::Module;
27
28// =============================================================================
29// DropoutBackward
30// =============================================================================
31
32/// Gradient function for Dropout.
33///
34/// Applies the same mask used in the forward pass: gradient is scaled where
35/// elements were kept, and zeroed where elements were dropped.
36#[derive(Debug)]
37struct DropoutBackward {
38    next_fns: Vec<Option<GradFn>>,
39    /// The mask as a tensor (stored on same device as input — GPU or CPU).
40    mask_tensor: Tensor<f32>,
41}
42
43impl GradientFunction for DropoutBackward {
44    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
45        let result = grad_output
46            .mul(&self.mask_tensor)
47            .expect("tensor mul failed");
48        vec![Some(result)]
49    }
50
51    fn name(&self) -> &'static str {
52        "DropoutBackward"
53    }
54
55    fn next_functions(&self) -> &[Option<GradFn>] {
56        &self.next_fns
57    }
58
59    fn as_any(&self) -> &dyn Any {
60        self
61    }
62}
63
64// =============================================================================
65// Dropout
66// =============================================================================
67
68/// During training, randomly zeros some elements with probability p.
69///
70/// During evaluation, returns input unchanged.
71///
72/// # Arguments
73/// * `p` - Probability of an element to be zeroed (default: 0.5)
74pub struct Dropout {
75    /// Dropout probability.
76    p: f32,
77    /// Whether in training mode.
78    training: AtomicBool,
79}
80
81impl std::fmt::Debug for Dropout {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("Dropout")
84            .field("p", &self.p)
85            .field("training", &self.training.load(Ordering::Relaxed))
86            .finish()
87    }
88}
89
90impl Dropout {
91    /// Creates a new Dropout layer with the given probability.
92    pub fn new(p: f32) -> Self {
93        assert!(
94            (0.0..1.0).contains(&p),
95            "Dropout probability must be in [0, 1)"
96        );
97        Self {
98            p,
99            training: AtomicBool::new(true),
100        }
101    }
102
103    /// Creates a Dropout layer with default probability (0.5).
104    pub fn default_p() -> Self {
105        Self::new(0.5)
106    }
107}
108
109impl Default for Dropout {
110    fn default() -> Self {
111        Self::default_p()
112    }
113}
114
115impl Module for Dropout {
116    fn forward(&self, input: &Variable) -> Variable {
117        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
118            return input.clone();
119        }
120
121        let input_data = input.data();
122        let shape = input_data.shape().to_vec();
123        let numel = input_data.numel();
124        // Use deterministic RNG during checkpoint recomputation
125        let mut rng = if let Some(seed) = checkpoint_rng_seed() {
126            StdRng::seed_from_u64(seed)
127        } else {
128            StdRng::from_rng(rand::thread_rng()).unwrap()
129        };
130
131        // Scale factor for inverted dropout
132        let scale = 1.0 / (1.0 - self.p);
133
134        // Build mask on CPU: 0.0 for dropped, scale for kept
135        let mask: Vec<f32> = (0..numel)
136            .map(|_| {
137                if rng.r#gen::<f32>() < self.p {
138                    0.0
139                } else {
140                    scale
141                }
142            })
143            .collect();
144
145        // Create mask tensor and move to input device
146        let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
147        if input_data.device().is_gpu() {
148            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
149        }
150        let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
151
152        let requires_grad = input.requires_grad() && is_grad_enabled();
153
154        if requires_grad {
155            let grad_fn = GradFn::new(DropoutBackward {
156                next_fns: vec![input.grad_fn().cloned()],
157                mask_tensor,
158            });
159            Variable::from_operation(output, grad_fn, true)
160        } else {
161            Variable::from_tensor(output)
162        }
163    }
164
165    fn set_training(&mut self, training: bool) {
166        self.training.store(training, Ordering::Relaxed);
167    }
168
169    fn is_training(&self) -> bool {
170        self.training.load(Ordering::Relaxed)
171    }
172
173    fn name(&self) -> &'static str {
174        "Dropout"
175    }
176}
177
178// =============================================================================
179// Dropout2d
180// =============================================================================
181
182/// Randomly zeros entire channels during training.
183///
184/// Useful for spatial data like images.
185///
186/// # Shape
187/// - Input: (N, C, H, W)
188/// - Output: Same as input
189pub struct Dropout2d {
190    /// Dropout probability.
191    p: f32,
192    /// Whether in training mode.
193    training: AtomicBool,
194}
195
196impl std::fmt::Debug for Dropout2d {
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        f.debug_struct("Dropout2d")
199            .field("p", &self.p)
200            .field("training", &self.training.load(Ordering::Relaxed))
201            .finish()
202    }
203}
204
205impl Dropout2d {
206    /// Creates a new Dropout2d layer.
207    pub fn new(p: f32) -> Self {
208        assert!(
209            (0.0..1.0).contains(&p),
210            "Dropout probability must be in [0, 1)"
211        );
212        Self {
213            p,
214            training: AtomicBool::new(true),
215        }
216    }
217}
218
219impl Module for Dropout2d {
220    fn forward(&self, input: &Variable) -> Variable {
221        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
222            return input.clone();
223        }
224
225        let input_data = input.data();
226        let shape = input_data.shape().to_vec();
227        let batch_size = shape[0];
228        let channels = shape[1];
229        let spatial_size: usize = shape[2..].iter().product();
230
231        let input_vec = input_data.to_vec();
232        let total = input_vec.len();
233        let mut mask = vec![0.0f32; total];
234        // Use deterministic RNG during checkpoint recomputation
235        let mut rng = if let Some(seed) = checkpoint_rng_seed() {
236            StdRng::seed_from_u64(seed)
237        } else {
238            StdRng::from_rng(rand::thread_rng()).unwrap()
239        };
240        let scale = 1.0 / (1.0 - self.p);
241
242        for b in 0..batch_size {
243            for c in 0..channels {
244                let keep = rng.r#gen::<f32>() >= self.p;
245                let start = b * channels * spatial_size + c * spatial_size;
246                if keep {
247                    for i in 0..spatial_size {
248                        mask[start + i] = scale;
249                    }
250                }
251            }
252        }
253
254        let mut mask_tensor = Tensor::from_vec(mask, &shape).expect("tensor creation failed");
255        if input_data.device().is_gpu() {
256            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
257        }
258        let output = input_data.mul(&mask_tensor).expect("tensor mul failed");
259        let requires_grad = input.requires_grad() && is_grad_enabled();
260
261        if requires_grad {
262            let grad_fn = GradFn::new(DropoutBackward {
263                next_fns: vec![input.grad_fn().cloned()],
264                mask_tensor,
265            });
266            Variable::from_operation(output, grad_fn, true)
267        } else {
268            Variable::from_tensor(output)
269        }
270    }
271
272    fn set_training(&mut self, training: bool) {
273        self.training.store(training, Ordering::Relaxed);
274    }
275
276    fn is_training(&self) -> bool {
277        self.training.load(Ordering::Relaxed)
278    }
279
280    fn name(&self) -> &'static str {
281        "Dropout2d"
282    }
283}
284
285// =============================================================================
286// AlphaDropout
287// =============================================================================
288
289/// Alpha Dropout for Self-Normalizing Neural Networks (SNNs).
290///
291/// Preserves the mean and variance of inputs by using specific alpha values.
292pub struct AlphaDropout {
293    /// Dropout probability.
294    p: f32,
295    /// Whether in training mode.
296    training: AtomicBool,
297}
298
299impl AlphaDropout {
300    /// Creates a new AlphaDropout layer.
301    pub fn new(p: f32) -> Self {
302        assert!(
303            (0.0..1.0).contains(&p),
304            "Dropout probability must be in [0, 1)"
305        );
306        Self {
307            p,
308            training: AtomicBool::new(true),
309        }
310    }
311}
312
313impl Module for AlphaDropout {
314    fn forward(&self, input: &Variable) -> Variable {
315        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
316            return input.clone();
317        }
318
319        // SELU parameters
320        const ALPHA: f32 = 1.673_263_2;
321        const SCALE: f32 = 1.050_701;
322
323        let alpha_p = -ALPHA * SCALE;
324        let a = ((1.0 - self.p) * (1.0 + self.p * alpha_p.powi(2)))
325            .sqrt()
326            .recip();
327        let b = -a * alpha_p * self.p;
328
329        let input_data = input.data();
330        let shape = input_data.shape().to_vec();
331        let numel = input_data.numel();
332        // Use deterministic RNG during checkpoint recomputation
333        let mut rng = if let Some(seed) = checkpoint_rng_seed() {
334            StdRng::seed_from_u64(seed)
335        } else {
336            StdRng::from_rng(rand::thread_rng()).unwrap()
337        };
338
339        // Build mask on CPU: 'a' where kept, 0.0 where dropped
340        let dropped_val = a * alpha_p + b;
341        let mask_raw: Vec<f32> = (0..numel)
342            .map(|_| if rng.r#gen::<f32>() < self.p { 0.0 } else { a })
343            .collect();
344
345        // Build bias tensor: dropped_val where dropped, b where kept
346        let bias_raw: Vec<f32> = mask_raw
347            .iter()
348            .map(|&m| if m == 0.0 { dropped_val } else { b })
349            .collect();
350
351        let mut mask_tensor = Tensor::from_vec(mask_raw, &shape).expect("tensor creation failed");
352        let mut bias_tensor = Tensor::from_vec(bias_raw, &shape).expect("tensor creation failed");
353        if input_data.device().is_gpu() {
354            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
355            bias_tensor = bias_tensor.to_device(input_data.device()).unwrap();
356        }
357
358        // output = mask * input + bias  (all Tensor ops, GPU-dispatched)
359        let output = input_data
360            .mul(&mask_tensor)
361            .unwrap()
362            .add(&bias_tensor)
363            .unwrap();
364        let requires_grad = input.requires_grad() && is_grad_enabled();
365
366        if requires_grad {
367            let grad_fn = GradFn::new(DropoutBackward {
368                next_fns: vec![input.grad_fn().cloned()],
369                mask_tensor,
370            });
371            Variable::from_operation(output, grad_fn, true)
372        } else {
373            Variable::from_tensor(output)
374        }
375    }
376
377    fn set_training(&mut self, training: bool) {
378        self.training.store(training, Ordering::Relaxed);
379    }
380
381    fn is_training(&self) -> bool {
382        self.training.load(Ordering::Relaxed)
383    }
384
385    fn name(&self) -> &'static str {
386        "AlphaDropout"
387    }
388}
389
390// =============================================================================
391// Tests
392// =============================================================================
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    #[test]
399    fn test_dropout_training() {
400        let dropout = Dropout::new(0.5);
401        let input = Variable::new(
402            Tensor::from_vec(vec![1.0; 1000], &[1000]).expect("tensor creation failed"),
403            false,
404        );
405        let output = dropout.forward(&input);
406
407        // Some values should be zero, some should be scaled
408        let output_vec = output.data().to_vec();
409        let num_zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
410
411        // With p=0.5, roughly half should be zero (with some variance)
412        assert!(num_zeros > 300 && num_zeros < 700);
413    }
414
415    #[test]
416    fn test_dropout_eval() {
417        let mut dropout = Dropout::new(0.5);
418        dropout.eval();
419
420        let input = Variable::new(
421            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
422            false,
423        );
424        let output = dropout.forward(&input);
425
426        // In eval mode, output should equal input
427        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
428    }
429
430    #[test]
431    fn test_dropout_zero_probability() {
432        let dropout = Dropout::new(0.0);
433        let input = Variable::new(
434            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
435            false,
436        );
437        let output = dropout.forward(&input);
438
439        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
440    }
441}