concision_core/activate/traits/
activate.rs1use super::unary::*;
6
7use ndarray::prelude::*;
8use ndarray::{Data, DataMut, ScalarOperand};
9use num::complex::ComplexFloat;
10
11pub trait Activate<A> {
14 type Cont<B>;
15
16 fn activate<V, F>(&self, f: F) -> Self::Cont<V>
17 where
18 F: Fn(A) -> V;
19}
20pub trait ActivateMut<A> {
22 type Cont<B>;
23
24 fn activate_inplace<'a, F>(&'a mut self, f: F)
25 where
26 A: 'a,
27 F: FnMut(A) -> A;
28}
29pub trait ActivateExt<U>: Activate<U> {
33 fn linear(&self) -> Self::Cont<U::Output>
34 where
35 U: LinearActivation,
36 {
37 self.activate(|x| x.linear())
38 }
39
40 fn linear_derivative(&self) -> Self::Cont<U::Output>
41 where
42 U: LinearActivation,
43 {
44 self.activate(|x| x.linear_derivative())
45 }
46
47 fn heavyside(&self) -> Self::Cont<U::Output>
48 where
49 U: Heavyside,
50 {
51 self.activate(|x| x.heavyside())
52 }
53
54 fn heavyside_derivative(&self) -> Self::Cont<U::Output>
55 where
56 U: Heavyside,
57 {
58 self.activate(|x| x.heavyside_derivative())
59 }
60
61 fn relu(&self) -> Self::Cont<U::Output>
62 where
63 U: ReLU,
64 {
65 self.activate(|x| x.relu())
66 }
67
68 fn relu_derivative(&self) -> Self::Cont<U::Output>
69 where
70 U: ReLU,
71 {
72 self.activate(|x| x.relu_derivative())
73 }
74
75 fn sigmoid(&self) -> Self::Cont<U::Output>
76 where
77 U: Sigmoid,
78 {
79 self.activate(|x| x.sigmoid())
80 }
81
82 fn sigmoid_derivative(&self) -> Self::Cont<U::Output>
83 where
84 U: Sigmoid,
85 {
86 self.activate(|x| x.sigmoid_derivative())
87 }
88
89 fn tanh(&self) -> Self::Cont<U::Output>
90 where
91 U: Tanh,
92 {
93 self.activate(|x| x.tanh())
94 }
95
96 fn tanh_derivative(&self) -> Self::Cont<U::Output>
97 where
98 U: Tanh,
99 {
100 self.activate(|x| x.tanh_derivative())
101 }
102
103 fn sigmoid_complex(&self) -> Self::Cont<U>
104 where
105 U: ComplexFloat,
106 {
107 self.activate(|x| U::one() / (U::one() + (-x).exp()))
108 }
109
110 fn sigmoid_complex_derivative(&self) -> Self::Cont<U>
111 where
112 U: ComplexFloat,
113 {
114 self.activate(|x| {
115 let s = U::one() / (U::one() + (-x).exp());
116 s * (U::one() - s)
117 })
118 }
119
120 fn tanh_complex(&self) -> Self::Cont<U>
121 where
122 U: ComplexFloat,
123 {
124 self.activate(|x| x.tanh())
125 }
126
127 fn tanh_complex_derivative(&self) -> Self::Cont<U>
128 where
129 U: ComplexFloat,
130 {
131 self.activate(|x| {
132 let s = x.tanh();
133 U::one() - s * s
134 })
135 }
136}
137
138pub trait NdActivateMut<A, D>
139where
140 A: ScalarOperand,
141 D: Dimension,
142{
143 type Data: DataMut<Elem = A>;
144}
145impl<U, S> ActivateExt<U> for S where S: Activate<U> {}
149
150impl<A, S, D> Activate<A> for ArrayBase<S, D>
151where
152 A: ScalarOperand,
153 D: Dimension,
154 S: Data<Elem = A>,
155{
156 type Cont<V> = Array<V, D>;
157
158 fn activate<V, F>(&self, f: F) -> Self::Cont<V>
159 where
160 F: Fn(A) -> V,
161 {
162 self.mapv(f)
163 }
164}
165
166impl<A, S, D> Activate<A> for &ArrayBase<S, D>
167where
168 A: ScalarOperand,
169 D: Dimension,
170 S: Data<Elem = A>,
171{
172 type Cont<V> = Array<V, D>;
173
174 fn activate<B, F>(&self, f: F) -> Array<B, D>
175 where
176 F: Fn(A) -> B,
177 {
178 self.mapv(f)
179 }
180}
181
182impl<A, S, D> Activate<A> for &mut ArrayBase<S, D>
183where
184 A: ScalarOperand,
185 D: Dimension,
186 S: Data<Elem = A>,
187{
188 type Cont<V> = Array<V, D>;
189
190 fn activate<B, F>(&self, f: F) -> Array<B, D>
191 where
192 F: Fn(A) -> B,
193 {
194 self.mapv(f)
195 }
196}
197
198impl<A, S, D> ActivateMut<A> for ArrayBase<S, D>
199where
200 A: ScalarOperand,
201 D: Dimension,
202 S: DataMut<Elem = A>,
203{
204 type Cont<V> = Array<V, D>;
205
206 fn activate_inplace<'a, F>(&'a mut self, f: F)
207 where
208 A: 'a,
209 F: FnMut(A) -> A,
210 {
211 self.mapv_inplace(f)
212 }
213}
214
215impl<A, S, D> ActivateMut<A> for &mut ArrayBase<S, D>
216where
217 A: ScalarOperand,
218 D: Dimension,
219 S: DataMut<Elem = A>,
220{
221 type Cont<V> = Array<V, D>;
222
223 fn activate_inplace<'b, F>(&'b mut self, f: F)
224 where
225 A: 'b,
226 F: FnMut(A) -> A,
227 {
228 self.mapv_inplace(f)
229 }
230}