concision_neural/layers/traits/
activate.rs

1/*
2    appellation: activate <module>
3    authors: @FL03
4*/
5/// The [`Activator`] trait defines a method for applying an activation function to an input
6/// tensor.
7pub trait Activator<T> {
8    type Output;
9
10    /// Applies the activation function to the input tensor.
11    fn activate(&self, input: T) -> Self::Output;
12}
13/// The [`ActivatorGradient`] trait extends the [`Activator`] trait to include a method for
14/// computing the gradient of the activation function.
15pub trait ActivatorGradient<T>: Activator<T> {
16    type Input;
17    type Delta;
18
19    /// compute the gradient of some input
20    fn activate_gradient(&self, input: Self::Input) -> Self::Delta;
21}
22
23/*
24 ************* Implementations *************
25*/
26
27impl<X, Y, F> Activator<X> for F
28where
29    F: Fn(X) -> Y,
30{
31    type Output = Y;
32
33    fn activate(&self, rhs: X) -> Self::Output {
34        self(rhs)
35    }
36}
37
38#[cfg(feature = "alloc")]
39mod impl_alloc {
40    use super::Activator;
41    use alloc::boxed::Box;
42
43    impl<X, Y> Activator<X> for Box<dyn Activator<X, Output = Y>> {
44        type Output = Y;
45
46        fn activate(&self, rhs: X) -> Self::Output {
47            self.as_ref().rho(rhs)
48        }
49    }
50}
51
52/*
53 ************* Implementations *************
54*/
55macro_rules! activator {
56    (@impl $vis:vis struct $name:ident::<$($trait:ident)::*>($method:ident) ) => {
57
58        #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
59        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60        $vis struct $name;
61
62        impl<U> Activator<U> for $name
63        where
64            U: $($trait)::*,
65        {
66            type Output = U::Output;
67
68            fn activate(&self, x: U) -> Self::Output {
69                x.$method()
70            }
71        }
72
73        paste::paste! {
74            impl<U> ActivatorGradient<U> for $name
75            where
76                U: $($trait)::*,
77            {
78                type Input = U;
79                type Delta = U::Output;
80
81                fn activate_gradient(&self, inputs: U) -> Self::Delta {
82                    inputs.[<$method _derivative>]()
83                }
84            }
85        }
86    };
87    ($(
88        $vis:vis struct $name:ident::<$($trait:ident)::*>($method:ident)
89    );* $(;)?) => {
90        $(
91            activator!(@impl $vis struct $name::<$($trait)::*>($method));
92        )*
93    };
94}
95
96activator! {
97    pub struct Linear::<cnc::LinearActivation>(linear);
98    pub struct ReLU::<cnc::ReLU>(relu);
99    pub struct Sigmoid::<cnc::Sigmoid>(sigmoid);
100    pub struct Tanh::<cnc::Tanh>(tanh);
101}