concision_core/activate/impls/
impl_nonlinear.rs

1/*
2    Appellation: sigmoid <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::activate::{ReLU, Sigmoid, Softmax, Tanh, utils::sigmoid_derivative};
6
7use ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
8use num_traits::{Float, One, Zero};
9
10impl<A, S, D> ReLU for ArrayBase<S, D>
11where
12    A: Copy + PartialOrd + Zero + One,
13    S: Data<Elem = A>,
14    D: Dimension,
15{
16    type Output = Array<A, D>;
17
18    fn relu(&self) -> Self::Output {
19        self.map(|&i| if i > A::zero() { i } else { A::zero() })
20    }
21
22    fn relu_derivative(&self) -> Self::Output {
23        self.map(|&i| if i > A::zero() { A::one() } else { A::zero() })
24    }
25}
26
27impl<A, S, D> Sigmoid for ArrayBase<S, D>
28where
29    A: ScalarOperand + Float,
30    S: Data<Elem = A>,
31    D: Dimension,
32{
33    type Output = Array<A, D>;
34
35    fn sigmoid(self) -> Self::Output {
36        let dim = self.dim();
37        let ones = Array::<A, D>::ones(dim);
38
39        (ones + self.map(|&i| i.neg().exp())).recip()
40    }
41
42    fn sigmoid_derivative(self) -> Self::Output {
43        self.mapv(|i| sigmoid_derivative(i))
44    }
45}
46
47impl<A, S, D> Softmax for ArrayBase<S, D>
48where
49    A: ScalarOperand + Float,
50    S: Data<Elem = A>,
51    D: Dimension,
52{
53    type Output = Array<A, D>;
54
55    fn softmax(&self) -> Self::Output {
56        let e = self.exp();
57        &e / e.sum()
58    }
59
60    fn softmax_derivative(&self) -> Self::Output {
61        let e = self.exp();
62        let sum = e.sum();
63        let softmax = &e / sum;
64
65        let ones = Array::<A, D>::ones(self.dim());
66        &softmax * (&ones - &softmax)
67    }
68}
69
70impl<A, S, D> Tanh for ArrayBase<S, D>
71where
72    A: ScalarOperand + Float,
73    S: Data<Elem = A>,
74    D: Dimension,
75{
76    type Output = Array<A, D>;
77
78    fn tanh(&self) -> Self::Output {
79        self.map(|i| i.tanh())
80    }
81
82    fn tanh_derivative(&self) -> Self::Output {
83        self.map(|i| A::one() - i.tanh().powi(2))
84    }
85}