1use super::utils::*;
7
8use ndarray::prelude::*;
9use ndarray::{Data, DataMut, RemoveAxis, ScalarOperand};
10use num::complex::ComplexFloat;
11use num_traits::{Float, One, Zero};
12
13macro_rules! unary {
14 ($($name:ident::$call:ident($($rest:tt)*)),* $(,)?) => {
15 $(
16 unary!(@impl $name::$call($($rest)*));
17 )*
18 };
19 (@impl $name:ident::$call:ident(self)) => {
20 pub trait $name {
21 type Output;
22
23 fn $call(self) -> Self::Output;
24 }
25 };
26 (@impl $name:ident::$call:ident(&self)) => {
27 pub trait $name {
28 type Output;
29
30 fn $call(&self) -> Self::Output;
31 }
32 };
33}
34
35unary! {
36 Heavyside::heavyside(self),
37 LinearActivation::linear(self),
38 Sigmoid::sigmoid(&self),
39 Softmax::softmax(&self),
40 ReLU::relu(&self),
41 Tanh::tanh(&self),
42}
43
44pub trait SoftmaxAxis: Softmax {
45 fn softmax_axis(self, axis: usize) -> Self::Output;
46}
47
48pub trait NdActivate<A, D>
49where
50 A: ScalarOperand,
51 D: Dimension,
52{
53 type Data: Data<Elem = A>;
54
55 fn activate<B, F>(&self, f: F) -> Array<B, D>
56 where
57 F: Fn(A) -> B;
58
59 fn linear(&self) -> Array<A, D>
60 where
61 A: Clone,
62 {
63 self.activate(|x| x.clone())
64 }
65
66 fn linear_derivative(&self) -> Array<A, D>
67 where
68 A: One,
69 {
70 self.activate(|_| A::one())
71 }
72
73 fn heavyside(&self) -> Array<A, D>
74 where
75 A: One + PartialOrd + Zero,
76 {
77 self.activate(heavyside)
78 }
79
80 fn relu(&self) -> Array<A, D>
81 where
82 A: PartialOrd + Zero,
83 {
84 self.activate(relu)
85 }
86
87 fn relu_derivative(&self) -> Array<A, D>
88 where
89 A: PartialOrd + One + Zero,
90 {
91 self.activate(relu_derivative)
92 }
93 fn sigmoid(&self) -> Array<A, D>
95 where
96 A: Float,
97 {
98 self.activate(sigmoid)
99 }
100 fn sigmoid_derivative(&self) -> Array<A, D>
102 where
103 A: Float,
104 {
105 self.activate(sigmoid_derivative)
106 }
107
108 fn sigmoid_complex(&self) -> Array<A, D>
109 where
110 A: ComplexFloat,
111 {
112 self.activate(|x| A::one() / (A::one() + (-x).exp()))
113 }
114 fn sigmoid_complex_derivative(&self) -> Array<A, D>
115 where
116 A: ComplexFloat,
117 {
118 self.activate(|x| {
119 let s = A::one() / (A::one() + (-x).exp());
120 s * (A::one() - s)
121 })
122 }
123
124 fn softmax(&self) -> Array<A, D>
125 where
126 A: ComplexFloat,
127 {
128 let exp = self.activate(A::exp);
129 &exp / exp.sum()
130 }
131
132 fn softmax_axis(&self, axis: usize) -> Array<A, D>
133 where
134 A: ComplexFloat,
135 D: RemoveAxis,
136 {
137 let exp = self.activate(A::exp);
138 let axis = Axis(axis);
139 &exp / &exp.sum_axis(axis)
140 }
141
142 fn tanh(&self) -> Array<A, D>
143 where
144 A: ComplexFloat,
145 {
146 self.activate(A::tanh)
147 }
148
149 fn tanh_derivative(&self) -> Array<A, D>
150 where
151 A: ComplexFloat,
152 {
153 self.activate(|i| A::one() - A::tanh(i) * A::tanh(i))
154 }
155}
156
157pub trait NdActivateMut<A, D>
158where
159 A: ScalarOperand,
160 D: Dimension,
161{
162 type Data: DataMut<Elem = A>;
163
164 fn activate_inplace<'a, F>(&'a mut self, f: F)
165 where
166 A: 'a,
167 F: FnMut(A) -> A;
168}
169impl<A, S, D> NdActivate<A, D> for ArrayBase<S, D>
174where
175 A: ScalarOperand,
176 D: Dimension,
177 S: Data<Elem = A>,
178{
179 type Data = S;
180
181 fn activate<B, F>(&self, f: F) -> Array<B, D>
182 where
183 F: Fn(A) -> B,
184 {
185 self.mapv(f)
186 }
187}
188
189impl<A, S, D> NdActivateMut<A, D> for ArrayBase<S, D>
190where
191 A: ScalarOperand,
192 D: Dimension,
193 S: DataMut<Elem = A>,
194{
195 type Data = S;
196
197 fn activate_inplace<'a, F>(&'a mut self, f: F)
198 where
199 A: 'a,
200 F: FnMut(A) -> A,
201 {
202 self.mapv_inplace(f)
203 }
204}
205
206impl<'a, A, S, D> NdActivate<A, D> for &'a ArrayBase<S, D>
207where
208 A: ScalarOperand,
209 D: Dimension,
210 S: Data<Elem = A>,
211{
212 type Data = S;
213
214 fn activate<B, F>(&self, f: F) -> Array<B, D>
215 where
216 F: Fn(A) -> B,
217 {
218 self.mapv(f)
219 }
220}
221
222impl<'a, A, S, D> NdActivate<A, D> for &'a mut ArrayBase<S, D>
223where
224 A: ScalarOperand,
225 D: Dimension,
226 S: Data<Elem = A>,
227{
228 type Data = S;
229
230 fn activate<B, F>(&self, f: F) -> Array<B, D>
231 where
232 F: Fn(A) -> B,
233 {
234 self.mapv(f)
235 }
236}
237
238impl<'a, A, S, D> NdActivateMut<A, D> for &'a mut ArrayBase<S, D>
239where
240 A: ScalarOperand,
241 D: Dimension,
242 S: DataMut<Elem = A>,
243{
244 type Data = S;
245
246 fn activate_inplace<'b, F>(&'b mut self, f: F)
247 where
248 A: 'b,
249 F: FnMut(A) -> A,
250 {
251 self.mapv_inplace(f)
252 }
253}