1use std::{collections::HashSet, error::Error};
2
3use nalgebra::{DMatrix, DVector};
4
5use crate::data::dataset::WholeNumber;
6
7type ConfusionMatrix = DMatrix<usize>;
8
9pub trait ClassificationMetrics<T: WholeNumber> {
10 fn confusion_matrix(
21 &self,
22 y_true: &DVector<T>,
23 y_pred: &DVector<T>,
24 ) -> Result<ConfusionMatrix, Box<dyn Error>> {
25 if y_true.len() != y_pred.len() {
26 return Err("Predictions and labels are of different sizes.".into());
27 }
28
29 let mut classes_set = HashSet::<T>::new();
30 classes_set.extend(y_true);
31 classes_set.extend(y_pred);
32
33 let mut classes = Vec::from_iter(classes_set.iter().cloned());
34 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
35
36 let mut matrix = DMatrix::zeros(classes_set.len(), classes_set.len());
37
38 for (y_t, y_p) in y_true.iter().zip(y_pred.iter()) {
39 let matrix_row = classes.iter().position(|&c| c == *y_t).unwrap();
40 let matrix_col = classes.iter().position(|&c| c == *y_p).unwrap();
41 matrix[(matrix_row, matrix_col)] += 1;
42 }
43
44 Ok(matrix)
45 }
46
47 fn accuracy(&self, y_true: &DVector<T>, y_pred: &DVector<T>) -> Result<f64, Box<dyn Error>> {
58 let matrix = self.confusion_matrix(y_true, y_pred)?;
59
60 let mut correct = 0;
61
62 matrix.diagonal().iter().for_each(|e| correct += e);
63
64 Ok(correct as f64 / y_true.len() as f64)
65 }
66
67 fn precision(&self, y_true: &DVector<T>, y_pred: &DVector<T>) -> Result<f64, Box<dyn Error>> {
78 let matrix = self.confusion_matrix(y_true, y_pred)?;
79
80 let num_classes = matrix.nrows();
81
82 if num_classes == 2 {
83 let tp = matrix[(1, 1)];
84 let fp = matrix[(0, 1)];
85
86 if tp + fp > 0 {
87 return Ok(tp as f64 / (tp + fp) as f64);
88 }
89 }
90
91 let mut precision_total = 0.0;
92 for class in 0..num_classes {
93 let tp = matrix[(class, class)];
94 let fp = matrix.column(class).sum() - tp;
95
96 if tp + fp > 0 {
97 let precision = tp as f64 / (tp + fp) as f64;
98 precision_total += precision;
99 }
100 }
101
102 Ok(precision_total / num_classes as f64)
103 }
104
105 fn recall(&self, y_true: &DVector<T>, y_pred: &DVector<T>) -> Result<f64, Box<dyn Error>> {
116 let matrix = self.confusion_matrix(y_true, y_pred)?;
117
118 let num_classes = matrix.nrows();
119
120 if num_classes == 2 {
121 let tp = matrix[(1, 1)];
122 let fn_ = matrix[(1, 0)];
123
124 if tp + fn_ > 0 {
125 return Ok(tp as f64 / (tp + fn_) as f64);
126 }
127 }
128
129 let mut recall_total = 0.0;
130
131 for class in 0..num_classes {
132 let tp = matrix[(class, class)];
133 let fn_ = matrix.row(class).sum() - tp;
134
135 if tp + fn_ > 0 {
136 let recall = tp as f64 / (tp + fn_) as f64;
137 recall_total += recall;
138 }
139 }
140
141 Ok(recall_total / num_classes as f64)
142 }
143
144 fn f1_score(&self, y_true: &DVector<T>, y_pred: &DVector<T>) -> Result<f64, Box<dyn Error>> {
155 let precision = self.precision(y_true, y_pred)?;
156 let recall = self.recall(y_true, y_pred)?;
157
158 match (precision + recall).abs() < std::f64::EPSILON {
159 true => Err("Precision and recall are both 0, F1 score undefined.".into()),
160 false => Ok(2.0 * (precision * recall) / (precision + recall)),
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use nalgebra::DVector;
169
170 struct MockClassifier;
171
172 impl ClassificationMetrics<u8> for MockClassifier {}
173
174 #[test]
175 fn test_confusion_matrix() {
176 let classifier = MockClassifier;
177
178 let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
179 let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
180
181 let result = classifier.confusion_matrix(&y_true, &y_pred).unwrap();
182
183 let expected = DMatrix::from_vec(2, 2, vec![1, 1, 1, 2]);
184
185 assert_eq!(result, expected);
186 }
187
188 #[test]
189 fn test_confusion_matrix_unequal() {
190 let classifier = MockClassifier;
191
192 let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1, 0]);
193 let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
194
195 let result = classifier.confusion_matrix(&y_true, &y_pred);
196
197 assert!(result.is_err());
198 }
199
200 #[test]
201 fn test_confusion_matrix_multiclass() {
202 let classifier = MockClassifier;
203
204 let y_true = DVector::from_vec(vec![0, 1, 2, 1, 0, 2]);
205 let y_pred = DVector::from_vec(vec![0, 2, 1, 1, 0, 2]);
206
207 let result = classifier.confusion_matrix(&y_true, &y_pred).unwrap();
208 let expected = DMatrix::from_vec(3, 3, vec![2, 0, 0, 0, 1, 1, 0, 1, 1]);
209
210 assert_eq!(result, expected);
211 }
212
213 #[test]
214 fn test_accuracy() {
215 let classifier = MockClassifier;
216
217 let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
218 let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
219
220 let result = classifier.accuracy(&y_true, &y_pred).unwrap();
221
222 let expected = 0.6;
223
224 assert_eq!(result, expected);
225 }
226
227 #[test]
228 fn test_accuracy_perfect_classification() {
229 let classifier = MockClassifier;
230
231 let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
232 let y_pred = DVector::from_vec(vec![1, 0, 1, 0, 1]);
233
234 let result = classifier.accuracy(&y_true, &y_pred).unwrap();
235 let expected = 1.0;
236
237 assert_eq!(result, expected);
238 }
239
240 #[test]
241 fn test_precision() {
242 let classifier = MockClassifier;
243
244 let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
245 let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
246
247 let conf = classifier.confusion_matrix(&y_true, &y_pred).unwrap();
248 println!("conf: {}", conf);
249 let result = classifier.precision(&y_true, &y_pred).unwrap();
250
251 let expected = 2.0 / 3.0;
252
253 assert_eq!(result, expected);
254 }
255
256 #[test]
257 fn test_precision_no_positive_predictions() {
258 let classifier = MockClassifier;
259
260 let y_true = DVector::from_vec(vec![1, 1, 1, 1, 1]);
261 let y_pred = DVector::from_vec(vec![0, 0, 0, 0, 0]);
262
263 let result = classifier.precision(&y_true, &y_pred).unwrap();
264
265 assert_eq!(result, 0.0);
266 }
267
268 #[test]
269 fn test_precision_multiclass() {
270 let classifier = MockClassifier;
271
272 let y_true = DVector::from_vec(vec![0, 1, 2, 1, 0, 2]);
273 let y_pred = DVector::from_vec(vec![0, 2, 1, 1, 0, 2]);
274
275 let result = classifier.precision(&y_true, &y_pred).unwrap();
276 let expected = (2.0 / 2.0 + 1.0 / 2.0 + 1.0 / 2.0) / 3.0;
277
278 assert!((result - expected).abs() < std::f64::EPSILON);
279 }
280
281 #[test]
282 fn test_recall() {
283 let classifier = MockClassifier;
284
285 let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
286 let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
287
288 let result = classifier.recall(&y_true, &y_pred).unwrap();
289
290 let expected = 2.0 / 3.0;
291
292 assert_eq!(result, expected);
293 }
294
295 #[test]
296 fn test_recall_no_true_positives() {
297 let classifier = MockClassifier;
298
299 let y_true = DVector::from_vec(vec![1, 1, 1, 1, 1]);
300 let y_pred = DVector::from_vec(vec![0, 0, 0, 0, 0]);
301
302 let result = classifier.recall(&y_true, &y_pred).unwrap();
303 let expected = 0.0;
304
305 assert_eq!(result, expected);
306 }
307
308 #[test]
309 fn test_recall_multiclass() {
310 let classifier = MockClassifier;
311
312 let y_true = DVector::from_vec(vec![0, 1, 2, 1, 0, 2]);
313 let y_pred = DVector::from_vec(vec![0, 2, 1, 1, 0, 2]);
314
315 let result = classifier.recall(&y_true, &y_pred).unwrap();
316 let expected = (2.0 / 2.0 + 1.0 / 2.0 + 1.0 / 2.0) / 3.0;
317
318 assert!((result - expected).abs() < std::f64::EPSILON);
319 }
320
321 #[test]
322 fn test_f1_score() {
323 let classifier = MockClassifier;
324
325 let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
326 let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
327
328 let result = classifier.f1_score(&y_true, &y_pred).unwrap();
329
330 let expected = 2.0 / 3.0;
331
332 assert_eq!(result, expected);
333 }
334
335 #[test]
336 fn test_f1_score_perfect_classification() {
337 let classifier = MockClassifier;
338
339 let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
340 let y_pred = DVector::from_vec(vec![1, 0, 1, 0, 1]);
341
342 let result = classifier.f1_score(&y_true, &y_pred).unwrap();
343 let expected = 1.0;
344
345 assert_eq!(result, expected);
346 }
347
348 #[test]
349 fn test_f1_score_error() {
350 let classifier = MockClassifier;
351
352 let y_true = DVector::from_vec(vec![1, 1, 1, 1, 1]);
353 let y_pred = DVector::from_vec(vec![0, 0, 0, 0, 0]);
354
355 let result = classifier.f1_score(&y_true, &y_pred);
356
357 assert!(result.is_err());
358 }
359
360 #[test]
361 fn test_f1_score_multiclass() {
362 let classifier = MockClassifier;
363
364 let y_true = DVector::from_vec(vec![0, 1, 2, 1, 0, 2]);
365 let y_pred = DVector::from_vec(vec![0, 2, 1, 1, 0, 2]);
366
367 let result = classifier.f1_score(&y_true, &y_pred).unwrap();
368 let precision = classifier.precision(&y_true, &y_pred).unwrap();
369 let recall = classifier.recall(&y_true, &y_pred).unwrap();
370 let expected = 2.0 * (precision * recall) / (precision + recall); assert!((result - expected).abs() < std::f64::EPSILON);
373 }
374}