concision_core/activate/utils/
non_linear.rs1use ndarray::{Array, ArrayBase, Axis, Data, Dimension, RemoveAxis, ScalarOperand};
6use num_traits::{Float, One, Zero};
7
8pub fn relu<T>(args: T) -> T
10where
11 T: PartialOrd + Zero,
12{
13 if args > T::zero() { args } else { T::zero() }
14}
15
16pub fn relu_derivative<T>(args: T) -> T
17where
18 T: PartialOrd + One + Zero,
19{
20 if args > T::zero() {
21 T::one()
22 } else {
23 T::zero()
24 }
25}
26pub fn sigmoid<T>(args: T) -> T
28where
29 T: Float,
30{
31 (T::one() + args.neg().exp()).recip()
32}
33pub fn sigmoid_derivative<T>(args: T) -> T
35where
36 T: Float,
37{
38 let s = sigmoid(args);
39 s * (T::one() - s)
40}
41pub fn softmax<A, S, D>(args: &ArrayBase<S, D>) -> Array<A, D>
43where
44 A: Float + ScalarOperand,
45 D: Dimension,
46 S: Data<Elem = A>,
47{
48 let e = args.exp();
49 &e / e.sum()
50}
51pub fn softmax_axis<A, S, D>(args: &ArrayBase<S, D>, axis: usize) -> Array<A, D>
53where
54 A: Float + ScalarOperand,
55 D: RemoveAxis,
56 S: Data<Elem = A>,
57{
58 let axis = Axis(axis);
59 let e = args.exp();
60 &e / &e.sum_axis(axis)
61}
62pub fn tanh<T>(args: T) -> T
64where
65 T: num::traits::Float,
66{
67 args.tanh()
68}
69pub fn tanh_derivative<T>(args: T) -> T
71where
72 T: num::traits::Float,
73{
74 let t = tanh(args);
75 T::one() - t * t
76}