1use ndarray::prelude::*;
7use ndarray::{Data, DataMut, ScalarOperand};
8use num::complex::ComplexFloat;
9
10macro_rules! unary {
11 ($($name:ident::$call:ident($($rest:tt)*)),* $(,)?) => {
12 $(
13 unary!(@impl $name::$call($($rest)*));
14 )*
15 };
16
17 (@impl $name:ident::$call:ident(self)) => {
18 paste::paste! {
19 pub trait $name {
20 type Output;
21
22 fn $call(self) -> Self::Output;
23
24 fn [<$call _derivative>](self) -> Self::Output;
25 }
26 }
27
28 };
29 (@impl $name:ident::$call:ident(&self)) => {
30 paste::paste! {
31 pub trait $name {
32 type Output;
33
34 fn $call(&self) -> Self::Output;
35
36 fn [<$call _derivative>](&self) -> Self::Output;
37 }
38 }
39 };
40}
41
42unary! {
43 Heavyside::heavyside(self),
44 LinearActivation::linear(self),
45 Sigmoid::sigmoid(self),
46 Softmax::softmax(&self),
47 ReLU::relu(&self),
48 Tanh::tanh(&self),
49}
50
51pub trait SoftmaxAxis: Softmax {
52 fn softmax_axis(self, axis: usize) -> Self::Output;
53}
54
55pub trait Activate<A> {
57 type Cont<B>;
58
59 fn activate<V, F>(&self, f: F) -> Self::Cont<V>
60 where
61 F: Fn(A) -> V;
62}
63pub trait ActivateMut<A> {
65 type Cont<B>;
66
67 fn activate_inplace<'a, F>(&'a mut self, f: F)
68 where
69 A: 'a,
70 F: FnMut(A) -> A;
71}
72pub trait ActivateExt<U>: Activate<U> {
76 fn linear(&self) -> Self::Cont<U::Output>
77 where
78 U: LinearActivation,
79 {
80 self.activate(|x| x.linear())
81 }
82
83 fn linear_derivative(&self) -> Self::Cont<U::Output>
84 where
85 U: LinearActivation,
86 {
87 self.activate(|x| x.linear_derivative())
88 }
89
90 fn heavyside(&self) -> Self::Cont<U::Output>
91 where
92 U: Heavyside,
93 {
94 self.activate(|x| x.heavyside())
95 }
96
97 fn heavyside_derivative(&self) -> Self::Cont<U::Output>
98 where
99 U: Heavyside,
100 {
101 self.activate(|x| x.heavyside_derivative())
102 }
103
104 fn relu(&self) -> Self::Cont<U::Output>
105 where
106 U: ReLU,
107 {
108 self.activate(|x| x.relu())
109 }
110
111 fn relu_derivative(&self) -> Self::Cont<U::Output>
112 where
113 U: ReLU,
114 {
115 self.activate(|x| x.relu_derivative())
116 }
117
118 fn sigmoid(&self) -> Self::Cont<U::Output>
119 where
120 U: Sigmoid,
121 {
122 self.activate(|x| x.sigmoid())
123 }
124
125 fn sigmoid_derivative(&self) -> Self::Cont<U::Output>
126 where
127 U: Sigmoid,
128 {
129 self.activate(|x| x.sigmoid_derivative())
130 }
131
132 fn tanh(&self) -> Self::Cont<U::Output>
133 where
134 U: Tanh,
135 {
136 self.activate(|x| x.tanh())
137 }
138
139 fn tanh_derivative(&self) -> Self::Cont<U::Output>
140 where
141 U: Tanh,
142 {
143 self.activate(|x| x.tanh_derivative())
144 }
145
146 fn sigmoid_complex(&self) -> Self::Cont<U>
147 where
148 U: ComplexFloat,
149 {
150 self.activate(|x| U::one() / (U::one() + (-x).exp()))
151 }
152
153 fn sigmoid_complex_derivative(&self) -> Self::Cont<U>
154 where
155 U: ComplexFloat,
156 {
157 self.activate(|x| {
158 let s = U::one() / (U::one() + (-x).exp());
159 s * (U::one() - s)
160 })
161 }
162
163 fn tanh_complex(&self) -> Self::Cont<U>
164 where
165 U: ComplexFloat,
166 {
167 self.activate(|x| x.tanh())
168 }
169
170 fn tanh_complex_derivative(&self) -> Self::Cont<U>
171 where
172 U: ComplexFloat,
173 {
174 self.activate(|x| {
175 let s = x.tanh();
176 U::one() - s * s
177 })
178 }
179}
180
181pub trait NdActivateMut<A, D>
182where
183 A: ScalarOperand,
184 D: Dimension,
185{
186 type Data: DataMut<Elem = A>;
187}
188impl<U, S> ActivateExt<U> for S where S: Activate<U> {}
192
193impl<A, S, D> Activate<A> for ArrayBase<S, D>
194where
195 A: ScalarOperand,
196 D: Dimension,
197 S: Data<Elem = A>,
198{
199 type Cont<V> = Array<V, D>;
200
201 fn activate<V, F>(&self, f: F) -> Self::Cont<V>
202 where
203 F: Fn(A) -> V,
204 {
205 self.mapv(f)
206 }
207}
208
209impl<A, S, D> Activate<A> for &ArrayBase<S, D>
210where
211 A: ScalarOperand,
212 D: Dimension,
213 S: Data<Elem = A>,
214{
215 type Cont<V> = Array<V, D>;
216
217 fn activate<B, F>(&self, f: F) -> Array<B, D>
218 where
219 F: Fn(A) -> B,
220 {
221 self.mapv(f)
222 }
223}
224
225impl<A, S, D> Activate<A> for &mut ArrayBase<S, D>
226where
227 A: ScalarOperand,
228 D: Dimension,
229 S: Data<Elem = A>,
230{
231 type Cont<V> = Array<V, D>;
232
233 fn activate<B, F>(&self, f: F) -> Array<B, D>
234 where
235 F: Fn(A) -> B,
236 {
237 self.mapv(f)
238 }
239}
240
241impl<A, S, D> ActivateMut<A> for ArrayBase<S, D>
242where
243 A: ScalarOperand,
244 D: Dimension,
245 S: DataMut<Elem = A>,
246{
247 type Cont<V> = Array<V, D>;
248
249 fn activate_inplace<'a, F>(&'a mut self, f: F)
250 where
251 A: 'a,
252 F: FnMut(A) -> A,
253 {
254 self.mapv_inplace(f)
255 }
256}
257
258impl<A, S, D> ActivateMut<A> for &mut ArrayBase<S, D>
259where
260 A: ScalarOperand,
261 D: Dimension,
262 S: DataMut<Elem = A>,
263{
264 type Cont<V> = Array<V, D>;
265
266 fn activate_inplace<'b, F>(&'b mut self, f: F)
267 where
268 A: 'b,
269 F: FnMut(A) -> A,
270 {
271 self.mapv_inplace(f)
272 }
273}