concision_core/activate/impls/
impl_activate_linear.rs1use crate::activate::{HeavysideActivation, LinearActivation};
7use ndarray::{Array, ArrayBase, Data, DataMut, Dimension};
8use num_traits::{One, Zero};
9
10macro_rules! impl_heavyside {
11 ($($T:ty),* $(,)*) => {
12 $(
13 impl $crate::activate::HeavysideActivation for $T {
14 type Output = $T;
15
16 fn heavyside(self) -> Self::Output {
17 if self > <$T>::zero() {
18 <$T>::one()
19 } else {
20 <$T>::zero()
21 }
22 }
23
24 fn heavyside_derivative(self) -> Self::Output {
25 if self > <$T>::zero() {
26 <$T>::one()
27 } else {
28 <$T>::zero()
29 }
30 }
31 }
32 )*
33 };
34}
35
36macro_rules! impl_linear {
37 ($($T:ty),* $(,)*) => {
38 $(
39 impl $crate::activate::LinearActivation for $T {
40 type Output = $T;
41
42 fn linear(self) -> Self::Output {
43 self
44 }
45
46 fn linear_derivative(self) -> Self::Output {
47 <$T>::one()
48 }
49 }
50 )*
51 };
52}
53
54impl_heavyside! {
55 i8, i16, i32, i64, i128, isize,
56 u8, u16, u32, u64, u128, usize,
57 f32, f64,
58}
59
60impl_linear! {
61 i8, i16, i32, i64, i128, isize,
62 u8, u16, u32, u64, u128, usize,
63 f32, f64,
64}
65
66impl<A, B, S, D> HeavysideActivation for ArrayBase<S, D, A>
67where
68 A: Clone + HeavysideActivation<Output = B>,
69 D: Dimension,
70 S: Data<Elem = A>,
71{
72 type Output = Array<B, D>;
73
74 fn heavyside(self) -> Self::Output {
75 self.mapv(HeavysideActivation::heavyside)
76 }
77
78 fn heavyside_derivative(self) -> Self::Output {
79 self.mapv(HeavysideActivation::heavyside_derivative)
80 }
81}
82
83impl<A, B, S, D> HeavysideActivation for &ArrayBase<S, D, A>
84where
85 A: Clone + HeavysideActivation<Output = B>,
86 D: Dimension,
87 S: Data<Elem = A>,
88{
89 type Output = Array<B, D>;
90
91 fn heavyside(self) -> Self::Output {
92 self.mapv(HeavysideActivation::heavyside)
93 }
94
95 fn heavyside_derivative(self) -> Self::Output {
96 self.mapv(HeavysideActivation::heavyside_derivative)
97 }
98}
99
100impl<A, S, D> LinearActivation for ArrayBase<S, D, A>
101where
102 A: Clone + One,
103 D: Dimension,
104 S: DataMut<Elem = A>,
105{
106 type Output = ArrayBase<S, D, A>;
107
108 fn linear(self) -> Self::Output {
109 self
110 }
111
112 fn linear_derivative(self) -> Self::Output {
113 self.mapv_into(|_| <A>::one())
114 }
115}