concision_core/activate/
traits.rs

1/*
2    Appellation: traits <activate>
3    Contrib: @FL03
4*/
5
6use ndarray::prelude::*;
7use ndarray::{Data, DataMut, ScalarOperand};
8use num::complex::ComplexFloat;
9
10macro_rules! unary {
11    ($($name:ident::$call:ident($($rest:tt)*)),* $(,)?) => {
12        $(
13            unary!(@impl $name::$call($($rest)*));
14        )*
15    };
16
17    (@impl $name:ident::$call:ident(self)) => {
18        paste::paste! {
19            pub trait $name {
20                type Output;
21
22                fn $call(self) -> Self::Output;
23
24                fn [<$call _derivative>](self) -> Self::Output;
25            }
26        }
27
28    };
29    (@impl $name:ident::$call:ident(&self)) => {
30        paste::paste! {
31            pub trait $name {
32                type Output;
33
34                fn $call(&self) -> Self::Output;
35
36                fn [<$call _derivative>](&self) -> Self::Output;
37            }
38        }
39    };
40}
41
42unary! {
43    Heavyside::heavyside(self),
44    LinearActivation::linear(self),
45    Sigmoid::sigmoid(self),
46    Softmax::softmax(&self),
47    ReLU::relu(&self),
48    Tanh::tanh(&self),
49}
50
51pub trait SoftmaxAxis: Softmax {
52    fn softmax_axis(self, axis: usize) -> Self::Output;
53}
54
55/// A trait defining the manner in which a particular entity can be activated.
56pub trait Activate<A> {
57    type Cont<B>;
58
59    fn activate<V, F>(&self, f: F) -> Self::Cont<V>
60    where
61        F: Fn(A) -> V;
62}
63/// A trait for establishing a common mechanism to activate entities in-place.
64pub trait ActivateMut<A> {
65    type Cont<B>;
66
67    fn activate_inplace<'a, F>(&'a mut self, f: F)
68    where
69        A: 'a,
70        F: FnMut(A) -> A;
71}
72/// This trait extends the [`Activate`] trait with a number of additional activation functions
73/// and their derivatives. _**Note:**_ this trait is automatically implemented for any type
74/// that implements the [`Activate`] trait eliminating the need to implement it manually.
75pub trait ActivateExt<U>: Activate<U> {
76    fn linear(&self) -> Self::Cont<U::Output>
77    where
78        U: LinearActivation,
79    {
80        self.activate(|x| x.linear())
81    }
82
83    fn linear_derivative(&self) -> Self::Cont<U::Output>
84    where
85        U: LinearActivation,
86    {
87        self.activate(|x| x.linear_derivative())
88    }
89
90    fn heavyside(&self) -> Self::Cont<U::Output>
91    where
92        U: Heavyside,
93    {
94        self.activate(|x| x.heavyside())
95    }
96
97    fn heavyside_derivative(&self) -> Self::Cont<U::Output>
98    where
99        U: Heavyside,
100    {
101        self.activate(|x| x.heavyside_derivative())
102    }
103
104    fn relu(&self) -> Self::Cont<U::Output>
105    where
106        U: ReLU,
107    {
108        self.activate(|x| x.relu())
109    }
110
111    fn relu_derivative(&self) -> Self::Cont<U::Output>
112    where
113        U: ReLU,
114    {
115        self.activate(|x| x.relu_derivative())
116    }
117
118    fn sigmoid(&self) -> Self::Cont<U::Output>
119    where
120        U: Sigmoid,
121    {
122        self.activate(|x| x.sigmoid())
123    }
124
125    fn sigmoid_derivative(&self) -> Self::Cont<U::Output>
126    where
127        U: Sigmoid,
128    {
129        self.activate(|x| x.sigmoid_derivative())
130    }
131
132    fn tanh(&self) -> Self::Cont<U::Output>
133    where
134        U: Tanh,
135    {
136        self.activate(|x| x.tanh())
137    }
138
139    fn tanh_derivative(&self) -> Self::Cont<U::Output>
140    where
141        U: Tanh,
142    {
143        self.activate(|x| x.tanh_derivative())
144    }
145
146    fn sigmoid_complex(&self) -> Self::Cont<U>
147    where
148        U: ComplexFloat,
149    {
150        self.activate(|x| U::one() / (U::one() + (-x).exp()))
151    }
152
153    fn sigmoid_complex_derivative(&self) -> Self::Cont<U>
154    where
155        U: ComplexFloat,
156    {
157        self.activate(|x| {
158            let s = U::one() / (U::one() + (-x).exp());
159            s * (U::one() - s)
160        })
161    }
162
163    fn tanh_complex(&self) -> Self::Cont<U>
164    where
165        U: ComplexFloat,
166    {
167        self.activate(|x| x.tanh())
168    }
169
170    fn tanh_complex_derivative(&self) -> Self::Cont<U>
171    where
172        U: ComplexFloat,
173    {
174        self.activate(|x| {
175            let s = x.tanh();
176            U::one() - s * s
177        })
178    }
179}
180
181pub trait NdActivateMut<A, D>
182where
183    A: ScalarOperand,
184    D: Dimension,
185{
186    type Data: DataMut<Elem = A>;
187}
188/*
189 ************* Implementations *************
190*/
191impl<U, S> ActivateExt<U> for S where S: Activate<U> {}
192
193impl<A, S, D> Activate<A> for ArrayBase<S, D>
194where
195    A: ScalarOperand,
196    D: Dimension,
197    S: Data<Elem = A>,
198{
199    type Cont<V> = Array<V, D>;
200
201    fn activate<V, F>(&self, f: F) -> Self::Cont<V>
202    where
203        F: Fn(A) -> V,
204    {
205        self.mapv(f)
206    }
207}
208
209impl<A, S, D> Activate<A> for &ArrayBase<S, D>
210where
211    A: ScalarOperand,
212    D: Dimension,
213    S: Data<Elem = A>,
214{
215    type Cont<V> = Array<V, D>;
216
217    fn activate<B, F>(&self, f: F) -> Array<B, D>
218    where
219        F: Fn(A) -> B,
220    {
221        self.mapv(f)
222    }
223}
224
225impl<A, S, D> Activate<A> for &mut ArrayBase<S, D>
226where
227    A: ScalarOperand,
228    D: Dimension,
229    S: Data<Elem = A>,
230{
231    type Cont<V> = Array<V, D>;
232
233    fn activate<B, F>(&self, f: F) -> Array<B, D>
234    where
235        F: Fn(A) -> B,
236    {
237        self.mapv(f)
238    }
239}
240
241impl<A, S, D> ActivateMut<A> for ArrayBase<S, D>
242where
243    A: ScalarOperand,
244    D: Dimension,
245    S: DataMut<Elem = A>,
246{
247    type Cont<V> = Array<V, D>;
248
249    fn activate_inplace<'a, F>(&'a mut self, f: F)
250    where
251        A: 'a,
252        F: FnMut(A) -> A,
253    {
254        self.mapv_inplace(f)
255    }
256}
257
258impl<A, S, D> ActivateMut<A> for &mut ArrayBase<S, D>
259where
260    A: ScalarOperand,
261    D: Dimension,
262    S: DataMut<Elem = A>,
263{
264    type Cont<V> = Array<V, D>;
265
266    fn activate_inplace<'b, F>(&'b mut self, f: F)
267    where
268        A: 'b,
269        F: FnMut(A) -> A,
270    {
271        self.mapv_inplace(f)
272    }
273}