1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
use ndarray::{Array2, Axis};
use ndarray_parallel::prelude::*;
use serde_derive::{Deserialize, Serialize};
use std::cmp::Ordering;

#[derive(Serialize, Deserialize)]
pub enum Activation {
    Sigmoid,
    Linear,
    Tanh,
    Softmax,
}

impl From<String> for Activation {
    fn from(name: String) -> Self {
        match name.to_lowercase().as_str() {
            "sigmoid" => Activation::Sigmoid,
            "linear" => Activation::Linear,
            "tanh" => Activation::Tanh,
            "softmax" => Activation::Softmax,
            _ => panic!("Activation {} not supported", name)
        }
    }
}

impl std::default::Default for Activation {
    fn default() -> Self {
        Activation::Linear
    }
}

/// Softmax
pub fn softmax(x: &Array2<f32>, deriv: bool) -> Array2<f32> {
    let mut out = x.clone();
    let _ = out
        .axis_iter_mut(Axis(0))
        .map(|ref mut vec| {
            let max = vec
                .iter()
                .max_by(|a, b| a.partial_cmp(&b).unwrap_or_else(|| Ordering::Equal))
                .unwrap();
            let exps = vec.mapv(|v| (v - max).exp());
            let result = &exps / exps.sum();
            vec.zip_mut_with(&result, |v, r| *v = *r);

            // Derivative
            if deriv {
                let _shape = vec.shape();
                let s = vec.to_owned();
                let result = s.diag().to_owned() - s.dot(&s.t());
                vec.zip_mut_with(&result, |v, r| *v = *r);
            }
        })
        .collect::<Vec<()>>();
    out
}

/// Tanh
pub fn tanh(x: &Array2<f32>, deriv: bool) -> Array2<f32> {
    let mut out = x.clone();

    if deriv {
        out.par_mapv_inplace(_tanh_prime);
    } else {
        out.par_mapv_inplace(_tanh);
    }
    out
}

fn _tanh(x: f32) -> f32 {
    let y = (-2.0 * x).exp();
    (1.0 - y) / (1.0 + y)
}
fn _tanh_prime(x: f32) -> f32 {
    _tanh(x) * (1.0 - _tanh(x))
}

/// Calculate result of a sigmoid on `ArrayD`
pub fn sigmoid(x: &Array2<f32>, deriv: bool) -> Array2<f32> {
    // TODO: Make this an inplace (&mut) process
    let mut out = x.clone();

    if deriv {
        out.par_mapv_inplace(_sigmoid_prime);
    } else {
        out.par_mapv_inplace(_sigmoid);
    }
    out
}

fn _sigmoid_prime(x: f32) -> f32 {
    _sigmoid(x) * (1. - _sigmoid(x))
}
fn _sigmoid(x: f32) -> f32 {
    1.0 / (1.0 + (-x).exp())
}