concision_core/activate/impls/
impl_binary.rs1use crate::activate::Heavyside;
6use crate::utils::heavyside;
7use ndarray::{Array, ArrayBase, Data, Dimension};
8use num_traits::{One, Zero};
9
10macro_rules! impl_heavyside {
11 ($($ty:ty),* $(,)*) => {
12 $(impl_heavyside!(@impl $ty);)*
13 };
14 (@impl $ty:ty) => {
15 impl Heavyside for $ty {
16 type Output = $ty;
17
18 fn heavyside(self) -> Self::Output {
19 heavyside(self)
20 }
21
22 fn heavyside_derivative(self) -> Self::Output {
23 if self > <$ty>::zero() {
24 <$ty>::one()
25 } else {
26 <$ty>::zero()
27 }
28 }
29 }
30 };
31}
32
33impl_heavyside!(
34 f32, f64, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize,
35);
36
37impl<A, B, S, D> Heavyside for ArrayBase<S, D>
38where
39 A: Clone + Heavyside<Output = B>,
40 D: Dimension,
41 S: Data<Elem = A>,
42{
43 type Output = Array<B, D>;
44
45 fn heavyside(self) -> Self::Output {
46 self.mapv(Heavyside::heavyside)
47 }
48
49 fn heavyside_derivative(self) -> Self::Output {
50 self.mapv(Heavyside::heavyside_derivative)
51 }
52}
53
54impl<A, B, S, D> Heavyside for &ArrayBase<S, D>
55where
56 A: Clone + Heavyside<Output = B>,
57 D: Dimension,
58 S: Data<Elem = A>,
59{
60 type Output = Array<B, D>;
61
62 fn heavyside(self) -> Self::Output {
63 self.mapv(Heavyside::heavyside)
64 }
65
66 fn heavyside_derivative(self) -> Self::Output {
67 self.mapv(Heavyside::heavyside_derivative)
68 }
69}