concision_core/activate/impls/
impl_binary.rs1use crate::activate::{utils::heavyside, Heavyside};
6use ndarray::{Array, ArrayBase, Data, Dimension};
7use num_traits::{One, Zero};
8
9macro_rules! impl_heavyside {
10 ($($ty:ty),* $(,)*) => {
11 $(impl_heavyside!(@impl $ty);)*
12 };
13 (@impl $ty:ty) => {
14 impl Heavyside for $ty {
15 type Output = $ty;
16
17 fn heavyside(self) -> Self::Output {
18 heavyside(self)
19 }
20
21 fn heavyside_derivative(self) -> Self::Output {
22 if self > <$ty>::zero() {
23 <$ty>::one()
24 } else {
25 <$ty>::zero()
26 }
27 }
28 }
29 };
30}
31
32impl_heavyside!(
33 f32, f64, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize,
34);
35
36impl<A, B, S, D> Heavyside for ArrayBase<S, D>
37where
38 A: Clone + Heavyside<Output = B>,
39 D: Dimension,
40 S: Data<Elem = A>,
41{
42 type Output = Array<B, D>;
43
44 fn heavyside(self) -> Self::Output {
45 self.mapv(Heavyside::heavyside)
46 }
47
48 fn heavyside_derivative(self) -> Self::Output {
49 self.mapv(Heavyside::heavyside_derivative)
50 }
51}
52
53impl<A, B, S, D> Heavyside for &ArrayBase<S, D>
54where
55 A: Clone + Heavyside<Output = B>,
56 D: Dimension,
57 S: Data<Elem = A>,
58{
59 type Output = Array<B, D>;
60
61 fn heavyside(self) -> Self::Output {
62 self.mapv(Heavyside::heavyside)
63 }
64
65 fn heavyside_derivative(self) -> Self::Output {
66 self.mapv(Heavyside::heavyside_derivative)
67 }
68}