Skip to main content

axonml_nn/layers/
dropout.rs

1//! Dropout Layers - Regularization via Random Zeroing
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/dropout.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::any::Any;
18use std::sync::atomic::{AtomicBool, Ordering};
19
20use axonml_autograd::no_grad::is_grad_enabled;
21use axonml_autograd::{GradFn, GradientFunction, Variable};
22use axonml_tensor::Tensor;
23use rand::Rng;
24
25use crate::module::Module;
26
27// =============================================================================
28// DropoutBackward
29// =============================================================================
30
31/// Gradient function for Dropout.
32///
33/// Applies the same mask used in the forward pass: gradient is scaled where
34/// elements were kept, and zeroed where elements were dropped.
35#[derive(Debug)]
36struct DropoutBackward {
37    next_fns: Vec<Option<GradFn>>,
38    /// The mask as a tensor (stored on same device as input — GPU or CPU).
39    mask_tensor: Tensor<f32>,
40}
41
42impl GradientFunction for DropoutBackward {
43    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
44        let result = grad_output.mul(&self.mask_tensor).unwrap();
45        vec![Some(result)]
46    }
47
48    fn name(&self) -> &'static str {
49        "DropoutBackward"
50    }
51
52    fn next_functions(&self) -> &[Option<GradFn>] {
53        &self.next_fns
54    }
55
56    fn as_any(&self) -> &dyn Any {
57        self
58    }
59}
60
61// =============================================================================
62// Dropout
63// =============================================================================
64
65/// During training, randomly zeros some elements with probability p.
66///
67/// During evaluation, returns input unchanged.
68///
69/// # Arguments
70/// * `p` - Probability of an element to be zeroed (default: 0.5)
71pub struct Dropout {
72    /// Dropout probability.
73    p: f32,
74    /// Whether in training mode.
75    training: AtomicBool,
76}
77
78impl std::fmt::Debug for Dropout {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("Dropout")
81            .field("p", &self.p)
82            .field("training", &self.training.load(Ordering::Relaxed))
83            .finish()
84    }
85}
86
87impl Dropout {
88    /// Creates a new Dropout layer with the given probability.
89    pub fn new(p: f32) -> Self {
90        assert!(
91            (0.0..1.0).contains(&p),
92            "Dropout probability must be in [0, 1)"
93        );
94        Self {
95            p,
96            training: AtomicBool::new(true),
97        }
98    }
99
100    /// Creates a Dropout layer with default probability (0.5).
101    pub fn default_p() -> Self {
102        Self::new(0.5)
103    }
104}
105
106impl Default for Dropout {
107    fn default() -> Self {
108        Self::default_p()
109    }
110}
111
112impl Module for Dropout {
113    fn forward(&self, input: &Variable) -> Variable {
114        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
115            return input.clone();
116        }
117
118        let input_data = input.data();
119        let shape = input_data.shape().to_vec();
120        let numel = input_data.numel();
121        let mut rng = rand::thread_rng();
122
123        // Scale factor for inverted dropout
124        let scale = 1.0 / (1.0 - self.p);
125
126        // Build mask on CPU: 0.0 for dropped, scale for kept
127        let mask: Vec<f32> = (0..numel)
128            .map(|_| {
129                if rng.r#gen::<f32>() < self.p {
130                    0.0
131                } else {
132                    scale
133                }
134            })
135            .collect();
136
137        // Create mask tensor and move to input device
138        let mut mask_tensor = Tensor::from_vec(mask, &shape).unwrap();
139        if input_data.device().is_gpu() {
140            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
141        }
142        let output = input_data.mul(&mask_tensor).unwrap();
143
144        let requires_grad = input.requires_grad() && is_grad_enabled();
145
146        if requires_grad {
147            let grad_fn = GradFn::new(DropoutBackward {
148                next_fns: vec![input.grad_fn().cloned()],
149                mask_tensor,
150            });
151            Variable::from_operation(output, grad_fn, true)
152        } else {
153            Variable::from_tensor(output)
154        }
155    }
156
157    fn set_training(&mut self, training: bool) {
158        self.training.store(training, Ordering::Relaxed);
159    }
160
161    fn is_training(&self) -> bool {
162        self.training.load(Ordering::Relaxed)
163    }
164
165    fn name(&self) -> &'static str {
166        "Dropout"
167    }
168}
169
170// =============================================================================
171// Dropout2d
172// =============================================================================
173
174/// Randomly zeros entire channels during training.
175///
176/// Useful for spatial data like images.
177///
178/// # Shape
179/// - Input: (N, C, H, W)
180/// - Output: Same as input
181pub struct Dropout2d {
182    /// Dropout probability.
183    p: f32,
184    /// Whether in training mode.
185    training: AtomicBool,
186}
187
188impl std::fmt::Debug for Dropout2d {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("Dropout2d")
191            .field("p", &self.p)
192            .field("training", &self.training.load(Ordering::Relaxed))
193            .finish()
194    }
195}
196
197impl Dropout2d {
198    /// Creates a new Dropout2d layer.
199    pub fn new(p: f32) -> Self {
200        assert!(
201            (0.0..1.0).contains(&p),
202            "Dropout probability must be in [0, 1)"
203        );
204        Self {
205            p,
206            training: AtomicBool::new(true),
207        }
208    }
209}
210
211impl Module for Dropout2d {
212    fn forward(&self, input: &Variable) -> Variable {
213        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
214            return input.clone();
215        }
216
217        let input_data = input.data();
218        let shape = input_data.shape().to_vec();
219        let batch_size = shape[0];
220        let channels = shape[1];
221        let spatial_size: usize = shape[2..].iter().product();
222
223        let input_vec = input_data.to_vec();
224        let total = input_vec.len();
225        let mut mask = vec![0.0f32; total];
226        let mut rng = rand::thread_rng();
227        let scale = 1.0 / (1.0 - self.p);
228
229        for b in 0..batch_size {
230            for c in 0..channels {
231                let keep = rng.r#gen::<f32>() >= self.p;
232                let start = b * channels * spatial_size + c * spatial_size;
233                if keep {
234                    for i in 0..spatial_size {
235                        mask[start + i] = scale;
236                    }
237                }
238            }
239        }
240
241        let mut mask_tensor = Tensor::from_vec(mask, &shape).unwrap();
242        if input_data.device().is_gpu() {
243            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
244        }
245        let output = input_data.mul(&mask_tensor).unwrap();
246        let requires_grad = input.requires_grad() && is_grad_enabled();
247
248        if requires_grad {
249            let grad_fn = GradFn::new(DropoutBackward {
250                next_fns: vec![input.grad_fn().cloned()],
251                mask_tensor,
252            });
253            Variable::from_operation(output, grad_fn, true)
254        } else {
255            Variable::from_tensor(output)
256        }
257    }
258
259    fn set_training(&mut self, training: bool) {
260        self.training.store(training, Ordering::Relaxed);
261    }
262
263    fn is_training(&self) -> bool {
264        self.training.load(Ordering::Relaxed)
265    }
266
267    fn name(&self) -> &'static str {
268        "Dropout2d"
269    }
270}
271
272// =============================================================================
273// AlphaDropout
274// =============================================================================
275
276/// Alpha Dropout for Self-Normalizing Neural Networks (SNNs).
277///
278/// Preserves the mean and variance of inputs by using specific alpha values.
279pub struct AlphaDropout {
280    /// Dropout probability.
281    p: f32,
282    /// Whether in training mode.
283    training: AtomicBool,
284}
285
286impl AlphaDropout {
287    /// Creates a new AlphaDropout layer.
288    pub fn new(p: f32) -> Self {
289        assert!(
290            (0.0..1.0).contains(&p),
291            "Dropout probability must be in [0, 1)"
292        );
293        Self {
294            p,
295            training: AtomicBool::new(true),
296        }
297    }
298}
299
300impl Module for AlphaDropout {
301    fn forward(&self, input: &Variable) -> Variable {
302        if !self.training.load(Ordering::Relaxed) || self.p == 0.0 {
303            return input.clone();
304        }
305
306        // SELU parameters
307        const ALPHA: f32 = 1.673_263_2;
308        const SCALE: f32 = 1.050_701;
309
310        let alpha_p = -ALPHA * SCALE;
311        let a = ((1.0 - self.p) * (1.0 + self.p * alpha_p.powi(2)))
312            .sqrt()
313            .recip();
314        let b = -a * alpha_p * self.p;
315
316        let input_data = input.data();
317        let shape = input_data.shape().to_vec();
318        let numel = input_data.numel();
319        let mut rng = rand::thread_rng();
320
321        // Build mask on CPU: 'a' where kept, 0.0 where dropped
322        let dropped_val = a * alpha_p + b;
323        let mask_raw: Vec<f32> = (0..numel)
324            .map(|_| if rng.r#gen::<f32>() < self.p { 0.0 } else { a })
325            .collect();
326
327        // Build bias tensor: dropped_val where dropped, b where kept
328        let bias_raw: Vec<f32> = mask_raw
329            .iter()
330            .map(|&m| if m == 0.0 { dropped_val } else { b })
331            .collect();
332
333        let mut mask_tensor = Tensor::from_vec(mask_raw, &shape).unwrap();
334        let mut bias_tensor = Tensor::from_vec(bias_raw, &shape).unwrap();
335        if input_data.device().is_gpu() {
336            mask_tensor = mask_tensor.to_device(input_data.device()).unwrap();
337            bias_tensor = bias_tensor.to_device(input_data.device()).unwrap();
338        }
339
340        // output = mask * input + bias  (all Tensor ops, GPU-dispatched)
341        let output = input_data
342            .mul(&mask_tensor)
343            .unwrap()
344            .add(&bias_tensor)
345            .unwrap();
346        let requires_grad = input.requires_grad() && is_grad_enabled();
347
348        if requires_grad {
349            let grad_fn = GradFn::new(DropoutBackward {
350                next_fns: vec![input.grad_fn().cloned()],
351                mask_tensor,
352            });
353            Variable::from_operation(output, grad_fn, true)
354        } else {
355            Variable::from_tensor(output)
356        }
357    }
358
359    fn set_training(&mut self, training: bool) {
360        self.training.store(training, Ordering::Relaxed);
361    }
362
363    fn is_training(&self) -> bool {
364        self.training.load(Ordering::Relaxed)
365    }
366
367    fn name(&self) -> &'static str {
368        "AlphaDropout"
369    }
370}
371
372// =============================================================================
373// Tests
374// =============================================================================
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_dropout_training() {
382        let dropout = Dropout::new(0.5);
383        let input = Variable::new(Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap(), false);
384        let output = dropout.forward(&input);
385
386        // Some values should be zero, some should be scaled
387        let output_vec = output.data().to_vec();
388        let num_zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
389
390        // With p=0.5, roughly half should be zero (with some variance)
391        assert!(num_zeros > 300 && num_zeros < 700);
392    }
393
394    #[test]
395    fn test_dropout_eval() {
396        let mut dropout = Dropout::new(0.5);
397        dropout.eval();
398
399        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
400        let output = dropout.forward(&input);
401
402        // In eval mode, output should equal input
403        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
404    }
405
406    #[test]
407    fn test_dropout_zero_probability() {
408        let dropout = Dropout::new(0.0);
409        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
410        let output = dropout.forward(&input);
411
412        assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
413    }
414}