concision_utils/utils/
activate.rs

1/*
2    Appellation: utils <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use ndarray::{Array, ArrayBase, Axis, Data, Dimension, RemoveAxis, ScalarOperand};
6use num_traits::{Float, One, Zero};
7
8/// Heaviside activation function
9pub fn heavyside<T>(x: T) -> T
10where
11    T: One + PartialOrd + Zero,
12{
13    if x > T::zero() { T::one() } else { T::zero() }
14}
15/// the relu activation function: $f(x) = \max(0, x)$
16pub fn relu<T>(args: T) -> T
17where
18    T: PartialOrd + Zero,
19{
20    if args > T::zero() { args } else { T::zero() }
21}
22
23pub fn relu_derivative<T>(args: T) -> T
24where
25    T: PartialOrd + One + Zero,
26{
27    if args > T::zero() {
28        T::one()
29    } else {
30        T::zero()
31    }
32}
33/// the sigmoid activation function: $f(x) = \frac{1}{1 + e^{-x}}$
34pub fn sigmoid<T>(args: T) -> T
35where
36    T: Float,
37{
38    (T::one() + args.neg().exp()).recip()
39}
40/// the derivative of the sigmoid function
41pub fn sigmoid_derivative<T>(args: T) -> T
42where
43    T: Float,
44{
45    let s = sigmoid(args);
46    s * (T::one() - s)
47}
48/// Softmax function: $f(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$
49pub fn softmax<A, S, D>(args: &ArrayBase<S, D>) -> Array<A, D>
50where
51    A: Float + ScalarOperand,
52    D: Dimension,
53    S: Data<Elem = A>,
54{
55    let e = args.exp();
56    &e / e.sum()
57}
58/// Softmax function along a specific axis: $f(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$
59pub fn softmax_axis<A, S, D>(args: &ArrayBase<S, D>, axis: usize) -> Array<A, D>
60where
61    A: Float + ScalarOperand,
62    D: RemoveAxis,
63    S: Data<Elem = A>,
64{
65    let axis = Axis(axis);
66    let e = args.exp();
67    &e / &e.sum_axis(axis)
68}
69/// the tanh activation function: $f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}$
70pub fn tanh<T>(args: T) -> T
71where
72    T: num::traits::Float,
73{
74    args.tanh()
75}
76/// the derivative of the tanh function
77pub fn tanh_derivative<T>(args: T) -> T
78where
79    T: num::traits::Float,
80{
81    let t = tanh(args);
82    T::one() - t * t
83}