concision_core/activate/utils/
funcs.rs1use ndarray::{Array, ArrayBase, Axis, Data, Dimension, RemoveAxis, ScalarOperand};
6use num_traits::{Float, One, Zero};
7
8pub fn relu<T>(args: T) -> T
14where
15 T: PartialOrd + Zero,
16{
17 if args > T::zero() { args } else { T::zero() }
18}
19
20pub fn relu_derivative<T>(args: T) -> T
25where
26 T: PartialOrd + One + Zero,
27{
28 if args > T::zero() {
29 T::one()
30 } else {
31 T::zero()
32 }
33}
34pub fn sigmoid<T>(args: T) -> T
40where
41 T: Float,
42{
43 (T::one() + args.neg().exp()).recip()
44}
45pub fn sigmoid_derivative<T>(args: T) -> T
47where
48 T: Float,
49{
50 let s = sigmoid(args);
51 s * (T::one() - s)
52}
53pub fn softmax<A, S, D>(args: &ArrayBase<S, D, A>) -> Array<A, D>
59where
60 A: Float + ScalarOperand,
61 D: Dimension,
62 S: Data<Elem = A>,
63{
64 let e = args.exp();
65 &e / e.sum()
66}
67pub fn softmax_axis<A, S, D>(args: &ArrayBase<S, D, A>, axis: usize) -> Array<A, D>
73where
74 A: Float + ScalarOperand,
75 D: RemoveAxis,
76 S: Data<Elem = A>,
77{
78 let axis = Axis(axis);
79 let e = args.exp();
80 &e / &e.sum_axis(axis)
81}
82pub fn tanh<T>(args: T) -> T
88where
89 T: Float,
90{
91 args.tanh()
92}
93pub fn tanh_derivative<T>(args: T) -> T
95where
96 T: Float,
97{
98 let t = tanh(args);
99 T::one() - t * t
100}
101
102pub const fn linear<T>(x: T) -> T {
105 x
106}
107
108pub fn linear_derivative<T>() -> T
111where
112 T: One,
113{
114 <T>::one()
115}
116
117pub fn heavyside<T>(x: T) -> T
123where
124 T: One + PartialOrd + Zero,
125{
126 if x > T::zero() { T::one() } else { T::zero() }
127}