concision_core/utils/
dropout.rs

1/*
2    Appellation: dropout <module>
3    Created At: 2025.11.26:17:01:56
4    Contrib: @FL03
5*/
6
7/// [Dropout] randomly zeroizes elements with a given probability (`p`).
8pub trait DropOut {
9    type Output;
10
11    fn dropout(&self, p: f64) -> Self::Output;
12}
13
14/// The [Dropout] layer is randomly zeroizes inputs with a given probability (`p`).
15/// This regularization technique is often used to prevent overfitting.
16///
17///
18/// ### Config
19///
20/// - (p) Probability of dropping an element
21#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
22#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
23pub struct Dropout {
24    pub(crate) p: f64,
25}
26
27impl Dropout {
28    pub fn new(p: f64) -> Self {
29        Self { p }
30    }
31
32    pub fn scale(&self) -> f64 {
33        (1f64 - self.p).recip()
34    }
35
36    pub fn forward<U>(&self, input: &U) -> Option<<U as DropOut>::Output>
37    where
38        U: DropOut,
39    {
40        Some(input.dropout(self.p))
41    }
42}
43
44impl Default for Dropout {
45    fn default() -> Self {
46        Self::new(0.5)
47    }
48}
49
50#[cfg(feature = "rand")]
51mod impl_rand {
52    use super::*;
53    use concision_init::NdRandom;
54    use concision_traits::Forward;
55    use ndarray::{Array, ArrayBase, DataOwned, Dimension, ScalarOperand};
56    use num_traits::Num;
57
58    impl<A, S, D> DropOut for ArrayBase<S, D, A>
59    where
60        A: Num + ScalarOperand,
61        D: Dimension,
62        S: DataOwned<Elem = A>,
63    {
64        type Output = Array<A, D>;
65
66        fn dropout(&self, p: f64) -> Self::Output {
67            let dim = self.dim();
68            // Create a mask of the same shape as the input array
69            let mask: Array<bool, D> = Array::bernoulli(dim, p).expect("Failed to create mask");
70            let mask = mask.mapv(|x| if x { A::zero() } else { A::one() });
71
72            // Element-wise multiplication to apply dropout
73            self.to_owned() * mask
74        }
75    }
76
77    impl<U> Forward<U> for Dropout
78    where
79        U: DropOut,
80    {
81        type Output = <U as DropOut>::Output;
82
83        fn forward(&self, input: &U) -> Self::Output {
84            input.dropout(self.p)
85        }
86    }
87}
88
89#[cfg(all(test, feature = "rand"))]
90mod tests {
91    use super::*;
92    use ndarray::Array2;
93
94    #[test]
95    fn test_dropout() {
96        let shape = (512, 2048);
97        let arr = Array2::<f64>::ones(shape);
98        let dropout = Dropout::new(0.5);
99        let out = dropout.forward(&arr).expect("Dropout forward pass failed");
100
101        assert!(arr.iter().all(|&x| x == 1.0));
102        assert!(out.iter().any(|x| x == &0f64));
103    }
104}