concision_core/activate/
traits.rs

1/*
2    Appellation: traits <activate>
3    Contrib: @FL03
4*/
5
6use super::utils::*;
7
8use ndarray::prelude::*;
9use ndarray::{Data, DataMut, RemoveAxis, ScalarOperand};
10use num::complex::ComplexFloat;
11use num_traits::{Float, One, Zero};
12
13macro_rules! unary {
14    ($($name:ident::$call:ident($($rest:tt)*)),* $(,)?) => {
15        $(
16            unary!(@impl $name::$call($($rest)*));
17        )*
18    };
19    (@impl $name:ident::$call:ident(self)) => {
20        pub trait $name {
21            type Output;
22
23            fn $call(self) -> Self::Output;
24        }
25    };
26    (@impl $name:ident::$call:ident(&self)) => {
27        pub trait $name {
28            type Output;
29
30            fn $call(&self) -> Self::Output;
31        }
32    };
33}
34
35unary! {
36    Heavyside::heavyside(self),
37    LinearActivation::linear(self),
38    Sigmoid::sigmoid(&self),
39    Softmax::softmax(&self),
40    ReLU::relu(&self),
41    Tanh::tanh(&self),
42}
43
44pub trait SoftmaxAxis: Softmax {
45    fn softmax_axis(self, axis: usize) -> Self::Output;
46}
47
48pub trait NdActivate<A, D>
49where
50    A: ScalarOperand,
51    D: Dimension,
52{
53    type Data: Data<Elem = A>;
54
55    fn activate<B, F>(&self, f: F) -> Array<B, D>
56    where
57        F: Fn(A) -> B;
58
59    fn linear(&self) -> Array<A, D>
60    where
61        A: Clone,
62    {
63        self.activate(|x| x.clone())
64    }
65
66    fn linear_derivative(&self) -> Array<A, D>
67    where
68        A: One,
69    {
70        self.activate(|_| A::one())
71    }
72
73    fn heavyside(&self) -> Array<A, D>
74    where
75        A: One + PartialOrd + Zero,
76    {
77        self.activate(heavyside)
78    }
79
80    fn relu(&self) -> Array<A, D>
81    where
82        A: PartialOrd + Zero,
83    {
84        self.activate(relu)
85    }
86
87    fn relu_derivative(&self) -> Array<A, D>
88    where
89        A: PartialOrd + One + Zero,
90    {
91        self.activate(relu_derivative)
92    }
93    ///
94    fn sigmoid(&self) -> Array<A, D>
95    where
96        A: Float,
97    {
98        self.activate(sigmoid)
99    }
100    ///
101    fn sigmoid_derivative(&self) -> Array<A, D>
102    where
103        A: Float,
104    {
105        self.activate(sigmoid_derivative)
106    }
107
108    fn sigmoid_complex(&self) -> Array<A, D>
109    where
110        A: ComplexFloat,
111    {
112        self.activate(|x| A::one() / (A::one() + (-x).exp()))
113    }
114    fn sigmoid_complex_derivative(&self) -> Array<A, D>
115    where
116        A: ComplexFloat,
117    {
118        self.activate(|x| {
119            let s = A::one() / (A::one() + (-x).exp());
120            s * (A::one() - s)
121        })
122    }
123
124    fn softmax(&self) -> Array<A, D>
125    where
126        A: ComplexFloat,
127    {
128        let exp = self.activate(A::exp);
129        &exp / exp.sum()
130    }
131
132    fn softmax_axis(&self, axis: usize) -> Array<A, D>
133    where
134        A: ComplexFloat,
135        D: RemoveAxis,
136    {
137        let exp = self.activate(A::exp);
138        let axis = Axis(axis);
139        &exp / &exp.sum_axis(axis)
140    }
141
142    fn tanh(&self) -> Array<A, D>
143    where
144        A: ComplexFloat,
145    {
146        self.activate(A::tanh)
147    }
148
149    fn tanh_derivative(&self) -> Array<A, D>
150    where
151        A: ComplexFloat,
152    {
153        self.activate(|i| A::one() - A::tanh(i) * A::tanh(i))
154    }
155}
156
157pub trait NdActivateMut<A, D>
158where
159    A: ScalarOperand,
160    D: Dimension,
161{
162    type Data: DataMut<Elem = A>;
163
164    fn activate_inplace<'a, F>(&'a mut self, f: F)
165    where
166        A: 'a,
167        F: FnMut(A) -> A;
168}
169/*
170 ************* Implementations *************
171*/
172
173impl<A, S, D> NdActivate<A, D> for ArrayBase<S, D>
174where
175    A: ScalarOperand,
176    D: Dimension,
177    S: Data<Elem = A>,
178{
179    type Data = S;
180
181    fn activate<B, F>(&self, f: F) -> Array<B, D>
182    where
183        F: Fn(A) -> B,
184    {
185        self.mapv(f)
186    }
187}
188
189impl<A, S, D> NdActivateMut<A, D> for ArrayBase<S, D>
190where
191    A: ScalarOperand,
192    D: Dimension,
193    S: DataMut<Elem = A>,
194{
195    type Data = S;
196
197    fn activate_inplace<'a, F>(&'a mut self, f: F)
198    where
199        A: 'a,
200        F: FnMut(A) -> A,
201    {
202        self.mapv_inplace(f)
203    }
204}
205
206impl<'a, A, S, D> NdActivate<A, D> for &'a ArrayBase<S, D>
207where
208    A: ScalarOperand,
209    D: Dimension,
210    S: Data<Elem = A>,
211{
212    type Data = S;
213
214    fn activate<B, F>(&self, f: F) -> Array<B, D>
215    where
216        F: Fn(A) -> B,
217    {
218        self.mapv(f)
219    }
220}
221
222impl<'a, A, S, D> NdActivate<A, D> for &'a mut ArrayBase<S, D>
223where
224    A: ScalarOperand,
225    D: Dimension,
226    S: Data<Elem = A>,
227{
228    type Data = S;
229
230    fn activate<B, F>(&self, f: F) -> Array<B, D>
231    where
232        F: Fn(A) -> B,
233    {
234        self.mapv(f)
235    }
236}
237
238impl<'a, A, S, D> NdActivateMut<A, D> for &'a mut ArrayBase<S, D>
239where
240    A: ScalarOperand,
241    D: Dimension,
242    S: DataMut<Elem = A>,
243{
244    type Data = S;
245
246    fn activate_inplace<'b, F>(&'b mut self, f: F)
247    where
248        A: 'b,
249        F: FnMut(A) -> A,
250    {
251        self.mapv_inplace(f)
252    }
253}