cogent/
activations.rs

1use serde::{Deserialize, Serialize};
2
3use arrayfire::{
4    and, constant, exp, gt, matmul, max, maxof, pow, sigmoid, sum, tanh, Array, Dim4, MatProp,
5};
6
7/// Defines the activation of a layer.
8#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
9pub enum Activation {
10    /// Sigmoid activation functions.
11    ///
12    /// $ A(z)=\frac{1}{1+e^{-z}} $
13    Sigmoid,
14    /// Tanh activation functions.
15    ///
16    /// $ A(z)=\frac{2}{1+e^{-2z}}-1 $
17    Tanh,
18    /// Softmax activation function.
19    ///
20    /// $ A(\begin{bmatrix}z_1,\dots,z_k\end{bmatrix})=\begin{bmatrix}\frac{e^{z_1}}{\Sigma_{i=1}^k e^{z_i}} & \dots &\frac{e^{z_k}}{\Sigma_{i=1}^k e^{z_i}}\end{bmatrix} $
21    Softmax,
22    /// ReLU activation function.
23    ///
24    /// $ A(z)=max(z,0) $
25    ReLU, // Name it 'ReLU' or 'Relu'?
26}
27impl Activation {
28    /// Computes activations given inputs (A(z)).
29    pub fn run(&self, z: &Array<f32>) -> Array<f32> {
30        return match self {
31            Self::Sigmoid => sigmoid(z),
32            Self::Tanh => tanh(z),
33            Self::Softmax => Activation::softmax(z),
34            Self::ReLU => Activation::relu(z),
35        };
36    }
37    /// Derivative w.r.t. layer input (∂a/∂z).
38    pub fn derivative(&self, z: &Array<f32>) -> Array<f32> {
39        // What should we name the derivative functions?
40        return match self {
41            Self::Sigmoid => sigmoid_derivative(z),
42            Self::Tanh => tanh_derivative(z),
43            Self::Softmax => softmax_derivative(z),
44            Self::ReLU => relu_derivative(z),
45        };
46
47        // Derivative of sigmoid
48        // s' = s(1-s)
49        fn sigmoid_derivative(z: &Array<f32>) -> Array<f32> {
50            let s = sigmoid(z);
51            return s.clone() * (1f32 - s); // TODO Can we remove the clone here?
52        }
53        // Derivative of sigmoid
54        // t' = 1-t^2
55        fn tanh_derivative(z: &Array<f32>) -> Array<f32> {
56            1 - pow(&tanh(z), &2, false)
57        }
58        // Derivative of softmax
59        // e^z * (sum of other inputs e^input) / (sum of all inputs e^input)^2 = e^z * (exp_sum-e^z) / (exp_sum)^2
60        fn softmax_derivative(z: &Array<f32>) -> Array<f32> {
61            // e^z
62            let exponents = exp(z);
63            // Gets sum of each example (column)
64            let sums = sum(&exponents, 0);
65
66            // This is done since `add(&a,&b,true)` is very slow.
67            let ones = constant(1f32, Dim4::new(&[z.dims().get()[0], 1, 1, 1]));
68            let sums_matrix = matmul(&ones, &sums, MatProp::NONE, MatProp::NONE);
69
70            // exp_sum-e^z
71            let sums_sub = sums_matrix - &exponents;
72
73            // (exp_sum)^2
74            // Gets squared sum of each example
75            let sqrd_sums = pow(&sums, &2, false); // is this better than `&sums*&sums`?
76
77            // TODO Is it more efficient to do this matrix multiplication before or after squaring?
78            // This is done since `div(&a,&b,true)` is very slow.
79            let sqrd_sums_matrix = matmul(&ones, &sqrd_sums, MatProp::NONE, MatProp::NONE);
80
81            // e^z * (exp_sum-e^z) / (exp_sum)^2
82            let derivatives = exponents * sums_sub / sqrd_sums_matrix;
83
84            return derivatives;
85        }
86        //Deritvative of ReLU
87        // ReLU(z)/1 = if >0 1 else 0
88        fn relu_derivative(z: &Array<f32>) -> Array<f32> {
89            // return Activation::relu(z) / z;
90            // Follow code replaces the above line.
91            // Above line replaced becuase it is prone to floating point error leading to f32:NAN.
92            // Similar performance.
93            let gt = gt(z, &0f32, false);
94            return and(z, &gt, false);
95        }
96    }
97    // TODO Make this better
98    // Applies softmax activation
99    fn softmax(y: &Array<f32>) -> Array<f32> {
100        let ones = constant(1f32, Dim4::new(&[y.dims().get()[0], 1, 1, 1]));
101
102        // Subtracts example max output from all example outputs.
103        //  Allowing softmax to handle large values in y.
104        // ------------------------------------------------
105        // Gets max values in each example
106        let max_axis_vals = max(&y, 0);
107        // Matrix where each value is example max
108        let max_axis_vals_matrix = matmul(&ones, &max_axis_vals, MatProp::NONE, MatProp::NONE);
109        // All values minus there example maxes
110        let max_reduced = y - max_axis_vals_matrix;
111
112        // Applies softmax
113        // ------------------------------------------------
114        // Apply e^(x) to every value in matrix
115        let exp_matrix = exp(&max_reduced);
116        // Calculates sums of examples
117        let row_sums = sum(&exp_matrix, 0);
118        // Matrix where each value is example sum
119        let row_sums_matrix = matmul(&ones, &row_sums, MatProp::NONE, MatProp::NONE);
120        // Divides each value by example sum
121        let softmax = exp_matrix / row_sums_matrix; // TODO Could this div be done using batch operation with `arrayfire::div(...)` using `row_sums`?
122
123        return softmax;
124    }
125    // Applies ReLU activation
126    fn relu(y: &Array<f32>) -> Array<f32> {
127        let zeros = constant(0f32, y.dims());
128        return maxof(y, &zeros, false);
129    }
130}