concision_core/activate/
rho.rs1pub trait Activator<T> {
11 type Output;
12
13 fn activate(&self, input: T) -> Self::Output;
15}
16pub trait ActivatorGradient<T> {
19 type Rel: Activator<T>;
20 type Delta;
21
22 fn activate_gradient(&self, input: T) -> Self::Delta;
24}
25
26macro_rules! activator {
27 ($($vis:vis struct $name:ident::<$T:ident>::$method:ident $({where $($where:tt)*})?),* $(,)?) => {
28 $(activator! {
29 @impl $vis struct $name::<$T>::$method $({where $($where)*})?
30 })*
31 };
32 (@impl $vis:vis struct $name:ident::<$T:ident>::$method:ident $({where $($where:tt)*})? ) => {
33 #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
34 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
35 $vis struct $name;
36
37 impl<$T> Activator<$T> for $name
38 $(where $($where)*)?
39 {
40 type Output = <$T>::Output;
41
42 fn activate(&self, x: $T) -> Self::Output {
43 x.$method()
44 }
45 }
46
47 paste::paste! {
48 impl<$T> ActivatorGradient<$T> for $name
49 $(where $($where)*)?,
50 {
51 type Rel = Self;
52 type Delta = <$T>::Output;
53
54 fn activate_gradient(&self, inputs: $T) -> Self::Delta {
55 inputs.[<$method _derivative>]()
56 }
57 }
58 }
59 };
60}
61
62activator! {
63 pub struct Linear::<T>::linear { where T: crate::activate::LinearActivation },
64 pub struct ReLU::<T>::relu { where T: crate::activate::ReLUActivation },
65 pub struct Sigmoid::<T>::sigmoid { where T: crate::activate::SigmoidActivation },
66 pub struct HyperbolicTangent::<T>::tanh { where T: crate::activate::TanhActivation },
67 pub struct HeavySide::<T>::heavyside { where T: crate::activate::HeavysideActivation },
68 pub struct Softmax::<T>::softmax { where T: crate::activate::SoftmaxActivation },
69}