1use ndarray::prelude::*;
7use ndarray::{Data, DataMut, RemoveAxis, ScalarOperand};
8use num::complex::ComplexFloat;
9
10macro_rules! unary {
11 ($($name:ident::$call:ident($($rest:tt)*)),* $(,)?) => {
12 $(
13 unary!(@impl $name::$call($($rest)*));
14 )*
15 };
16
17 (@impl $name:ident::$call:ident(self)) => {
18 paste::paste! {
19 pub trait $name {
20 type Output;
21
22 fn $call(self) -> Self::Output;
23
24 fn [<$call _derivative>](self) -> Self::Output;
25 }
26 }
27
28 };
29 (@impl $name:ident::$call:ident(&self)) => {
30 paste::paste! {
31 pub trait $name {
32 type Output;
33
34 fn $call(&self) -> Self::Output;
35
36 fn [<$call _derivative>](&self) -> Self::Output;
37 }
38 }
39 };
40}
41
42unary! {
43 Heavyside::heavyside(self),
44 LinearActivation::linear(self),
45 Sigmoid::sigmoid(&self),
46 Softmax::softmax(&self),
47 ReLU::relu(&self),
48 Tanh::tanh(&self),
49}
50
51pub trait SoftmaxAxis: Softmax {
52 fn softmax_axis(self, axis: usize) -> Self::Output;
53}
54
55pub trait NdActivate<A, D>
56where
57 A: ScalarOperand,
58 D: Dimension,
59{
60 type Data: Data<Elem = A>;
61
62 fn activate<B, F>(&self, f: F) -> Array<B, D>
63 where
64 F: Fn(A) -> B;
65
66 fn linear(&self) -> Array<A::Output, D>
67 where
68 A: LinearActivation,
69 {
70 self.activate(|x| x.linear())
71 }
72
73 fn linear_derivative(&self) -> Array<A::Output, D>
74 where
75 A: LinearActivation,
76 {
77 self.activate(|x| x.linear_derivative())
78 }
79
80 fn heavyside(&self) -> Array<A::Output, D>
81 where
82 A: Heavyside,
83 {
84 self.activate(|x| x.heavyside())
85 }
86
87 fn heavyside_derivative(&self) -> Array<A::Output, D>
88 where
89 A: Heavyside,
90 {
91 self.activate(|x| x.heavyside_derivative())
92 }
93
94 fn relu(&self) -> Array<A::Output, D>
95 where
96 A: ReLU,
97 {
98 self.activate(|x| x.relu())
99 }
100
101 fn relu_derivative(&self) -> Array<A::Output, D>
102 where
103 A: ReLU,
104 {
105 self.activate(|x| x.relu_derivative())
106 }
107 fn sigmoid(&self) -> Array<A::Output, D>
109 where
110 A: Sigmoid,
111 {
112 self.activate(|x| x.sigmoid())
113 }
114 fn sigmoid_derivative(&self) -> Array<A::Output, D>
116 where
117 A: Sigmoid,
118 {
119 self.activate(|x| x.sigmoid_derivative())
120 }
121 fn softmax(&self) -> Array<A, D>
125 where
126 A: ComplexFloat,
127 {
128 let exp = self.activate(A::exp);
129 &exp / exp.sum()
130 }
131
132 fn softmax_axis(&self, axis: usize) -> Array<A, D>
133 where
134 A: ComplexFloat,
135 D: RemoveAxis,
136 {
137 let exp = self.activate(A::exp);
138 let axis = Axis(axis);
139 &exp / &exp.sum_axis(axis)
140 }
141
142 fn tanh(&self) -> Array<A::Output, D>
143 where
144 A: Tanh,
145 {
146 self.activate(|x| x.tanh())
147 }
148
149 fn tanh_derivative(&self) -> Array<A::Output, D>
150 where
151 A: Tanh,
152 {
153 self.activate(|x| x.tanh_derivative())
154 }
155
156 fn sigmoid_complex(&self) -> Array<A, D>
157 where
158 A: ComplexFloat,
159 {
160 self.activate(|x| A::one() / (A::one() + (-x).exp()))
161 }
162
163 fn sigmoid_complex_derivative(&self) -> Array<A, D>
164 where
165 A: ComplexFloat,
166 {
167 self.activate(|x| {
168 let s = A::one() / (A::one() + (-x).exp());
169 s * (A::one() - s)
170 })
171 }
172
173 fn tanh_complex(&self) -> Array<A, D>
174 where
175 A: ComplexFloat,
176 {
177 self.activate(|x| x.tanh())
178 }
179 fn tanh_complex_derivative(&self) -> Array<A, D>
180 where
181 A: ComplexFloat,
182 {
183 self.activate(|x| {
184 let s = x.tanh();
185 A::one() - s * s
186 })
187 }
188}
189
190pub trait NdActivateMut<A, D>
191where
192 A: ScalarOperand,
193 D: Dimension,
194{
195 type Data: DataMut<Elem = A>;
196
197 fn activate_inplace<'a, F>(&'a mut self, f: F)
198 where
199 A: 'a,
200 F: FnMut(A) -> A;
201}
202impl<A, S, D> NdActivate<A, D> for ArrayBase<S, D>
207where
208 A: ScalarOperand,
209 D: Dimension,
210 S: Data<Elem = A>,
211{
212 type Data = S;
213
214 fn activate<B, F>(&self, f: F) -> Array<B, D>
215 where
216 F: Fn(A) -> B,
217 {
218 self.mapv(f)
219 }
220}
221
222impl<A, S, D> NdActivateMut<A, D> for ArrayBase<S, D>
223where
224 A: ScalarOperand,
225 D: Dimension,
226 S: DataMut<Elem = A>,
227{
228 type Data = S;
229
230 fn activate_inplace<'a, F>(&'a mut self, f: F)
231 where
232 A: 'a,
233 F: FnMut(A) -> A,
234 {
235 self.mapv_inplace(f)
236 }
237}
238
239impl<'a, A, S, D> NdActivate<A, D> for &'a ArrayBase<S, D>
240where
241 A: ScalarOperand,
242 D: Dimension,
243 S: Data<Elem = A>,
244{
245 type Data = S;
246
247 fn activate<B, F>(&self, f: F) -> Array<B, D>
248 where
249 F: Fn(A) -> B,
250 {
251 self.mapv(f)
252 }
253}
254
255impl<'a, A, S, D> NdActivate<A, D> for &'a mut ArrayBase<S, D>
256where
257 A: ScalarOperand,
258 D: Dimension,
259 S: Data<Elem = A>,
260{
261 type Data = S;
262
263 fn activate<B, F>(&self, f: F) -> Array<B, D>
264 where
265 F: Fn(A) -> B,
266 {
267 self.mapv(f)
268 }
269}
270
271impl<'a, A, S, D> NdActivateMut<A, D> for &'a mut ArrayBase<S, D>
272where
273 A: ScalarOperand,
274 D: Dimension,
275 S: DataMut<Elem = A>,
276{
277 type Data = S;
278
279 fn activate_inplace<'b, F>(&'b mut self, f: F)
280 where
281 A: 'b,
282 F: FnMut(A) -> A,
283 {
284 self.mapv_inplace(f)
285 }
286}