1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use serde::{Deserialize, Serialize};

use arrayfire::{log, pow, sum_all, Array};

/// Defines cost function of a neural network.
#[derive(Serialize, Deserialize)]
pub enum Cost {
    /// Quadratic cost function.
    ///
    /// $ C(w,b)=\frac{1}{2n}\sum_{x} ||y(x)-a(x) ||^2 $
    Quadratic,
    /// Crossentropy cost function.
    ///
    /// $ C(w,b) = -\frac{1}{n} \sum_{x} (y(x) \ln{(a(x))}  + (1-y(x)) \ln{(1-a(x))}) $
    Crossentropy,
}
impl Cost {
    /// Runs cost functions.
    ///
    /// y: Target out, a: Actual out.
    pub fn run(&self, y: &Array<f32>, a: &Array<f32>) -> f32 {
        return match self {
            Self::Quadratic => quadratic(y, a),
            Self::Crossentropy => cross_entropy(y, a),
        };
        // Quadratic cost
        fn quadratic(y: &Array<f32>, a: &Array<f32>) -> f32 {
            sum_all(&pow(&(y - a), &2, false)).0 as f32 / (2f32 * a.dims().get()[0] as f32)
        }
        // Cross entropy cost
        // TODO Need to double check this
        fn cross_entropy(y: &Array<f32>, a: &Array<f32>) -> f32 {
            // Adds very small value to a, to prevent log(0)=nan
            let part1 = log(&(a + 1e-20)) * y;
            // Add very small value to prevent log(1-1)=log(0)=nan
            let part2 = log(&(1f32 - a + 1e-20)) * (1f32 - y);

            let mut cost: f32 = sum_all(&(part1 + part2)).0 as f32;

            //if cost.is_nan() { panic!("nan cost"); }

            cost /= -(a.dims().get()[0] as f32);

            return cost;
        }
    }
    /// Derivative w.r.t. layer output (∂C/∂a).
    ///
    /// y: Target out, a: Actual out.
    pub fn derivative(&self, y: &Array<f32>, a: &Array<f32>) -> Array<f32> {
        return match self {
            Self::Quadratic => a - y,
            Self::Crossentropy => {
                // TODO Double check we don't need to add a val to prevent 1-a=0 (commented out code below checks count of values where a>=1)
                //let check = sum_all(&arrayfire::ge(a,&1f32,false)).0;
                //if check != 0f64 { panic!("check: {}",check); }

                return (-1 * y) / a + (1f32 - y) / (1f32 - a);
            } // -y/a + (1-y)/(1-a)
        };
    }
}