1use crate as burn;
2
3use crate::config::Config;
4use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
5use crate::tensor::backend::Backend;
6use crate::tensor::{Distribution, Tensor};
7
8#[derive(Config, Debug)]
10pub struct DropoutConfig {
11 pub prob: f64,
13}
14
15#[derive(Module, Clone, Debug)]
24#[module(custom_display)]
25pub struct Dropout {
26 pub prob: f64,
28}
29
30impl DropoutConfig {
31 pub fn init(&self) -> Dropout {
33 if self.prob < 0.0 || self.prob > 1.0 {
34 panic!(
35 "Dropout probability should be between 0 and 1, but got {}",
36 self.prob
37 );
38 }
39 Dropout { prob: self.prob }
40 }
41}
42
43impl Dropout {
44 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
53 if !B::ad_enabled() || self.prob == 0.0 {
54 return input;
55 }
56
57 let prob_keep = 1.0 - self.prob;
58 let random = input.random_like(Distribution::Bernoulli(prob_keep));
59 let x = input * random;
60
61 x * (1.0 / prob_keep)
62 }
63}
64
65impl ModuleDisplay for Dropout {
66 fn custom_settings(&self) -> Option<DisplaySettings> {
67 DisplaySettings::new()
68 .with_new_line_after_attribute(false)
69 .optional()
70 }
71
72 fn custom_content(&self, content: Content) -> Option<Content> {
73 content.add("prob", &self.prob).optional()
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use crate::tensor::Shape;
81
82 #[cfg(feature = "std")]
83 use crate::{TestAutodiffBackend, TestBackend};
84
85 #[cfg(not(feature = "std"))]
86 use crate::TestBackend;
87
88 #[cfg(feature = "std")]
89 #[test]
90 fn with_ad_backend_should_mark_input() {
91 let tensor =
92 Tensor::<TestAutodiffBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
93 let dropout = DropoutConfig::new(0.5).init();
94
95 let output = dropout.forward(tensor.clone());
96
97 assert_ne!(tensor.to_data(), output.to_data());
98 }
99
100 #[test]
101 fn without_ad_backend_should_not_change_input() {
102 let tensor = Tensor::<TestBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
103 let dropout = DropoutConfig::new(0.5).init();
104
105 let output = dropout.forward(tensor.clone());
106
107 assert_eq!(tensor.to_data(), output.to_data());
108 }
109
110 #[test]
111 fn display() {
112 let config = DropoutConfig::new(0.5);
113 let layer = config.init();
114
115 assert_eq!(alloc::format!("{layer}"), "Dropout {prob: 0.5}");
116 }
117
118 #[test]
119 #[should_panic = "Dropout probability should be between 0 and 1,"]
120 fn dropout_prob_invalid() {
121 let config = DropoutConfig::new(-10.);
122 let _layer = config.init();
123 }
124}