concision_core/activate/traits/
activate.rs

1/*
2    appellation: activate <module>
3    authors: @FL03
4*/
5use super::unary::*;
6
7use crate::Apply;
8#[cfg(feature = "complex")]
9use num_complex::ComplexFloat;
10use num_traits::One;
11
12/// The [`Rho`] trait defines a set of activation functions that can be applied to an
13/// implementor of the [`Apply`] trait. It provides methods for common activation functions
14/// such as linear, heavyside, ReLU, sigmoid, and tanh, along with their derivatives.
15/// The trait is generic over a type `U`, which represents the data type of the input to the
16/// activation functions. The trait also inherits a type alias `Cont<U>` to allow for variance
17/// w.r.t. the outputs of defined methods.
18pub trait Rho<U>: Apply<U> {
19    /// the linear activation function is essentially a passthrough function, simply cloning
20    /// the content.
21    fn linear(&self) -> Self::Cont<U> {
22        self.apply(|x| x)
23    }
24
25    fn linear_derivative(&self) -> Self::Cont<U::Output>
26    where
27        U: One,
28    {
29        self.apply(|_| <U>::one())
30    }
31
32    fn heavyside(&self) -> Self::Cont<U::Output>
33    where
34        U: Heavyside,
35    {
36        self.apply(|x| x.heavyside())
37    }
38
39    fn heavyside_derivative(&self) -> Self::Cont<U::Output>
40    where
41        U: Heavyside,
42    {
43        self.apply(|x| x.heavyside_derivative())
44    }
45
46    fn relu(&self) -> Self::Cont<U::Output>
47    where
48        U: ReLU,
49    {
50        self.apply(|x| x.relu())
51    }
52
53    fn relu_derivative(&self) -> Self::Cont<U::Output>
54    where
55        U: ReLU,
56    {
57        self.apply(|x| x.relu_derivative())
58    }
59
60    fn sigmoid(&self) -> Self::Cont<U::Output>
61    where
62        U: Sigmoid,
63    {
64        self.apply(|x| x.sigmoid())
65    }
66
67    fn sigmoid_derivative(&self) -> Self::Cont<U::Output>
68    where
69        U: Sigmoid,
70    {
71        self.apply(|x| x.sigmoid_derivative())
72    }
73
74    fn tanh(&self) -> Self::Cont<U::Output>
75    where
76        U: Tanh,
77    {
78        self.apply(|x| x.tanh())
79    }
80
81    fn tanh_derivative(&self) -> Self::Cont<U::Output>
82    where
83        U: Tanh,
84    {
85        self.apply(|x| x.tanh_derivative())
86    }
87}
88
89#[cfg(feature = "complex")]
90/// The [`RhoComplex`] trait is similar to the [`Rho`] trait in that it provides various
91/// activation functions for implementos of the [`Apply`] trait, however, instead of being
92/// truly generic over a type `U`, it is generic over a type `U` that implements the
93/// [`ComplexFloat`] trait. This enables the use of complex numbers in the activation
94/// functions, something particularly useful for signal-based workloads.
95///
96/// **note**: The [`Rho`] and [`RhoComplex`] traits are not intended to be used together, hence
97/// why the implemented methods are not given alternative or unique name between the two
98/// traits. If you happen to import both within the same file, you will more than likely need
99/// to use a fully qualified syntax to disambiguate the two traits. If this becomes a problem,
100/// we may consider renaming the _complex_ methods accordingly to differentiate them from the
101/// _standard_ methods.
102pub trait RhoComplex<U>: Apply<U>
103where
104    U: ComplexFloat,
105{
106    fn sigmoid(&self) -> Self::Cont<U> {
107        self.apply(|x| U::one() / (U::one() + (-x).exp()))
108    }
109
110    fn sigmoid_derivative(&self) -> Self::Cont<U> {
111        self.apply(|x| {
112            let s = U::one() / (U::one() + (-x).exp());
113            s * (U::one() - s)
114        })
115    }
116
117    fn tanh(&self) -> Self::Cont<U> {
118        self.apply(|x| x.tanh())
119    }
120
121    fn tanh_derivative(&self) -> Self::Cont<U> {
122        self.apply(|x| {
123            let s = x.tanh();
124            U::one() - s * s
125        })
126    }
127}
128
129/*
130 ************* Implementations *************
131*/
132impl<U, S> Rho<U> for S where S: Apply<U> {}
133
134#[cfg(feature = "complex")]
135impl<U, S> RhoComplex<U> for S
136where
137    S: Apply<U>,
138    U: ComplexFloat,
139{
140}