mlx_nn/dropout.rs
1//! Dropout module — randomly zeros elements during training.
2
3use mlx_core::{Result, Tensor};
4use rand::Rng;
5
6use crate::Module;
7
8/// Dropout: randomly zeros elements with probability `p` during training.
9///
10/// During training, each element is independently set to zero with probability
11/// `p`, and the remaining elements are scaled by `1 / (1 - p)` to preserve the
12/// expected value. In eval mode, the input is passed through unchanged.
13pub struct Dropout {
14 p: f32,
15 training: bool,
16}
17
18impl Dropout {
19 /// Create a new Dropout with drop probability `p`.
20 pub fn new(p: f32) -> Self {
21 Self { p, training: true }
22 }
23
24 /// Set to training mode.
25 pub fn train(&mut self) {
26 self.training = true;
27 }
28
29 /// Set to eval mode (no dropout).
30 pub fn eval(&mut self) {
31 self.training = false;
32 }
33}
34
35impl Module for Dropout {
36 fn forward(&self, input: &Tensor) -> Result<Tensor> {
37 if !self.training || self.p == 0.0 {
38 return Ok(input.clone());
39 }
40 if self.p >= 1.0 {
41 return Tensor::zeros(input.shape(), input.dtype(), input.device());
42 }
43 // Note: the mask is generated eagerly (not as a graph op), so this
44 // allocates even if the result tensor is never evaluated. A graph-level
45 // random op would preserve full laziness but is not yet implemented.
46 let n = input.numel() as usize;
47 let mut rng = rand::rng();
48 let scale = 1.0 / (1.0 - self.p);
49 let mask: Vec<f32> = (0..n)
50 .map(|_| {
51 if rng.random::<f32>() >= self.p {
52 scale
53 } else {
54 0.0
55 }
56 })
57 .collect();
58 let mask_t =
59 Tensor::from_data_with_dtype(mask, input.shape(), input.dtype(), input.device())?;
60 input.mul(&mask_t)
61 }
62}