concision_neural/layers/traits/
activate.rs1pub trait Activator<T> {
8 type Output;
9
10 fn activate(&self, input: T) -> Self::Output;
12}
13pub trait ActivatorGradient<T>: Activator<T> {
16 type Input;
17 type Delta;
18
19 fn activate_gradient(&self, input: Self::Input) -> Self::Delta;
21}
22
23impl<A, B, T> Activator<A> for &T
27where
28 T: Activator<A, Output = B>,
29{
30 type Output = B;
31
32 fn activate(&self, rhs: A) -> Self::Output {
33 (*self).activate(rhs)
34 }
35}
36
37impl<A, B, C, T> ActivatorGradient<A> for &T
38where
39 T: ActivatorGradient<A, Input = B, Delta = C>,
40{
41 type Input = B;
42 type Delta = C;
43
44 fn activate_gradient(&self, inputs: Self::Input) -> Self::Delta {
45 (*self).activate_gradient(inputs)
46 }
47}
48
49impl<X, Y> Activator<X> for dyn Fn(X) -> Y {
50 type Output = Y;
51
52 fn activate(&self, rhs: X) -> Self::Output {
53 self(rhs)
54 }
55}
56
57#[cfg(feature = "alloc")]
58mod impl_alloc {
59 use super::Activator;
60 use alloc::boxed::Box;
61
62 impl<X, Y> Activator<X> for Box<dyn Activator<X, Output = Y>> {
63 type Output = Y;
64
65 fn activate(&self, rhs: X) -> Self::Output {
66 self.as_ref().activate(rhs)
67 }
68 }
69}
70
71macro_rules! activator {
75 (@impl $vis:vis struct $name:ident::<$($trait:ident)::*>($method:ident) ) => {
76
77 #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
78 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
79 $vis struct $name;
80
81 impl<U> Activator<U> for $name
82 where
83 U: $($trait)::*,
84 {
85 type Output = U::Output;
86
87 fn activate(&self, x: U) -> Self::Output {
88 x.$method()
89 }
90 }
91
92 paste::paste! {
93 impl<U> ActivatorGradient<U> for $name
94 where
95 U: $($trait)::*,
96 {
97 type Input = U;
98 type Delta = U::Output;
99
100 fn activate_gradient(&self, inputs: U) -> Self::Delta {
101 inputs.[<$method _derivative>]()
102 }
103 }
104 }
105 };
106 ($(
107 $vis:vis struct $name:ident::<$($trait:ident)::*>($method:ident)
108 );* $(;)?) => {
109 $(
110 activator!(@impl $vis struct $name::<$($trait)::*>($method));
111 )*
112 };
113}
114
115activator! {
116 pub struct Linear::<cnc::LinearActivation>(linear);
117 pub struct ReLU::<cnc::ReLU>(relu);
118 pub struct Sigmoid::<cnc::Sigmoid>(sigmoid);
119 pub struct Tanh::<cnc::Tanh>(tanh);
120}