concision_utils/utils/
activate.rs1use ndarray::{Array, ArrayBase, Axis, Data, Dimension, RemoveAxis, ScalarOperand};
6use num_traits::{Float, One, Zero};
7
8pub fn heavyside<T>(x: T) -> T
10where
11    T: One + PartialOrd + Zero,
12{
13    if x > T::zero() { T::one() } else { T::zero() }
14}
15pub fn relu<T>(args: T) -> T
17where
18    T: PartialOrd + Zero,
19{
20    if args > T::zero() { args } else { T::zero() }
21}
22
23pub fn relu_derivative<T>(args: T) -> T
24where
25    T: PartialOrd + One + Zero,
26{
27    if args > T::zero() {
28        T::one()
29    } else {
30        T::zero()
31    }
32}
33pub fn sigmoid<T>(args: T) -> T
35where
36    T: Float,
37{
38    (T::one() + args.neg().exp()).recip()
39}
40pub fn sigmoid_derivative<T>(args: T) -> T
42where
43    T: Float,
44{
45    let s = sigmoid(args);
46    s * (T::one() - s)
47}
48pub fn softmax<A, S, D>(args: &ArrayBase<S, D>) -> Array<A, D>
50where
51    A: Float + ScalarOperand,
52    D: Dimension,
53    S: Data<Elem = A>,
54{
55    let e = args.exp();
56    &e / e.sum()
57}
58pub fn softmax_axis<A, S, D>(args: &ArrayBase<S, D>, axis: usize) -> Array<A, D>
60where
61    A: Float + ScalarOperand,
62    D: RemoveAxis,
63    S: Data<Elem = A>,
64{
65    let axis = Axis(axis);
66    let e = args.exp();
67    &e / &e.sum_axis(axis)
68}
69pub fn tanh<T>(args: T) -> T
71where
72    T: num::traits::Float,
73{
74    args.tanh()
75}
76pub fn tanh_derivative<T>(args: T) -> T
78where
79    T: num::traits::Float,
80{
81    let t = tanh(args);
82    T::one() - t * t
83}