Skip to main content

mlx_nn/
dropout.rs

1//! Dropout module — randomly zeros elements during training.
2
3use mlx_core::{Result, Tensor};
4use rand::Rng;
5
6use crate::Module;
7
8/// Dropout: randomly zeros elements with probability `p` during training.
9///
10/// During training, each element is independently set to zero with probability
11/// `p`, and the remaining elements are scaled by `1 / (1 - p)` to preserve the
12/// expected value. In eval mode, the input is passed through unchanged.
13pub struct Dropout {
14    p: f32,
15    training: bool,
16}
17
18impl Dropout {
19    /// Create a new Dropout with drop probability `p`.
20    pub fn new(p: f32) -> Self {
21        Self { p, training: true }
22    }
23
24    /// Set to training mode.
25    pub fn train(&mut self) {
26        self.training = true;
27    }
28
29    /// Set to eval mode (no dropout).
30    pub fn eval(&mut self) {
31        self.training = false;
32    }
33}
34
35impl Module for Dropout {
36    fn forward(&self, input: &Tensor) -> Result<Tensor> {
37        if !self.training || self.p == 0.0 {
38            return Ok(input.clone());
39        }
40        if self.p >= 1.0 {
41            return Tensor::zeros(input.shape(), input.dtype(), input.device());
42        }
43        // Note: the mask is generated eagerly (not as a graph op), so this
44        // allocates even if the result tensor is never evaluated. A graph-level
45        // random op would preserve full laziness but is not yet implemented.
46        let n = input.numel() as usize;
47        let mut rng = rand::rng();
48        let scale = 1.0 / (1.0 - self.p);
49        let mask: Vec<f32> = (0..n)
50            .map(|_| {
51                if rng.random::<f32>() >= self.p {
52                    scale
53                } else {
54                    0.0
55                }
56            })
57            .collect();
58        let mask_t =
59            Tensor::from_data_with_dtype(mask, input.shape(), input.dtype(), input.device())?;
60        input.mul(&mask_t)
61    }
62}