pub trait Activator<T> {
type Output;
fn activate(&self, input: T) -> Self::Output;
}
pub trait ActivatorGradient<T> {
type Rel: Activator<T>;
type Delta;
fn activate_gradient(&self, input: T) -> Self::Delta;
}
macro_rules! activator {
($($vis:vis struct $name:ident::<$T:ident>::$method:ident $({where $($where:tt)*})?),* $(,)?) => {
$(activator! {
@impl $vis struct $name::<$T>::$method $({where $($where)*})?
})*
};
(@impl $vis:vis struct $name:ident::<$T:ident>::$method:ident $({where $($where:tt)*})? ) => {
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
$vis struct $name;
impl<$T> Activator<$T> for $name
$(where $($where)*)?
{
type Output = <$T>::Output;
fn activate(&self, x: $T) -> Self::Output {
x.$method()
}
}
paste::paste! {
impl<$T> ActivatorGradient<$T> for $name
$(where $($where)*)?,
{
type Rel = Self;
type Delta = <$T>::Output;
fn activate_gradient(&self, inputs: $T) -> Self::Delta {
inputs.[<$method _derivative>]()
}
}
}
};
}
activator! {
pub struct Linear::<T>::linear { where T: crate::activate::LinearActivation },
pub struct ReLU::<T>::relu { where T: crate::activate::ReLUActivation },
pub struct Sigmoid::<T>::sigmoid { where T: crate::activate::SigmoidActivation },
pub struct HyperbolicTangent::<T>::tanh { where T: crate::activate::TanhActivation },
pub struct HeavySide::<T>::heavyside { where T: crate::activate::HeavysideActivation },
pub struct Softmax::<T>::softmax { where T: crate::activate::SoftmaxActivation },
}