1use ndarray::{Array1, Array2};
2
3pub trait LossFunction {
5 fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64;
7
8 fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64>;
10
11 fn compute_batch_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
14 let batch_size = predictions.ncols();
15 let mut total_loss = 0.0;
16
17 for i in 0..batch_size {
18 let pred_col = predictions.column(i).to_owned().insert_axis(ndarray::Axis(1));
19 let target_col = targets.column(i).to_owned().insert_axis(ndarray::Axis(1));
20 total_loss += self.compute_loss(&pred_col, &target_col);
21 }
22
23 total_loss / batch_size as f64
24 }
25
26 fn compute_batch_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
29 let batch_size = predictions.ncols();
30 let mut batch_gradients = Array2::zeros(predictions.raw_dim());
31
32 for i in 0..batch_size {
33 let pred_col = predictions.column(i).to_owned().insert_axis(ndarray::Axis(1));
34 let target_col = targets.column(i).to_owned().insert_axis(ndarray::Axis(1));
35 let grad = self.compute_gradient(&pred_col, &target_col);
36 batch_gradients.column_mut(i).assign(&grad.column(0));
37 }
38
39 batch_gradients
40 }
41}
42
43pub struct MSELoss;
45
46impl LossFunction for MSELoss {
47 fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
48 let diff = predictions - targets;
49 let squared_diff = &diff * &diff;
50 squared_diff.sum() / (predictions.len() as f64)
51 }
52
53 fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
54 let diff = predictions - targets;
55 2.0 * diff / (predictions.len() as f64)
56 }
57
58 fn compute_batch_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
59 let diff = predictions - targets;
60 let squared_diff = &diff * &diff;
61 squared_diff.sum() / (predictions.len() as f64)
62 }
63
64 fn compute_batch_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
65 let diff = predictions - targets;
66 2.0 * diff / (predictions.len() as f64)
67 }
68}
69
70pub struct MAELoss;
72
73impl LossFunction for MAELoss {
74 fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
75 let diff = predictions - targets;
76 diff.map(|x| x.abs()).sum() / (predictions.len() as f64)
77 }
78
79 fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
80 let diff = predictions - targets;
81 diff.map(|x| if *x > 0.0 { 1.0 } else if *x < 0.0 { -1.0 } else { 0.0 }) / (predictions.len() as f64)
82 }
83
84 fn compute_batch_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
85 let diff = predictions - targets;
86 diff.map(|x| x.abs()).sum() / (predictions.len() as f64)
87 }
88
89 fn compute_batch_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
90 let diff = predictions - targets;
91 diff.map(|x| if *x > 0.0 { 1.0 } else if *x < 0.0 { -1.0 } else { 0.0 }) / (predictions.len() as f64)
92 }
93}
94
95pub struct CrossEntropyLoss;
97
98impl LossFunction for CrossEntropyLoss {
99 fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
100 let softmax_preds = softmax(predictions);
101 let epsilon = 1e-15;
102 let log_preds = softmax_preds.map(|x| (x + epsilon).ln());
103 -(targets * log_preds).sum() / (predictions.shape()[1] as f64)
104 }
105
106 fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
107 let softmax_preds = softmax(predictions);
108 (softmax_preds - targets) / (predictions.shape()[1] as f64)
109 }
110}
111
112pub fn softmax(x: &Array2<f64>) -> Array2<f64> {
114 let mut result = Array2::zeros(x.raw_dim());
115
116 for (i, col) in x.axis_iter(ndarray::Axis(1)).enumerate() {
117 let max_val = col.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
118 let exp_vals: Array1<f64> = col.map(|&val| (val - max_val).exp());
119 let sum_exp = exp_vals.sum();
120
121 for (j, &exp_val) in exp_vals.iter().enumerate() {
122 result[[j, i]] = exp_val / sum_exp;
123 }
124 }
125
126 result
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use ndarray::arr2;
133
134 #[test]
135 fn test_mse_loss() {
136 let loss_fn = MSELoss;
137 let predictions = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
138 let targets = arr2(&[[1.5, 2.5], [2.5, 3.5]]);
139
140 let loss = loss_fn.compute_loss(&predictions, &targets);
141 assert!((loss - 0.25).abs() < 1e-6);
142
143 let gradient = loss_fn.compute_gradient(&predictions, &targets);
144 assert_eq!(gradient.shape(), predictions.shape());
145 }
146
147 #[test]
148 fn test_mae_loss() {
149 let loss_fn = MAELoss;
150 let predictions = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
151 let targets = arr2(&[[1.5, 2.5], [2.5, 3.5]]);
152
153 let loss = loss_fn.compute_loss(&predictions, &targets);
154 assert!((loss - 0.5).abs() < 1e-6);
155
156 let gradient = loss_fn.compute_gradient(&predictions, &targets);
157 assert_eq!(gradient.shape(), predictions.shape());
158 }
159
160 #[test]
161 fn test_softmax() {
162 let input = arr2(&[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]);
163 let output = softmax(&input);
164
165 for col in output.axis_iter(ndarray::Axis(1)) {
167 let sum: f64 = col.sum();
168 assert!((sum - 1.0).abs() < 1e-6);
169 }
170 }
171}