1use crate::nab_array::NDArray;
2
3pub trait Loss {
4 fn forward(&self, predictions: &NDArray, targets: &NDArray) -> f64;
5 fn backward(&self, predictions: &NDArray, targets: &NDArray) -> NDArray;
6}
7
8pub struct NabLoss;
24
25impl NabLoss {
26 #[allow(dead_code)]
37 pub fn mean_squared_error(y_true: &NDArray, y_pred: &NDArray) -> f64 {
38 assert_eq!(y_true.shape(), y_pred.shape(), "Shapes of y_true and y_pred must match");
39 let diff = y_true.data().iter().zip(y_pred.data().iter()).map(|(t, p)| (t - p).powi(2)).collect::<Vec<f64>>();
40 diff.iter().sum::<f64>() / y_true.data().len() as f64
41 }
42
43 #[allow(dead_code)]
54 pub fn cross_entropy_loss(y_true: &NDArray, y_pred: &NDArray) -> f64 {
55 assert_eq!(y_true.shape(), y_pred.shape(), "Shapes of y_true and y_pred must match");
56 let epsilon = 1e-8;
57 let clipped_pred = y_pred.data().iter().map(|&p| p.clamp(epsilon, 1.0 - epsilon)).collect::<Vec<f64>>();
58 let loss = y_true.data().iter().zip(clipped_pred.iter()).map(|(t, p)| t * p.ln()).collect::<Vec<f64>>();
59 -loss.iter().sum::<f64>() / y_true.shape()[0] as f64
60 }
61
62}
63
64impl Loss for NabLoss {
65 fn forward(&self, predictions: &NDArray, targets: &NDArray) -> f64 {
66 NabLoss::mean_squared_error(predictions, targets)
67 }
68
69 fn backward(&self, predictions: &NDArray, targets: &NDArray) -> NDArray {
70 predictions.subtract(targets).multiply_scalar(2.0 / predictions.shape()[0] as f64)
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78
79 #[test]
80 pub fn test_mean_squared_error() {
81 let y_true = NDArray::from_vec(vec![1.0, 0.0, 1.0, 1.0]);
82 let y_pred = NDArray::from_vec(vec![0.9, 0.2, 0.8, 0.6]);
83 let mse = NabLoss::mean_squared_error(&y_true, &y_pred);
84 assert!((mse - 0.0625).abs() < 1e-4);
85 }
86
87 #[test]
88 pub fn test_cross_entropy_loss() {
89 let y_true = NDArray::from_matrix(vec![
90 vec![1.0, 0.0, 0.0],
91 vec![0.0, 1.0, 0.0],
92 vec![0.0, 0.0, 1.0],
93 ]);
94 let y_pred = NDArray::from_matrix(vec![
95 vec![0.7, 0.2, 0.1],
96 vec![0.1, 0.8, 0.1],
97 vec![0.05, 0.15, 0.8],
98 ]);
99 let cross_entropy = NabLoss::cross_entropy_loss(&y_true, &y_pred);
100 assert!((cross_entropy - 0.267654016).abs() < 1e-4);
101 }
102}