cogent/activations.rs
1use serde::{Deserialize, Serialize};
2
3use arrayfire::{
4 and, constant, exp, gt, matmul, max, maxof, pow, sigmoid, sum, tanh, Array, Dim4, MatProp,
5};
6
7/// Defines the activation of a layer.
8#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
9pub enum Activation {
10 /// Sigmoid activation functions.
11 ///
12 /// $ A(z)=\frac{1}{1+e^{-z}} $
13 Sigmoid,
14 /// Tanh activation functions.
15 ///
16 /// $ A(z)=\frac{2}{1+e^{-2z}}-1 $
17 Tanh,
18 /// Softmax activation function.
19 ///
20 /// $ A(\begin{bmatrix}z_1,\dots,z_k\end{bmatrix})=\begin{bmatrix}\frac{e^{z_1}}{\Sigma_{i=1}^k e^{z_i}} & \dots &\frac{e^{z_k}}{\Sigma_{i=1}^k e^{z_i}}\end{bmatrix} $
21 Softmax,
22 /// ReLU activation function.
23 ///
24 /// $ A(z)=max(z,0) $
25 ReLU, // Name it 'ReLU' or 'Relu'?
26}
27impl Activation {
28 /// Computes activations given inputs (A(z)).
29 pub fn run(&self, z: &Array<f32>) -> Array<f32> {
30 return match self {
31 Self::Sigmoid => sigmoid(z),
32 Self::Tanh => tanh(z),
33 Self::Softmax => Activation::softmax(z),
34 Self::ReLU => Activation::relu(z),
35 };
36 }
37 /// Derivative w.r.t. layer input (∂a/∂z).
38 pub fn derivative(&self, z: &Array<f32>) -> Array<f32> {
39 // What should we name the derivative functions?
40 return match self {
41 Self::Sigmoid => sigmoid_derivative(z),
42 Self::Tanh => tanh_derivative(z),
43 Self::Softmax => softmax_derivative(z),
44 Self::ReLU => relu_derivative(z),
45 };
46
47 // Derivative of sigmoid
48 // s' = s(1-s)
49 fn sigmoid_derivative(z: &Array<f32>) -> Array<f32> {
50 let s = sigmoid(z);
51 return s.clone() * (1f32 - s); // TODO Can we remove the clone here?
52 }
53 // Derivative of sigmoid
54 // t' = 1-t^2
55 fn tanh_derivative(z: &Array<f32>) -> Array<f32> {
56 1 - pow(&tanh(z), &2, false)
57 }
58 // Derivative of softmax
59 // e^z * (sum of other inputs e^input) / (sum of all inputs e^input)^2 = e^z * (exp_sum-e^z) / (exp_sum)^2
60 fn softmax_derivative(z: &Array<f32>) -> Array<f32> {
61 // e^z
62 let exponents = exp(z);
63 // Gets sum of each example (column)
64 let sums = sum(&exponents, 0);
65
66 // This is done since `add(&a,&b,true)` is very slow.
67 let ones = constant(1f32, Dim4::new(&[z.dims().get()[0], 1, 1, 1]));
68 let sums_matrix = matmul(&ones, &sums, MatProp::NONE, MatProp::NONE);
69
70 // exp_sum-e^z
71 let sums_sub = sums_matrix - &exponents;
72
73 // (exp_sum)^2
74 // Gets squared sum of each example
75 let sqrd_sums = pow(&sums, &2, false); // is this better than `&sums*&sums`?
76
77 // TODO Is it more efficient to do this matrix multiplication before or after squaring?
78 // This is done since `div(&a,&b,true)` is very slow.
79 let sqrd_sums_matrix = matmul(&ones, &sqrd_sums, MatProp::NONE, MatProp::NONE);
80
81 // e^z * (exp_sum-e^z) / (exp_sum)^2
82 let derivatives = exponents * sums_sub / sqrd_sums_matrix;
83
84 return derivatives;
85 }
86 //Deritvative of ReLU
87 // ReLU(z)/1 = if >0 1 else 0
88 fn relu_derivative(z: &Array<f32>) -> Array<f32> {
89 // return Activation::relu(z) / z;
90 // Follow code replaces the above line.
91 // Above line replaced becuase it is prone to floating point error leading to f32:NAN.
92 // Similar performance.
93 let gt = gt(z, &0f32, false);
94 return and(z, >, false);
95 }
96 }
97 // TODO Make this better
98 // Applies softmax activation
99 fn softmax(y: &Array<f32>) -> Array<f32> {
100 let ones = constant(1f32, Dim4::new(&[y.dims().get()[0], 1, 1, 1]));
101
102 // Subtracts example max output from all example outputs.
103 // Allowing softmax to handle large values in y.
104 // ------------------------------------------------
105 // Gets max values in each example
106 let max_axis_vals = max(&y, 0);
107 // Matrix where each value is example max
108 let max_axis_vals_matrix = matmul(&ones, &max_axis_vals, MatProp::NONE, MatProp::NONE);
109 // All values minus there example maxes
110 let max_reduced = y - max_axis_vals_matrix;
111
112 // Applies softmax
113 // ------------------------------------------------
114 // Apply e^(x) to every value in matrix
115 let exp_matrix = exp(&max_reduced);
116 // Calculates sums of examples
117 let row_sums = sum(&exp_matrix, 0);
118 // Matrix where each value is example sum
119 let row_sums_matrix = matmul(&ones, &row_sums, MatProp::NONE, MatProp::NONE);
120 // Divides each value by example sum
121 let softmax = exp_matrix / row_sums_matrix; // TODO Could this div be done using batch operation with `arrayfire::div(...)` using `row_sums`?
122
123 return softmax;
124 }
125 // Applies ReLU activation
126 fn relu(y: &Array<f32>) -> Array<f32> {
127 let zeros = constant(0f32, y.dims());
128 return maxof(y, &zeros, false);
129 }
130}