ducky_learn/
activations.rs1extern crate ndarray;
2
3use ndarray::prelude::*;
4
5pub fn relu_1d(input_array: Array1<f64>) -> Array1<f64> {
25 input_array.map(|value| value.max(0.))
26}
27
28pub fn deriv_relu_1d(input_array: Array1<f64>) -> Array1<f64> {
46 input_array.map(|value| (*value > 0f64) as i32 as f64)
47}
48
49pub fn softmax_1d(input_array: Array1<f64>) -> Array1<f64> {
71 let sum_exp_input_array = input_array.map(|value| value.exp()).sum();
72
73 input_array.map(|value| value.exp() / sum_exp_input_array)
74}
75
76#[cfg(test)]
77mod activations_tests {
78 use super::*;
79 use ndarray::arr1;
80
81 #[test]
82 fn relu_1d_1() {
83 let input_array = arr1(&[0., 1., -1., 0.01, -0.1]);
84
85 assert_eq!(relu_1d(input_array), arr1(&[0., 1., 0., 0.01, 0.]));
86 }
87
88 #[test]
89 fn relu_1d_2() {
90 let input_array = arr1(&[]);
91
92 assert_eq!(relu_1d(input_array), arr1(&[]));
93 }
94
95 #[test]
96 fn relu_1d_3() {
97 let input_array = arr1(&[-1.3456435325242, -32145324321., -132432888.]);
98
99 assert_eq!(relu_1d(input_array), arr1(&[0., 0., 0.]));
100 }
101
102 #[test]
103 fn deriv_relu_1d_1() {
104 let input_array = arr1(&[1.3456435325242, -32145324321., 132432888.]);
105 assert_eq!(deriv_relu_1d(input_array), arr1(&[1., 0., 1.]));
106 }
107
108 #[test]
109 fn deriv_relu_1d_2() {
110 let input_array = arr1(&[-1.3456435325242, -32145324321., 132432888.]);
111 assert_eq!(deriv_relu_1d(input_array), arr1(&[0., 0., 1.]));
112 }
113
114 #[test]
115 fn deriv_relu_1d_3() {
116 let input_array = arr1(&[]);
117 assert_eq!(deriv_relu_1d(input_array), arr1(&[]));
118 }
119
120 #[test]
121 fn softmax_1d_1() {
122 let input_array = arr1(&[0., 1., -1., 0.01, -0.1]);
123
124 assert_eq!(
125 softmax_1d(input_array),
126 arr1(&[
127 0.16663753690463112,
128 0.4529677885070323,
129 0.0613025239546613,
130 0.16831227199301688,
131 0.15077987864065834
132 ])
133 );
134 }
135
136 #[test]
137 fn softmax_1d_2() {
138 let input_array = arr1(&[]);
139
140 assert_eq!(softmax_1d(input_array), arr1(&[]));
141 }
142
143 #[test]
144 fn softmax_1d_3() {
145 let input_array = arr1(&[-0.3456435325242, 232., -888.]);
146
147 assert_eq!(
148 softmax_1d(input_array),
149 arr1(&[1.2404210269803915e-101, 1.0, 0.0])
150 );
151 }
152}