1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use std::cmp::Ordering;

use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use crate::linalg::{Matrix, MatrixTrait, Scalar};
use crate::network::NetworkLayer;
use crate::{activation::ActivationLayer, layer::dense_layer::DenseLayer, layer::Layer};

use super::{DropoutLayer, LearnableLayer, ParameterableLayer};

#[derive(Debug)]
pub struct FullLayer {
    dense: DenseLayer,
    activation: ActivationLayer,
    // dropout resources : https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf
    dropout_enabled: bool,
    dropout_rate: Option<Scalar>,
    mask: Option<Matrix>,
}

impl FullLayer {
    pub fn new(dense: DenseLayer, activation: ActivationLayer, dropout: Option<Scalar>) -> Self {
        Self {
            dense,
            activation,
            dropout_rate: dropout,
            dropout_enabled: false,
            mask: None,
        }
    }

    fn generate_dropout_mask(&mut self, output_shape: (usize, usize)) -> Option<(Matrix, Scalar)> {
        if let Some(dropout_rate) = self.dropout_rate {
            let mut rng = SmallRng::from_entropy();
            let dropout_mask = Matrix::from_fn(output_shape.0, output_shape.1, |_, _| {
                if rng
                    .gen_range((0.0 as Scalar)..(1.0 as Scalar))
                    .total_cmp(&self.dropout_rate.unwrap())
                    == Ordering::Greater
                {
                    1.0
                } else {
                    0.0
                }
            });
            Some((dropout_mask, dropout_rate))
        } else {
            None
        }
    }
}

impl Layer for FullLayer {
    fn forward(&mut self, mut input: Matrix) -> Matrix {
        let output = if self.dropout_enabled {
            if let Some((mask, _)) = self.generate_dropout_mask(input.dim()) {
                input = input.component_mul(&mask);
                self.mask = Some(mask);
            };
            self.dense.forward(input)
        } else {
            if let Some(dropout_rate) = self.dropout_rate {
                self.dense.weights = self.dense.weights.scalar_mul(1.0 - dropout_rate);
                let output = self.dense.forward(input);
                self.dense.weights = self.dense.weights.scalar_div(1.0 - dropout_rate);
                output
            } else {
                self.dense.forward(input)
            }
        };

        self.activation.forward(output)
    }

    fn backward(&mut self, epoch: usize, output_gradient: Matrix) -> Matrix {
        let activation_input_gradient = self.activation.backward(epoch, output_gradient);
        let input_gradient = self.dense.backward(epoch, activation_input_gradient);

        if let Some(mask) = &self.mask {
            input_gradient.component_mul(&mask)
        } else {
            input_gradient
        }
    }
}

impl NetworkLayer for FullLayer {}

impl ParameterableLayer for FullLayer {
    fn as_learnable_layer(&self) -> Option<&dyn LearnableLayer> {
        Some(self)
    }

    fn as_learnable_layer_mut(&mut self) -> Option<&mut dyn LearnableLayer> {
        Some(self)
    }

    fn as_dropout_layer(&mut self) -> Option<&mut dyn DropoutLayer> {
        Some(self)
    }
}

impl LearnableLayer for FullLayer {
    // returns a matrix of the (jxi) weights and the final column being the (j) biases
    fn get_learnable_parameters(&self) -> Vec<Vec<Scalar>> {
        self.dense.get_learnable_parameters()
    }

    // takes a matrix of the (jxi) weights and the final column being the (j) biases
    fn set_learnable_parameters(&mut self, params_matrix: &Vec<Vec<Scalar>>) {
        self.dense.set_learnable_parameters(params_matrix)
    }
}

impl DropoutLayer for FullLayer {
    fn enable_dropout(&mut self) {
        self.dropout_enabled = true;
    }

    fn disable_dropout(&mut self) {
        self.dropout_enabled = false;
    }
}