concision_core/activate/
traits.rs

1/*
2    Appellation: traits <activate>
3    Contrib: @FL03
4*/
5
6use ndarray::prelude::*;
7use ndarray::{Data, DataMut, RemoveAxis, ScalarOperand};
8use num::complex::ComplexFloat;
9
10macro_rules! unary {
11    ($($name:ident::$call:ident($($rest:tt)*)),* $(,)?) => {
12        $(
13            unary!(@impl $name::$call($($rest)*));
14        )*
15    };
16
17    (@impl $name:ident::$call:ident(self)) => {
18        paste::paste! {
19            pub trait $name {
20                type Output;
21
22                fn $call(self) -> Self::Output;
23
24                fn [<$call _derivative>](self) -> Self::Output;
25            }
26        }
27
28    };
29    (@impl $name:ident::$call:ident(&self)) => {
30        paste::paste! {
31            pub trait $name {
32                type Output;
33
34                fn $call(&self) -> Self::Output;
35
36                fn [<$call _derivative>](&self) -> Self::Output;
37            }
38        }
39    };
40}
41
42unary! {
43    Heavyside::heavyside(self),
44    LinearActivation::linear(self),
45    Sigmoid::sigmoid(&self),
46    Softmax::softmax(&self),
47    ReLU::relu(&self),
48    Tanh::tanh(&self),
49}
50
51pub trait SoftmaxAxis: Softmax {
52    fn softmax_axis(self, axis: usize) -> Self::Output;
53}
54
55pub trait NdActivate<A, D>
56where
57    A: ScalarOperand,
58    D: Dimension,
59{
60    type Data: Data<Elem = A>;
61
62    fn activate<B, F>(&self, f: F) -> Array<B, D>
63    where
64        F: Fn(A) -> B;
65
66    fn linear(&self) -> Array<A::Output, D>
67    where
68        A: LinearActivation,
69    {
70        self.activate(|x| x.linear())
71    }
72
73    fn linear_derivative(&self) -> Array<A::Output, D>
74    where
75        A: LinearActivation,
76    {
77        self.activate(|x| x.linear_derivative())
78    }
79
80    fn heavyside(&self) -> Array<A::Output, D>
81    where
82        A: Heavyside,
83    {
84        self.activate(|x| x.heavyside())
85    }
86
87    fn heavyside_derivative(&self) -> Array<A::Output, D>
88    where
89        A: Heavyside,
90    {
91        self.activate(|x| x.heavyside_derivative())
92    }
93
94    fn relu(&self) -> Array<A::Output, D>
95    where
96        A: ReLU,
97    {
98        self.activate(|x| x.relu())
99    }
100
101    fn relu_derivative(&self) -> Array<A::Output, D>
102    where
103        A: ReLU,
104    {
105        self.activate(|x| x.relu_derivative())
106    }
107    ///
108    fn sigmoid(&self) -> Array<A::Output, D>
109    where
110        A: Sigmoid,
111    {
112        self.activate(|x| x.sigmoid())
113    }
114    ///
115    fn sigmoid_derivative(&self) -> Array<A::Output, D>
116    where
117        A: Sigmoid,
118    {
119        self.activate(|x| x.sigmoid_derivative())
120    }
121    /// Softmax activation function
122    /// The softmax function is defined as:
123    /// $$ \sigma(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} $$
124    fn softmax(&self) -> Array<A, D>
125    where
126        A: ComplexFloat,
127    {
128        let exp = self.activate(A::exp);
129        &exp / exp.sum()
130    }
131
132    fn softmax_axis(&self, axis: usize) -> Array<A, D>
133    where
134        A: ComplexFloat,
135        D: RemoveAxis,
136    {
137        let exp = self.activate(A::exp);
138        let axis = Axis(axis);
139        &exp / &exp.sum_axis(axis)
140    }
141
142    fn tanh(&self) -> Array<A::Output, D>
143    where
144        A: Tanh,
145    {
146        self.activate(|x| x.tanh())
147    }
148
149    fn tanh_derivative(&self) -> Array<A::Output, D>
150    where
151        A: Tanh,
152    {
153        self.activate(|x| x.tanh_derivative())
154    }
155
156    fn sigmoid_complex(&self) -> Array<A, D>
157    where
158        A: ComplexFloat,
159    {
160        self.activate(|x| A::one() / (A::one() + (-x).exp()))
161    }
162
163    fn sigmoid_complex_derivative(&self) -> Array<A, D>
164    where
165        A: ComplexFloat,
166    {
167        self.activate(|x| {
168            let s = A::one() / (A::one() + (-x).exp());
169            s * (A::one() - s)
170        })
171    }
172
173    fn tanh_complex(&self) -> Array<A, D>
174    where
175        A: ComplexFloat,
176    {
177        self.activate(|x| x.tanh())
178    }
179    fn tanh_complex_derivative(&self) -> Array<A, D>
180    where
181        A: ComplexFloat,
182    {
183        self.activate(|x| {
184            let s = x.tanh();
185            A::one() - s * s
186        })
187    }
188}
189
190pub trait NdActivateMut<A, D>
191where
192    A: ScalarOperand,
193    D: Dimension,
194{
195    type Data: DataMut<Elem = A>;
196
197    fn activate_inplace<'a, F>(&'a mut self, f: F)
198    where
199        A: 'a,
200        F: FnMut(A) -> A;
201}
202/*
203 ************* Implementations *************
204*/
205
206impl<A, S, D> NdActivate<A, D> for ArrayBase<S, D>
207where
208    A: ScalarOperand,
209    D: Dimension,
210    S: Data<Elem = A>,
211{
212    type Data = S;
213
214    fn activate<B, F>(&self, f: F) -> Array<B, D>
215    where
216        F: Fn(A) -> B,
217    {
218        self.mapv(f)
219    }
220}
221
222impl<A, S, D> NdActivateMut<A, D> for ArrayBase<S, D>
223where
224    A: ScalarOperand,
225    D: Dimension,
226    S: DataMut<Elem = A>,
227{
228    type Data = S;
229
230    fn activate_inplace<'a, F>(&'a mut self, f: F)
231    where
232        A: 'a,
233        F: FnMut(A) -> A,
234    {
235        self.mapv_inplace(f)
236    }
237}
238
239impl<'a, A, S, D> NdActivate<A, D> for &'a ArrayBase<S, D>
240where
241    A: ScalarOperand,
242    D: Dimension,
243    S: Data<Elem = A>,
244{
245    type Data = S;
246
247    fn activate<B, F>(&self, f: F) -> Array<B, D>
248    where
249        F: Fn(A) -> B,
250    {
251        self.mapv(f)
252    }
253}
254
255impl<'a, A, S, D> NdActivate<A, D> for &'a mut ArrayBase<S, D>
256where
257    A: ScalarOperand,
258    D: Dimension,
259    S: Data<Elem = A>,
260{
261    type Data = S;
262
263    fn activate<B, F>(&self, f: F) -> Array<B, D>
264    where
265        F: Fn(A) -> B,
266    {
267        self.mapv(f)
268    }
269}
270
271impl<'a, A, S, D> NdActivateMut<A, D> for &'a mut ArrayBase<S, D>
272where
273    A: ScalarOperand,
274    D: Dimension,
275    S: DataMut<Elem = A>,
276{
277    type Data = S;
278
279    fn activate_inplace<'b, F>(&'b mut self, f: F)
280    where
281        A: 'b,
282        F: FnMut(A) -> A,
283    {
284        self.mapv_inplace(f)
285    }
286}