concision_core/utils/
dropout.rs1pub trait DropOut {
9 type Output;
10
11 fn dropout(&self, p: f64) -> Self::Output;
12}
13
14#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
22#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
23pub struct Dropout {
24 pub(crate) p: f64,
25}
26
27impl Dropout {
28 pub fn new(p: f64) -> Self {
29 Self { p }
30 }
31
32 pub fn scale(&self) -> f64 {
33 (1f64 - self.p).recip()
34 }
35
36 pub fn forward<U>(&self, input: &U) -> Option<<U as DropOut>::Output>
37 where
38 U: DropOut,
39 {
40 Some(input.dropout(self.p))
41 }
42}
43
44impl Default for Dropout {
45 fn default() -> Self {
46 Self::new(0.5)
47 }
48}
49
50#[cfg(feature = "rand")]
51mod impl_rand {
52 use super::*;
53 use concision_init::NdRandom;
54 use concision_traits::Forward;
55 use ndarray::{Array, ArrayBase, DataOwned, Dimension, ScalarOperand};
56 use num_traits::Num;
57
58 impl<A, S, D> DropOut for ArrayBase<S, D, A>
59 where
60 A: Num + ScalarOperand,
61 D: Dimension,
62 S: DataOwned<Elem = A>,
63 {
64 type Output = Array<A, D>;
65
66 fn dropout(&self, p: f64) -> Self::Output {
67 let dim = self.dim();
68 let mask: Array<bool, D> = Array::bernoulli(dim, p).expect("Failed to create mask");
70 let mask = mask.mapv(|x| if x { A::zero() } else { A::one() });
71
72 self.to_owned() * mask
74 }
75 }
76
77 impl<U> Forward<U> for Dropout
78 where
79 U: DropOut,
80 {
81 type Output = <U as DropOut>::Output;
82
83 fn forward(&self, input: &U) -> Self::Output {
84 input.dropout(self.p)
85 }
86 }
87}
88
89#[cfg(all(test, feature = "rand"))]
90mod tests {
91 use super::*;
92 use ndarray::Array2;
93
94 #[test]
95 fn test_dropout() {
96 let shape = (512, 2048);
97 let arr = Array2::<f64>::ones(shape);
98 let dropout = Dropout::new(0.5);
99 let out = dropout.forward(&arr).expect("Dropout forward pass failed");
100
101 assert!(arr.iter().all(|&x| x == 1.0));
102 assert!(out.iter().any(|x| x == &0f64));
103 }
104}