entrenar/eval/classification/
confusion.rs1use std::fmt;
4
5#[derive(Clone, Debug)]
9pub struct ConfusionMatrix {
10 matrix: Vec<Vec<usize>>,
12 n_classes: usize,
14 labels: Vec<usize>,
16}
17
18impl ConfusionMatrix {
19 pub fn new(n_classes: usize) -> Self {
21 Self {
22 matrix: vec![vec![0; n_classes]; n_classes],
23 n_classes,
24 labels: (0..n_classes).collect(),
25 }
26 }
27
28 pub fn from_predictions(y_pred: &[usize], y_true: &[usize]) -> Self {
30 Self::from_predictions_with_min_classes(y_pred, y_true, 0)
31 }
32
33 pub fn from_predictions_with_min_classes(
35 y_pred: &[usize],
36 y_true: &[usize],
37 min_classes: usize,
38 ) -> Self {
39 assert_eq!(y_pred.len(), y_true.len(), "Predictions and targets must have same length");
40
41 let observed = y_pred.iter().chain(y_true.iter()).max().map_or(0, |&m| m + 1);
43 let n_classes = observed.max(min_classes);
44
45 let mut cm = Self::new(n_classes);
46
47 for (&pred, &true_label) in y_pred.iter().zip(y_true.iter()) {
48 if pred < n_classes && true_label < n_classes {
49 cm.matrix[true_label][pred] += 1;
50 }
51 }
52
53 cm
54 }
55
56 pub fn matrix(&self) -> &Vec<Vec<usize>> {
58 &self.matrix
59 }
60
61 pub fn labels(&self) -> &[usize] {
63 &self.labels
64 }
65
66 pub fn n_classes(&self) -> usize {
68 self.n_classes
69 }
70
71 pub fn get(&self, true_label: usize, predicted_label: usize) -> usize {
73 self.matrix[true_label][predicted_label]
74 }
75
76 pub fn true_positives(&self, class: usize) -> usize {
78 self.matrix[class][class]
79 }
80
81 pub fn false_positives(&self, class: usize) -> usize {
83 (0..self.n_classes).filter(|&i| i != class).map(|i| self.matrix[i][class]).sum()
84 }
85
86 pub fn false_negatives(&self, class: usize) -> usize {
88 (0..self.n_classes).filter(|&j| j != class).map(|j| self.matrix[class][j]).sum()
89 }
90
91 pub fn true_negatives(&self, class: usize) -> usize {
93 let total: usize = self.matrix.iter().flatten().sum();
94 total
95 - self.true_positives(class)
96 - self.false_positives(class)
97 - self.false_negatives(class)
98 }
99
100 pub fn support(&self, class: usize) -> usize {
102 self.matrix[class].iter().sum()
103 }
104
105 pub fn total(&self) -> usize {
107 self.matrix.iter().flatten().sum()
108 }
109
110 pub fn accuracy(&self) -> f64 {
112 contract_pre_accuracy!();
113 let total = self.total();
114 if total == 0 {
115 return 0.0;
116 }
117 let correct: usize = (0..self.n_classes).map(|i| self.matrix[i][i]).sum();
118 correct as f64 / total as f64
119 }
120}
121
122impl fmt::Display for ConfusionMatrix {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 writeln!(f, "Confusion Matrix:")?;
125
126 write!(f, " ")?;
128 for j in 0..self.n_classes {
129 write!(f, "Pred {j} ")?;
130 }
131 writeln!(f)?;
132
133 for i in 0..self.n_classes {
135 write!(f, "True {i}")?;
136 for j in 0..self.n_classes {
137 write!(f, "{:>6} ", self.matrix[i][j])?;
138 }
139 writeln!(f)?;
140 }
141
142 Ok(())
143 }
144}