use std::hash::Hash;
use std::collections::HashMap;
use linalg::Matrix;
pub fn confusion_matrix<T>(predictions: &[T],
targets: &[T],
labels: Option<Vec<T>>) -> Matrix<usize>
where T: Ord + Eq + Hash + Copy
{
assert!(predictions.len() == targets.len(),
"predictions and targets have different lengths");
let labels = match labels {
Some(ls) => ls,
None => ordered_distinct(predictions, targets)
};
let mut label_to_index: HashMap<T, usize> = HashMap::new();
for (i, l) in labels.iter().enumerate() {
match label_to_index.insert(*l, i) {
None => {},
Some(_) => { panic!("labels must be distinct"); }
}
}
let mut counts = Matrix::new(labels.len(), labels.len(),
vec![0usize; labels.len() * labels.len()]);
for (truth, pred) in targets.iter().zip(predictions) {
if label_to_index.contains_key(truth) && label_to_index.contains_key(pred) {
let row = label_to_index[truth];
let col = label_to_index[pred];
counts[[row, col]] += 1;
}
}
counts
}
fn ordered_distinct<T: Ord + Eq + Copy>(xs: &[T], ys: &[T]) -> Vec<T> {
let mut ds: Vec<T> = xs.iter().chain(ys).map(|x| *x).collect();
ds.sort();
ds.dedup();
ds
}
#[cfg(test)]
mod tests {
use super::confusion_matrix;
#[test]
fn confusion_matrix_no_labels() {
let truth = vec![2, 0, 2, 2, 0, 1];
let predictions = vec![0, 0, 2, 2, 0, 2];
let confusion = confusion_matrix(&predictions, &truth, None);
let expected = matrix!(2, 0, 0;
0, 0, 1;
1, 0, 2);
assert_eq!(confusion, expected);
}
#[test]
fn confusion_matrix_with_labels_a_permutation_of_classes() {
let truth = vec![2, 0, 2, 2, 0, 1];
let predictions = vec![0, 0, 2, 2, 0, 2];
let labels = vec![2, 1, 0];
let confusion = confusion_matrix(&predictions, &truth, Some(labels));
let expected = matrix!(2, 0, 1;
1, 0, 0;
0, 0, 2);
assert_eq!(confusion, expected);
}
#[test]
fn confusion_matrix_accepts_labels_intersecting_targets_and_disjoint_from_predictions() {
let truth = vec![2, 0, 2, 2, 3, 1];
let predictions = vec![0, 0, 2, 2, 0, 2];
let labels = vec![1, 3];
let confusion = confusion_matrix(&predictions, &truth, Some(labels));
let expected = matrix!(0, 0;
0, 0);
assert_eq!(confusion, expected);
}
#[test]
fn confusion_matrix_accepts_labels_intersecting_predictions_and_disjoint_from_targets() {
let truth = vec![0, 0, 2, 2, 0, 2];
let predictions = vec![2, 0, 2, 2, 3, 1];
let labels = vec![1, 3];
let confusion = confusion_matrix(&predictions, &truth, Some(labels));
let expected = matrix!(0, 0;
0, 0);
assert_eq!(confusion, expected);
}
#[test]
fn confusion_matrix_accepts_labels_disjoint_from_predictions_and_targets() {
let truth = vec![0, 0, 2, 2, 0, 2];
let predictions = vec![2, 0, 2, 2, 3, 1];
let labels = vec![4, 5];
let confusion = confusion_matrix(&predictions, &truth, Some(labels));
let expected = matrix!(0, 0;
0, 0);
assert_eq!(confusion, expected);
}
#[test]
#[should_panic]
fn confusion_matrix_rejects_duplicate_labels() {
let truth = vec![0, 0, 2, 2, 0, 2];
let predictions = vec![2, 0, 2, 2, 3, 1];
let labels = vec![1, 1];
let _ = confusion_matrix(&predictions, &truth, Some(labels));
}
#[test]
#[should_panic]
fn confusion_matrix_rejects_mismatched_prediction_and_target_lengths() {
let truth = vec![0, 0, 2, 2, 0, 2];
let predictions = vec![2, 0, 2, 2];
let _ = confusion_matrix(&predictions, &truth, None);
}
}