ice_nine/
layer.rs

1use crate::{HasGrad, Layer};
2use ndarray::Array2;
3
4pub struct Relu {}
5
6impl HasGrad<f64> for Relu {
7    /// Rectified Linear Unit
8    fn f(&self, x: f64) -> f64 {
9        if x < 0.0 {
10            0.0
11        } else {
12            x
13        }
14    }
15
16    /// Derivative of Relu
17    fn d_f(&self, x: f64) -> f64 {
18        if x < 0.0 {
19            0.0
20        } else {
21            1.0
22        }
23    }
24}
25
26impl Relu {
27    /// Make new Relu layer
28    pub fn new_layer(weights: Array2<f64>) -> Layer {
29        let gradients = Array2::zeros(weights.raw_dim());
30        Layer {
31            activation: Box::new(Relu {}),
32            weights,
33            gradients,
34        }
35    }
36}
37
38pub struct LeakyRelu {
39    pub slope: f64,
40}
41
42impl HasGrad<f64> for LeakyRelu {
43    /// Leaky Rectified Linear Unit
44    fn f(&self, x: f64) -> f64 {
45        if x < 0.0 {
46            x * self.slope
47        } else {
48            x
49        }
50    }
51
52    /// Derivative of LeakyRelu
53    fn d_f(&self, x: f64) -> f64 {
54        if x < 0.0 {
55            self.slope
56        } else {
57            1.0
58        }
59    }
60}
61
62impl LeakyRelu {
63    /// Make new LeakyRelu layer
64    pub fn new_layer(weights: Array2<f64>, slope: f64) -> Layer {
65        let gradients = Array2::zeros(weights.raw_dim());
66        Layer {
67            activation: Box::new(LeakyRelu { slope }),
68            weights,
69            gradients,
70        }
71    }
72}
73
74pub struct Linear {}
75
76impl HasGrad<f64> for Linear {
77    /// Identity activation
78    fn f(&self, x: f64) -> f64 {
79        x
80    }
81
82    /// Derivative of identity
83    fn d_f(&self, _x: f64) -> f64 {
84        1.0
85    }
86}
87
88impl Linear {
89    /// Make new linear layer
90    pub fn new_layer(weights: Array2<f64>) -> Layer {
91        let gradients = Array2::zeros(weights.raw_dim());
92        Layer {
93            activation: Box::new(Linear {}),
94            weights,
95            gradients,
96        }
97    }
98}