concision_core/activate/
rho.rs

1/*
2    Appellation: rho <module>
3    Created At: 2026.01.13:18:09:21
4    Contrib: @FL03
5*/
6//! this module defines _structural_ implementations of various activation functions
7
8/// An [`Activator`] defines an interface for _structural_ activation functions that can be
9/// applied onto various types.
10pub trait Activator<T> {
11    type Output;
12
13    /// Applies the activation function to the input tensor.
14    fn activate(&self, input: T) -> Self::Output;
15}
16/// The [`ActivatorGradient`] trait extends the [`Activator`] trait to include a method for
17/// computing the gradient of the activation function.
18pub trait ActivatorGradient<T> {
19    type Rel: Activator<T>;
20    type Delta;
21
22    /// compute the gradient of some input
23    fn activate_gradient(&self, input: T) -> Self::Delta;
24}
25
26macro_rules! activator {
27    ($($vis:vis struct $name:ident::<$T:ident>::$method:ident $({where $($where:tt)*})?),* $(,)?) => {
28        $(activator! {
29            @impl $vis struct $name::<$T>::$method $({where $($where)*})?
30        })*
31    };
32    (@impl $vis:vis struct $name:ident::<$T:ident>::$method:ident $({where $($where:tt)*})? ) => {
33        #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
34        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
35        $vis struct $name;
36
37        impl<$T> Activator<$T> for $name
38        $(where $($where)*)?
39        {
40            type Output = <$T>::Output;
41
42            fn activate(&self, x: $T) -> Self::Output {
43                x.$method()
44            }
45        }
46
47        paste::paste! {
48            impl<$T> ActivatorGradient<$T> for $name
49            $(where $($where)*)?,
50            {
51                type Rel = Self;
52                type Delta = <$T>::Output;
53
54                fn activate_gradient(&self, inputs: $T) -> Self::Delta {
55                    inputs.[<$method _derivative>]()
56                }
57            }
58        }
59    };
60}
61
62activator! {
63    pub struct Linear::<T>::linear { where T: crate::activate::LinearActivation },
64    pub struct ReLU::<T>::relu { where T: crate::activate::ReLUActivation },
65    pub struct Sigmoid::<T>::sigmoid { where T: crate::activate::SigmoidActivation },
66    pub struct HyperbolicTangent::<T>::tanh { where T: crate::activate::TanhActivation },
67    pub struct HeavySide::<T>::heavyside { where T: crate::activate::HeavysideActivation },
68    pub struct Softmax::<T>::softmax { where T: crate::activate::SoftmaxActivation },
69}