1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use serde::{Deserialize, Serialize};
use arrayfire::{log, pow, sum_all, Array};
#[derive(Serialize, Deserialize)]
pub enum Cost {
Quadratic,
Crossentropy,
}
impl Cost {
pub fn run(&self, y: &Array<f32>, a: &Array<f32>) -> f32 {
return match self {
Self::Quadratic => quadratic(y, a),
Self::Crossentropy => cross_entropy(y, a),
};
fn quadratic(y: &Array<f32>, a: &Array<f32>) -> f32 {
sum_all(&pow(&(y - a), &2, false)).0 as f32 / (2f32 * a.dims().get()[0] as f32)
}
fn cross_entropy(y: &Array<f32>, a: &Array<f32>) -> f32 {
let part1 = log(&(a + 1e-20)) * y;
let part2 = log(&(1f32 - a + 1e-20)) * (1f32 - y);
let mut cost: f32 = sum_all(&(part1 + part2)).0 as f32;
cost /= -(a.dims().get()[0] as f32);
return cost;
}
}
pub fn derivative(&self, y: &Array<f32>, a: &Array<f32>) -> Array<f32> {
return match self {
Self::Quadratic => a - y,
Self::Crossentropy => {
return (-1 * y) / a + (1f32 - y) / (1f32 - a);
}
};
}
}