1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
use core::fmt::Debug;
use serde::{de::DeserializeOwned, Serialize};
pub trait Activation: Debug + Serialize + DeserializeOwned + Clone {
fn activate(inputs: f32) -> f32;
fn derivate(activation: f32) -> f32;
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum Sigmoid {}
impl Activation for Sigmoid {
#[inline(always)]
fn activate(input: f32) -> f32 {
let clamped = input.max(-20.).min(20.);
let res = 1. / (1. + (-clamped).exp());
debug_assert!(!res.is_nan());
res
}
#[inline(always)]
fn derivate(activation: f32) -> f32 {
debug_assert!(!activation.is_nan());
let res = activation * (1. - activation);
debug_assert!(!res.is_nan());
res
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum ReLu {}
impl Activation for ReLu {
#[inline(always)]
fn activate(input: f32) -> f32 { input.max(0.) }
#[inline(always)]
fn derivate(activation: f32) -> f32 { if activation > 0. { 1. } else { 0.01 } }
}
#[doc(hidden)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum SoftMax {}
impl Activation for SoftMax {
#[inline(always)]
fn activate(_input: f32) -> f32 { panic!() }
#[inline(always)]
fn derivate(_activation: f32) -> f32 { panic!() }
}