concision_core/func/activate/
nl.rs

1/*
2    Appellation: sigmoid <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::math::Exp;
6use ndarray::*;
7use num::complex::{Complex, ComplexFloat};
8use num::traits::Zero;
9
10fn _relu<T>(args: T) -> T
11where
12    T: PartialOrd + Zero,
13{
14    if args > T::zero() {
15        return args;
16    }
17    T::zero()
18}
19
20fn _sigmoid<T>(args: T) -> T
21where
22    T: ComplexFloat,
23{
24    (T::one() + args.neg().exp()).recip()
25}
26
27fn _softmax<A, S, D>(args: &ArrayBase<S, D>) -> Array<A, D>
28where
29    A: ComplexFloat + ScalarOperand,
30    D: Dimension,
31    S: Data<Elem = A>,
32{
33    let e = args.exp();
34    &e / e.sum()
35}
36
37// fn __softmax<T, I>(args: &I) -> I
38// where
39//     I: Clone + core::ops::Div<T, Output = I> + Exp<Output = I>, T: Exp<Output = T> + core::iter::Sum ,
40//     for<'a> I: IntoIterator<Item = &'a T>,
41// {
42//     let e = args.exp();
43//     e.clone() / e.into_iter().sum::<T>()
44// }
45
46fn _tanh<T>(args: T) -> T
47where
48    T: ComplexFloat,
49{
50    args.tanh()
51}
52
53unary!(
54    ReLU::relu(self),
55    Sigmoid::sigmoid(self),
56    Softmax::softmax(self),
57    Tanh::tanh(self),
58);
59
60/*
61 ********** Implementations **********
62*/
63macro_rules! nonlinear {
64    ($($rho:ident::$call:ident<[$($T:ty),* $(,)?]>),* $(,)? ) => {
65        $(
66            nonlinear!(@loop $rho::$call<[$($T),*]>);
67        )*
68    };
69    (@loop $rho:ident::$call:ident<[$($T:ty),* $(,)?]> ) => {
70        $(
71            nonlinear!(@impl $rho::$call<$T>);
72        )*
73
74        nonlinear!(@arr $rho::$call);
75    };
76    (@impl $rho:ident::$call:ident<$T:ty>) => {
77        paste::paste! {
78            impl $rho for $T {
79                type Output = $T;
80
81                fn $call(self) -> Self::Output {
82                    [<_ $call>](self)
83                }
84            }
85
86            impl<'a> $rho for &'a $T {
87                type Output = $T;
88
89                fn $call(self) -> Self::Output {
90                    [<_ $call>](*self)
91                }
92            }
93        }
94
95
96    };
97    (@arr $name:ident::$call:ident) => {
98        impl<A, S, D> $name for ArrayBase<S, D>
99        where
100            A: Clone + $name,
101            D: Dimension,
102            S: Data<Elem = A>
103        {
104            type Output = Array<<A as $name>::Output, D>;
105
106            fn $call(self) -> Self::Output {
107                self.mapv($name::$call)
108            }
109        }
110
111        impl<'a, A, S, D> $name for &'a ArrayBase<S, D>
112        where
113            A: Clone + $name,
114            D: Dimension,
115            S: Data<Elem = A>
116        {
117            type Output = Array<<A as $name>::Output, D>;
118
119            fn $call(self) -> Self::Output {
120                self.mapv($name::$call)
121            }
122        }
123    };
124}
125
126nonlinear!(
127    ReLU::relu<[
128        f32,
129        f64,
130        i8,
131        i16,
132        i32,
133        i64,
134        i128,
135        isize,
136        u8,
137        u16,
138        u32,
139        u64,
140        u128,
141        usize
142    ]>,
143    Sigmoid::sigmoid<[
144        f32,
145        f64,
146        Complex<f32>,
147        Complex<f64>
148    ]>,
149    Tanh::tanh<[
150        f32,
151        f64,
152        Complex<f32>,
153        Complex<f64>
154    ]>,
155);
156
157impl<A, S, D> Softmax for ArrayBase<S, D>
158where
159    A: ComplexFloat + ScalarOperand,
160    D: Dimension,
161    S: Data<Elem = A>,
162{
163    type Output = Array<A, D>;
164
165    fn softmax(self) -> Self::Output {
166        _softmax(&self)
167    }
168}
169
170impl<'a, A, S, D> Softmax for &'a ArrayBase<S, D>
171where
172    A: ComplexFloat + ScalarOperand,
173    D: Dimension,
174    S: Data<Elem = A>,
175{
176    type Output = Array<A, D>;
177
178    fn softmax(self) -> Self::Output {
179        _softmax(self)
180    }
181}