ghostflow_nn/
dropout.rs

1//! Dropout regularization
2
3use ghostflow_core::Tensor;
4use crate::module::Module;
5use rand::Rng;
6
7/// Dropout layer
8pub struct Dropout {
9    p: f32,
10    training: bool,
11}
12
13impl Dropout {
14    pub fn new(p: f32) -> Self {
15        assert!(p >= 0.0 && p < 1.0, "Dropout probability must be in [0, 1)");
16        Dropout { p, training: true }
17    }
18}
19
20impl Default for Dropout {
21    fn default() -> Self {
22        Self::new(0.5)
23    }
24}
25
26impl Module for Dropout {
27    fn forward(&self, input: &Tensor) -> Tensor {
28        if !self.training || self.p == 0.0 {
29            return input.clone();
30        }
31
32        let data = input.data_f32();
33        let mut rng = rand::thread_rng();
34        let scale = 1.0 / (1.0 - self.p);
35
36        let output: Vec<f32> = data.iter()
37            .map(|&x| {
38                if rng.gen::<f32>() < self.p {
39                    0.0
40                } else {
41                    x * scale
42                }
43            })
44            .collect();
45
46        Tensor::from_slice(&output, input.dims()).unwrap()
47    }
48
49    fn parameters(&self) -> Vec<Tensor> {
50        vec![]
51    }
52
53    fn train(&mut self) {
54        self.training = true;
55    }
56
57    fn eval(&mut self) {
58        self.training = false;
59    }
60
61    fn is_training(&self) -> bool {
62        self.training
63    }
64}
65
66/// Dropout2d - drops entire channels
67pub struct Dropout2d {
68    p: f32,
69    training: bool,
70}
71
72impl Dropout2d {
73    pub fn new(p: f32) -> Self {
74        assert!(p >= 0.0 && p < 1.0);
75        Dropout2d { p, training: true }
76    }
77}
78
79impl Module for Dropout2d {
80    fn forward(&self, input: &Tensor) -> Tensor {
81        if !self.training || self.p == 0.0 {
82            return input.clone();
83        }
84
85        let dims = input.dims();
86        let batch = dims[0];
87        let channels = dims[1];
88        let spatial: usize = dims[2..].iter().product();
89
90        let data = input.data_f32();
91        let mut rng = rand::thread_rng();
92        let scale = 1.0 / (1.0 - self.p);
93
94        let mut output = data.clone();
95
96        for b in 0..batch {
97            for c in 0..channels {
98                if rng.gen::<f32>() < self.p {
99                    // Drop entire channel
100                    let start = (b * channels + c) * spatial;
101                    for i in 0..spatial {
102                        output[start + i] = 0.0;
103                    }
104                } else {
105                    // Scale
106                    let start = (b * channels + c) * spatial;
107                    for i in 0..spatial {
108                        output[start + i] *= scale;
109                    }
110                }
111            }
112        }
113
114        Tensor::from_slice(&output, dims).unwrap()
115    }
116
117    fn parameters(&self) -> Vec<Tensor> { vec![] }
118    fn train(&mut self) { self.training = true; }
119    fn eval(&mut self) { self.training = false; }
120    fn is_training(&self) -> bool { self.training }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_dropout_eval() {
129        let mut dropout = Dropout::new(0.5);
130        dropout.eval();
131        
132        let input = Tensor::ones(&[10, 10]);
133        let output = dropout.forward(&input);
134        
135        // In eval mode, output should equal input
136        assert_eq!(output.data_f32(), input.data_f32());
137    }
138
139    #[test]
140    fn test_dropout_train() {
141        let dropout = Dropout::new(0.5);
142        let input = Tensor::ones(&[100, 100]);
143        let output = dropout.forward(&input);
144        
145        // Some values should be zero
146        let zeros = output.data_f32().iter().filter(|&&x| x == 0.0).count();
147        assert!(zeros > 0);
148    }
149}