concision_core/activate/traits/
activate.rs

1/*
2    appellation: activate <module>
3    authors: @FL03
4*/
5use super::unary::*;
6
7use ndarray::prelude::*;
8use ndarray::{Data, DataMut, ScalarOperand};
9use num::complex::ComplexFloat;
10
11/// The [`Activate`] trait establishes a common interface for entities that can be _activated_
12/// according to some function
13pub trait Activate<A> {
14    type Cont<B>;
15
16    fn activate<V, F>(&self, f: F) -> Self::Cont<V>
17    where
18        F: Fn(A) -> V;
19}
20/// A trait for establishing a common mechanism to activate entities in-place.
21pub trait ActivateMut<A> {
22    type Cont<B>;
23
24    fn activate_inplace<'a, F>(&'a mut self, f: F)
25    where
26        A: 'a,
27        F: FnMut(A) -> A;
28}
29/// This trait extends the [`Activate`] trait with a number of additional activation functions
30/// and their derivatives. _**Note:**_ this trait is automatically implemented for any type
31/// that implements the [`Activate`] trait eliminating the need to implement it manually.
32pub trait ActivateExt<U>: Activate<U> {
33    fn linear(&self) -> Self::Cont<U::Output>
34    where
35        U: LinearActivation,
36    {
37        self.activate(|x| x.linear())
38    }
39
40    fn linear_derivative(&self) -> Self::Cont<U::Output>
41    where
42        U: LinearActivation,
43    {
44        self.activate(|x| x.linear_derivative())
45    }
46
47    fn heavyside(&self) -> Self::Cont<U::Output>
48    where
49        U: Heavyside,
50    {
51        self.activate(|x| x.heavyside())
52    }
53
54    fn heavyside_derivative(&self) -> Self::Cont<U::Output>
55    where
56        U: Heavyside,
57    {
58        self.activate(|x| x.heavyside_derivative())
59    }
60
61    fn relu(&self) -> Self::Cont<U::Output>
62    where
63        U: ReLU,
64    {
65        self.activate(|x| x.relu())
66    }
67
68    fn relu_derivative(&self) -> Self::Cont<U::Output>
69    where
70        U: ReLU,
71    {
72        self.activate(|x| x.relu_derivative())
73    }
74
75    fn sigmoid(&self) -> Self::Cont<U::Output>
76    where
77        U: Sigmoid,
78    {
79        self.activate(|x| x.sigmoid())
80    }
81
82    fn sigmoid_derivative(&self) -> Self::Cont<U::Output>
83    where
84        U: Sigmoid,
85    {
86        self.activate(|x| x.sigmoid_derivative())
87    }
88
89    fn tanh(&self) -> Self::Cont<U::Output>
90    where
91        U: Tanh,
92    {
93        self.activate(|x| x.tanh())
94    }
95
96    fn tanh_derivative(&self) -> Self::Cont<U::Output>
97    where
98        U: Tanh,
99    {
100        self.activate(|x| x.tanh_derivative())
101    }
102
103    fn sigmoid_complex(&self) -> Self::Cont<U>
104    where
105        U: ComplexFloat,
106    {
107        self.activate(|x| U::one() / (U::one() + (-x).exp()))
108    }
109
110    fn sigmoid_complex_derivative(&self) -> Self::Cont<U>
111    where
112        U: ComplexFloat,
113    {
114        self.activate(|x| {
115            let s = U::one() / (U::one() + (-x).exp());
116            s * (U::one() - s)
117        })
118    }
119
120    fn tanh_complex(&self) -> Self::Cont<U>
121    where
122        U: ComplexFloat,
123    {
124        self.activate(|x| x.tanh())
125    }
126
127    fn tanh_complex_derivative(&self) -> Self::Cont<U>
128    where
129        U: ComplexFloat,
130    {
131        self.activate(|x| {
132            let s = x.tanh();
133            U::one() - s * s
134        })
135    }
136}
137
138pub trait NdActivateMut<A, D>
139where
140    A: ScalarOperand,
141    D: Dimension,
142{
143    type Data: DataMut<Elem = A>;
144}
145/*
146 ************* Implementations *************
147*/
148impl<U, S> ActivateExt<U> for S where S: Activate<U> {}
149
150impl<A, S, D> Activate<A> for ArrayBase<S, D>
151where
152    A: ScalarOperand,
153    D: Dimension,
154    S: Data<Elem = A>,
155{
156    type Cont<V> = Array<V, D>;
157
158    fn activate<V, F>(&self, f: F) -> Self::Cont<V>
159    where
160        F: Fn(A) -> V,
161    {
162        self.mapv(f)
163    }
164}
165
166impl<A, S, D> Activate<A> for &ArrayBase<S, D>
167where
168    A: ScalarOperand,
169    D: Dimension,
170    S: Data<Elem = A>,
171{
172    type Cont<V> = Array<V, D>;
173
174    fn activate<B, F>(&self, f: F) -> Array<B, D>
175    where
176        F: Fn(A) -> B,
177    {
178        self.mapv(f)
179    }
180}
181
182impl<A, S, D> Activate<A> for &mut ArrayBase<S, D>
183where
184    A: ScalarOperand,
185    D: Dimension,
186    S: Data<Elem = A>,
187{
188    type Cont<V> = Array<V, D>;
189
190    fn activate<B, F>(&self, f: F) -> Array<B, D>
191    where
192        F: Fn(A) -> B,
193    {
194        self.mapv(f)
195    }
196}
197
198impl<A, S, D> ActivateMut<A> for ArrayBase<S, D>
199where
200    A: ScalarOperand,
201    D: Dimension,
202    S: DataMut<Elem = A>,
203{
204    type Cont<V> = Array<V, D>;
205
206    fn activate_inplace<'a, F>(&'a mut self, f: F)
207    where
208        A: 'a,
209        F: FnMut(A) -> A,
210    {
211        self.mapv_inplace(f)
212    }
213}
214
215impl<A, S, D> ActivateMut<A> for &mut ArrayBase<S, D>
216where
217    A: ScalarOperand,
218    D: Dimension,
219    S: DataMut<Elem = A>,
220{
221    type Cont<V> = Array<V, D>;
222
223    fn activate_inplace<'b, F>(&'b mut self, f: F)
224    where
225        A: 'b,
226        F: FnMut(A) -> A,
227    {
228        self.mapv_inplace(f)
229    }
230}