dendritic_metrics/
loss.rs

1use dendritic_ndarray::ndarray::NDArray;
2
3/// Mean squared error function
4pub fn mse(y_true: &NDArray<f64>, y_pred: &NDArray<f64>) -> Result<f64, String>  {
5
6    if y_true.size() != y_pred.size() {
7        return Err("Size of y values do not match".to_string());
8    }
9
10    let mut index = 0;
11    let mut result = 0.0;  
12    for item in y_true.values() {
13        let diff = item - y_pred.values()[index];
14        result += diff.powf(2.0);
15        index += 1; 
16    }
17
18    result = result * 1.0/y_true.size() as f64; 
19    Ok(result)
20}
21
22/// Binary cross entropy for logistic binary classification
23pub fn binary_cross_entropy(y_hat: &NDArray<f64>, y_true: &NDArray<f64>) -> Result<f64, String> {
24    
25    let mut index = 0;
26    let mut result = 0.0;  
27    for y in y_true.values() {
28      
29        let y_pred = y_hat.values()[index];
30        let diff = y * y_pred.ln() + (1.0 - y) * (1.0-y_pred).ln();
31        result += diff;
32        index += 1;
33    }
34
35    result = -(1.0/y_hat.size() as f64) * result;
36    Ok(result)
37}
38
39/// Categorical cross entropy for multi class classification
40pub fn categorical_cross_entropy(y_hat: &NDArray<f64>, y_true: &NDArray<f64>) -> Result<f64, String> {
41
42    let mut index = 0;
43    let mut result = 0.0;  
44    for y in y_true.values() {
45        let y_pred = y_hat.values()[index];
46        let diff = -y * y_pred.ln();
47        result += diff;
48        index += 1;
49    }
50
51    Ok(result * 1.0/y_hat.size() as f64)
52
53}
54