concision_utils/utils/
activate.rs1use ndarray::{Array, ArrayBase, Axis, Data, Dimension, RemoveAxis, ScalarOperand};
6use num_traits::{Float, One, Zero};
7
8pub fn heavyside<T>(x: T) -> T
10where
11 T: One + PartialOrd + Zero,
12{
13 if x > T::zero() { T::one() } else { T::zero() }
14}
15pub fn relu<T>(args: T) -> T
17where
18 T: PartialOrd + Zero,
19{
20 if args > T::zero() { args } else { T::zero() }
21}
22
23pub fn relu_derivative<T>(args: T) -> T
24where
25 T: PartialOrd + One + Zero,
26{
27 if args > T::zero() {
28 T::one()
29 } else {
30 T::zero()
31 }
32}
33pub fn sigmoid<T>(args: T) -> T
35where
36 T: Float,
37{
38 (T::one() + args.neg().exp()).recip()
39}
40pub fn sigmoid_derivative<T>(args: T) -> T
42where
43 T: Float,
44{
45 let s = sigmoid(args);
46 s * (T::one() - s)
47}
48pub fn softmax<A, S, D>(args: &ArrayBase<S, D>) -> Array<A, D>
50where
51 A: Float + ScalarOperand,
52 D: Dimension,
53 S: Data<Elem = A>,
54{
55 let e = args.exp();
56 &e / e.sum()
57}
58pub fn softmax_axis<A, S, D>(args: &ArrayBase<S, D>, axis: usize) -> Array<A, D>
60where
61 A: Float + ScalarOperand,
62 D: RemoveAxis,
63 S: Data<Elem = A>,
64{
65 let axis = Axis(axis);
66 let e = args.exp();
67 &e / &e.sum_axis(axis)
68}
69pub fn tanh<T>(args: T) -> T
71where
72 T: num::traits::Float,
73{
74 args.tanh()
75}
76pub fn tanh_derivative<T>(args: T) -> T
78where
79 T: num::traits::Float,
80{
81 let t = tanh(args);
82 T::one() - t * t
83}