nabla_ml/
nab_loss.rs

1use crate::nab_array::NDArray;
2
3pub trait Loss {
4    fn forward(&self, predictions: &NDArray, targets: &NDArray) -> f64;
5    fn backward(&self, predictions: &NDArray, targets: &NDArray) -> NDArray;
6}
7
8// pub struct CategoricalCrossentropy;
9
10// impl Loss for CategoricalCrossentropy {
11//     fn forward(&self, y_pred: &NDArray, y_true: &NDArray) -> f64 {
12//         // Compute cross-entropy loss
13//         let epsilon = 1e-8;
14//         let clipped_pred = y_pred.clip(epsilon, 1.0 - epsilon);
15//         -y_true.multiply(&clipped_pred.log()).sum() / y_true.shape()[0] as f64
16//     }
17
18//     fn backward(&self, y_pred: &NDArray, y_true: &NDArray) -> NDArray {
19//         NDArray::crossentropy_nabla(y_pred, y_true)
20//     }
21// }
22
23pub struct NabLoss;
24
25impl NabLoss {
26    /// Calculates the Mean Squared Error (MSE) between two arrays
27    ///
28    /// # Arguments
29    ///
30    /// * `y_true` - The true values as an NDArray.
31    /// * `y_pred` - The predicted values as an NDArray.
32    ///
33    /// # Returns
34    ///
35    /// The MSE as a f64.
36    #[allow(dead_code)]
37    pub fn mean_squared_error(y_true: &NDArray, y_pred: &NDArray) -> f64 {
38        assert_eq!(y_true.shape(), y_pred.shape(), "Shapes of y_true and y_pred must match");
39        let diff = y_true.data().iter().zip(y_pred.data().iter()).map(|(t, p)| (t - p).powi(2)).collect::<Vec<f64>>();
40        diff.iter().sum::<f64>() / y_true.data().len() as f64
41    }
42
43    /// Calculates the Cross-Entropy Loss between two arrays
44    ///
45    /// # Arguments
46    ///
47    /// * `y_true` - The true values as an NDArray (one-hot encoded).
48    /// * `y_pred` - The predicted probabilities as an NDArray.
49    ///
50    /// # Returns
51    ///
52    /// The Cross-Entropy Loss as a f64.    
53    #[allow(dead_code)]
54    pub fn cross_entropy_loss(y_true: &NDArray, y_pred: &NDArray) -> f64 {
55        assert_eq!(y_true.shape(), y_pred.shape(), "Shapes of y_true and y_pred must match");
56        let epsilon = 1e-8;
57        let clipped_pred = y_pred.data().iter().map(|&p| p.clamp(epsilon, 1.0 - epsilon)).collect::<Vec<f64>>();
58        let loss = y_true.data().iter().zip(clipped_pred.iter()).map(|(t, p)| t * p.ln()).collect::<Vec<f64>>();
59        -loss.iter().sum::<f64>() / y_true.shape()[0] as f64
60    }
61
62}
63
64impl Loss for NabLoss {
65    fn forward(&self, predictions: &NDArray, targets: &NDArray) -> f64 {
66        NabLoss::mean_squared_error(predictions, targets)
67    }
68
69    fn backward(&self, predictions: &NDArray, targets: &NDArray) -> NDArray {
70        // Default to MSE gradient
71        predictions.subtract(targets).multiply_scalar(2.0 / predictions.shape()[0] as f64)
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    #[test]
80    pub fn test_mean_squared_error() {
81        let y_true = NDArray::from_vec(vec![1.0, 0.0, 1.0, 1.0]);
82        let y_pred = NDArray::from_vec(vec![0.9, 0.2, 0.8, 0.6]);
83        let mse = NabLoss::mean_squared_error(&y_true, &y_pred);
84        assert!((mse - 0.0625).abs() < 1e-4);
85    }
86
87    #[test]
88    pub fn test_cross_entropy_loss() {
89        let y_true = NDArray::from_matrix(vec![
90            vec![1.0, 0.0, 0.0],
91            vec![0.0, 1.0, 0.0],
92            vec![0.0, 0.0, 1.0],
93        ]);
94        let y_pred = NDArray::from_matrix(vec![
95            vec![0.7, 0.2, 0.1],
96            vec![0.1, 0.8, 0.1],
97            vec![0.05, 0.15, 0.8],
98        ]);
99        let cross_entropy = NabLoss::cross_entropy_loss(&y_true, &y_pred);
100        assert!((cross_entropy - 0.267654016).abs() < 1e-4);
101    }
102}