concision_neural/layers/traits/
activate.rs

1/*
2    appellation: activate <module>
3    authors: @FL03
4*/
5/// The [`Activator`] trait defines a method for applying an activation function to an input
6/// tensor.
7pub trait Activator<T> {
8    type Output;
9
10    /// Applies the activation function to the input tensor.
11    fn activate(&self, input: T) -> Self::Output;
12}
13/// The [`ActivatorGradient`] trait extends the [`Activator`] trait to include a method for
14/// computing the gradient of the activation function.
15pub trait ActivatorGradient<T>: Activator<T> {
16    type Input;
17    type Delta;
18
19    /// compute the gradient of some input
20    fn activate_gradient(&self, input: Self::Input) -> Self::Delta;
21}
22
23/*
24 ************* Implementations *************
25*/
26impl<A, B, T> Activator<A> for &T
27where
28    T: Activator<A, Output = B>,
29{
30    type Output = B;
31
32    fn activate(&self, rhs: A) -> Self::Output {
33        (*self).activate(rhs)
34    }
35}
36
37impl<A, B, C, T> ActivatorGradient<A> for &T
38where
39    T: ActivatorGradient<A, Input = B, Delta = C>,
40{
41    type Input = B;
42    type Delta = C;
43
44    fn activate_gradient(&self, inputs: Self::Input) -> Self::Delta {
45        (*self).activate_gradient(inputs)
46    }
47}
48
49impl<X, Y> Activator<X> for dyn Fn(X) -> Y {
50    type Output = Y;
51
52    fn activate(&self, rhs: X) -> Self::Output {
53        self(rhs)
54    }
55}
56
57#[cfg(feature = "alloc")]
58mod impl_alloc {
59    use super::Activator;
60    use alloc::boxed::Box;
61
62    impl<X, Y> Activator<X> for Box<dyn Activator<X, Output = Y>> {
63        type Output = Y;
64
65        fn activate(&self, rhs: X) -> Self::Output {
66            self.as_ref().activate(rhs)
67        }
68    }
69}
70
71/*
72 ************* Implementations *************
73*/
74macro_rules! activator {
75    (@impl $vis:vis struct $name:ident::<$($trait:ident)::*>($method:ident) ) => {
76
77        #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
78        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
79        $vis struct $name;
80
81        impl<U> Activator<U> for $name
82        where
83            U: $($trait)::*,
84        {
85            type Output = U::Output;
86
87            fn activate(&self, x: U) -> Self::Output {
88                x.$method()
89            }
90        }
91
92        paste::paste! {
93            impl<U> ActivatorGradient<U> for $name
94            where
95                U: $($trait)::*,
96            {
97                type Input = U;
98                type Delta = U::Output;
99
100                fn activate_gradient(&self, inputs: U) -> Self::Delta {
101                    inputs.[<$method _derivative>]()
102                }
103            }
104        }
105    };
106    ($(
107        $vis:vis struct $name:ident::<$($trait:ident)::*>($method:ident)
108    );* $(;)?) => {
109        $(
110            activator!(@impl $vis struct $name::<$($trait)::*>($method));
111        )*
112    };
113}
114
115activator! {
116    pub struct Linear::<cnc::LinearActivation>(linear);
117    pub struct ReLU::<cnc::ReLU>(relu);
118    pub struct Sigmoid::<cnc::Sigmoid>(sigmoid);
119    pub struct Tanh::<cnc::Tanh>(tanh);
120}