matrix_oxide/
activation.rs

1use crate::Matrix;
2use std::f64::consts::PI;
3use std::ops::{Add, Mul};
4
5impl<T> Matrix<T>
6where
7    T: PartialOrd + Default + Copy + Mul<Output = T>,
8{
9    /// Apply the ReLU activation function onto a `Matrix`
10    pub fn relu(&self) -> Matrix<T>
11    where
12        T: PartialOrd + Default + Copy,
13    {
14        let data: Vec<T> = self
15            .data
16            .iter()
17            .map(|&x| if x < T::default() { T::default() } else { x })
18            .collect();
19
20        Matrix {
21            data,
22            row_size: self.row_size,
23            col_size: self.col_size,
24        }
25    }
26
27    /// Apply the Leaky ReLU activation function onto a `Matrix`
28    pub fn leaky_relu(&self, alpha: T) -> Matrix<T>
29    where
30        T: PartialOrd + Default + Copy + Mul<Output = T>,
31    {
32        let data: Vec<T> = self
33            .data
34            .iter()
35            .map(|&x| if x < T::default() { x * alpha } else { x })
36            .collect();
37
38        Matrix {
39            data,
40            row_size: self.row_size,
41            col_size: self.col_size,
42        }
43    }
44
45    /// Apply backward pass for the ReLU activation function onto a `Matrix`
46    pub fn relu_backward(&self) -> Matrix<T>
47    where
48        T: Copy + PartialOrd + Default + Add<T, Output = T> + Mul<T, Output = T> + From<u8>,
49    {
50        let data: Vec<T> = self
51            .data
52            .iter()
53            .map(|&x| {
54                if x >= T::default() {
55                    T::default() + T::from(1u8)
56                } else {
57                    T::default()
58                }
59            })
60            .collect();
61
62        Matrix {
63            data,
64            row_size: self.row_size,
65            col_size: self.col_size,
66        }
67    }
68
69    /// Apply the GeLU activation function onto a `Matrix`
70    /// NOTE: Smoother (near 0) than ReLU & potential for regularization effects.
71    pub fn gelu(&self) -> Matrix<T>
72    where
73        T: Copy + PartialOrd + Default + From<f64> + Into<f64>,
74    {
75        let data: Vec<T> = self
76            .data
77            .iter()
78            .map(|&x| {
79                let x_f64: f64 = x.into();
80                let x_gelu = 0.5
81                    * x_f64
82                    * (1.0 + ((2.0 / PI).sqrt() * (x_f64 + 0.04715 * x_f64.powi(3))).tanh());
83                T::from(x_gelu)
84            })
85            .collect();
86
87        Matrix {
88            data,
89            row_size: self.row_size,
90            col_size: self.col_size,
91        }
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    /// Verify each element is correctly ReLU'd
101    fn test_relu() {
102        let matrix = Matrix {
103            data: vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0],
104            row_size: 3,
105            col_size: 3,
106        };
107
108        let expected = vec![1.0, 0.0, 3.0, 0.0, 5.0, 0.0, 7.0, 0.0, 9.0];
109        let result = matrix.relu();
110
111        assert_eq!(result.data, expected);
112    }
113
114    #[test]
115    /// Verify each element is correctly Leaky ReLU'd
116    fn test_leaky_relu() {
117        let matrix = Matrix {
118            data: vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0],
119            row_size: 3,
120            col_size: 3,
121        };
122
123        let alpha = 0.1;
124
125        let expected = vec![
126            1.0,
127            -0.2,
128            3.0,
129            -0.4,
130            5.0,
131            -0.6000000000000001,
132            7.0,
133            -0.8,
134            9.0,
135        ];
136        let result = matrix.leaky_relu(alpha);
137
138        assert_eq!(result.data, expected);
139    }
140
141    #[test]
142    /// Verify each element in the gradient matrix
143    fn test_relu_backward() {
144        let matrix = Matrix {
145            data: vec![1, -2, 3, -4, 5, -6, 7, -8, 9],
146            row_size: 3,
147            col_size: 3,
148        };
149
150        let expected = vec![1, 0, 1, 0, 1, 0, 1, 0, 1];
151        let result = matrix.relu_backward();
152
153        assert_eq!(result.data, expected);
154    }
155
156    #[test]
157    /// Verify each element is correctly GeLU'd
158    fn test_gelu() {
159        let matrix = Matrix {
160            data: vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0],
161            row_size: 3,
162            col_size: 3,
163        };
164
165        let expected = [
166            0.841344746,
167            -0.045500263,
168            2.9963625,
169            -0.0003058,
170            4.9998675,
171            -0.0000009,
172            6.999999998,
173            -0.000000001,
174            8.9999999,
175        ];
176        let result = matrix.gelu();
177
178        for (res, exp) in result.data.iter().zip(expected.iter()) {
179            // NOTE: due to inaccuracy in floating point arithmetic this has a tolerance of 0.01
180            assert!((Into::<f64>::into(*res) - exp).abs() < 1e-2);
181        }
182    }
183}