concision_core/func/activate/
nl.rs1use crate::math::Exp;
6use ndarray::*;
7use num::complex::{Complex, ComplexFloat};
8use num::traits::Zero;
9
10fn _relu<T>(args: T) -> T
11where
12 T: PartialOrd + Zero,
13{
14 if args > T::zero() {
15 return args;
16 }
17 T::zero()
18}
19
20fn _sigmoid<T>(args: T) -> T
21where
22 T: ComplexFloat,
23{
24 (T::one() + args.neg().exp()).recip()
25}
26
27fn _softmax<A, S, D>(args: &ArrayBase<S, D>) -> Array<A, D>
28where
29 A: ComplexFloat + ScalarOperand,
30 D: Dimension,
31 S: Data<Elem = A>,
32{
33 let e = args.exp();
34 &e / e.sum()
35}
36
37fn _tanh<T>(args: T) -> T
47where
48 T: ComplexFloat,
49{
50 args.tanh()
51}
52
53unary!(
54 ReLU::relu(self),
55 Sigmoid::sigmoid(self),
56 Softmax::softmax(self),
57 Tanh::tanh(self),
58);
59
60macro_rules! nonlinear {
64 ($($rho:ident::$call:ident<[$($T:ty),* $(,)?]>),* $(,)? ) => {
65 $(
66 nonlinear!(@loop $rho::$call<[$($T),*]>);
67 )*
68 };
69 (@loop $rho:ident::$call:ident<[$($T:ty),* $(,)?]> ) => {
70 $(
71 nonlinear!(@impl $rho::$call<$T>);
72 )*
73
74 nonlinear!(@arr $rho::$call);
75 };
76 (@impl $rho:ident::$call:ident<$T:ty>) => {
77 paste::paste! {
78 impl $rho for $T {
79 type Output = $T;
80
81 fn $call(self) -> Self::Output {
82 [<_ $call>](self)
83 }
84 }
85
86 impl<'a> $rho for &'a $T {
87 type Output = $T;
88
89 fn $call(self) -> Self::Output {
90 [<_ $call>](*self)
91 }
92 }
93 }
94
95
96 };
97 (@arr $name:ident::$call:ident) => {
98 impl<A, S, D> $name for ArrayBase<S, D>
99 where
100 A: Clone + $name,
101 D: Dimension,
102 S: Data<Elem = A>
103 {
104 type Output = Array<<A as $name>::Output, D>;
105
106 fn $call(self) -> Self::Output {
107 self.mapv($name::$call)
108 }
109 }
110
111 impl<'a, A, S, D> $name for &'a ArrayBase<S, D>
112 where
113 A: Clone + $name,
114 D: Dimension,
115 S: Data<Elem = A>
116 {
117 type Output = Array<<A as $name>::Output, D>;
118
119 fn $call(self) -> Self::Output {
120 self.mapv($name::$call)
121 }
122 }
123 };
124}
125
126nonlinear!(
127 ReLU::relu<[
128 f32,
129 f64,
130 i8,
131 i16,
132 i32,
133 i64,
134 i128,
135 isize,
136 u8,
137 u16,
138 u32,
139 u64,
140 u128,
141 usize
142 ]>,
143 Sigmoid::sigmoid<[
144 f32,
145 f64,
146 Complex<f32>,
147 Complex<f64>
148 ]>,
149 Tanh::tanh<[
150 f32,
151 f64,
152 Complex<f32>,
153 Complex<f64>
154 ]>,
155);
156
157impl<A, S, D> Softmax for ArrayBase<S, D>
158where
159 A: ComplexFloat + ScalarOperand,
160 D: Dimension,
161 S: Data<Elem = A>,
162{
163 type Output = Array<A, D>;
164
165 fn softmax(self) -> Self::Output {
166 _softmax(&self)
167 }
168}
169
170impl<'a, A, S, D> Softmax for &'a ArrayBase<S, D>
171where
172 A: ComplexFloat + ScalarOperand,
173 D: Dimension,
174 S: Data<Elem = A>,
175{
176 type Output = Array<A, D>;
177
178 fn softmax(self) -> Self::Output {
179 _softmax(self)
180 }
181}