Skip to main content

entrenar/eval/classification/
confusion.rs

1//! Confusion matrix for multi-class classification
2
3use std::fmt;
4
5/// Confusion matrix for multi-class classification
6///
7/// Element [i][j] represents count of samples with true label i predicted as j
8#[derive(Clone, Debug)]
9pub struct ConfusionMatrix {
10    /// The matrix data: matrix[true_label][predicted_label] = count
11    matrix: Vec<Vec<usize>>,
12    /// Number of classes
13    n_classes: usize,
14    /// Class labels (indices)
15    labels: Vec<usize>,
16}
17
18impl ConfusionMatrix {
19    /// Create a new confusion matrix with given number of classes
20    pub fn new(n_classes: usize) -> Self {
21        Self {
22            matrix: vec![vec![0; n_classes]; n_classes],
23            n_classes,
24            labels: (0..n_classes).collect(),
25        }
26    }
27
28    /// Create from predictions and ground truth
29    pub fn from_predictions(y_pred: &[usize], y_true: &[usize]) -> Self {
30        Self::from_predictions_with_min_classes(y_pred, y_true, 0)
31    }
32
33    /// Create from predictions and ground truth, ensuring at least `min_classes` classes
34    pub fn from_predictions_with_min_classes(
35        y_pred: &[usize],
36        y_true: &[usize],
37        min_classes: usize,
38    ) -> Self {
39        assert_eq!(y_pred.len(), y_true.len(), "Predictions and targets must have same length");
40
41        // Determine number of classes (at least min_classes)
42        let observed = y_pred.iter().chain(y_true.iter()).max().map_or(0, |&m| m + 1);
43        let n_classes = observed.max(min_classes);
44
45        let mut cm = Self::new(n_classes);
46
47        for (&pred, &true_label) in y_pred.iter().zip(y_true.iter()) {
48            if pred < n_classes && true_label < n_classes {
49                cm.matrix[true_label][pred] += 1;
50            }
51        }
52
53        cm
54    }
55
56    /// Get the raw matrix
57    pub fn matrix(&self) -> &Vec<Vec<usize>> {
58        &self.matrix
59    }
60
61    /// Get the class labels
62    pub fn labels(&self) -> &[usize] {
63        &self.labels
64    }
65
66    /// Get number of classes
67    pub fn n_classes(&self) -> usize {
68        self.n_classes
69    }
70
71    /// Get element at [true_label][predicted_label]
72    pub fn get(&self, true_label: usize, predicted_label: usize) -> usize {
73        self.matrix[true_label][predicted_label]
74    }
75
76    /// Calculate true positives for a class
77    pub fn true_positives(&self, class: usize) -> usize {
78        self.matrix[class][class]
79    }
80
81    /// Calculate false positives for a class (predicted as class but wasn't)
82    pub fn false_positives(&self, class: usize) -> usize {
83        (0..self.n_classes).filter(|&i| i != class).map(|i| self.matrix[i][class]).sum()
84    }
85
86    /// Calculate false negatives for a class (was class but predicted differently)
87    pub fn false_negatives(&self, class: usize) -> usize {
88        (0..self.n_classes).filter(|&j| j != class).map(|j| self.matrix[class][j]).sum()
89    }
90
91    /// Calculate true negatives for a class
92    pub fn true_negatives(&self, class: usize) -> usize {
93        let total: usize = self.matrix.iter().flatten().sum();
94        total
95            - self.true_positives(class)
96            - self.false_positives(class)
97            - self.false_negatives(class)
98    }
99
100    /// Calculate support (total true instances) for a class
101    pub fn support(&self, class: usize) -> usize {
102        self.matrix[class].iter().sum()
103    }
104
105    /// Total number of samples
106    pub fn total(&self) -> usize {
107        self.matrix.iter().flatten().sum()
108    }
109
110    /// Calculate accuracy
111    pub fn accuracy(&self) -> f64 {
112        contract_pre_accuracy!();
113        let total = self.total();
114        if total == 0 {
115            return 0.0;
116        }
117        let correct: usize = (0..self.n_classes).map(|i| self.matrix[i][i]).sum();
118        correct as f64 / total as f64
119    }
120}
121
122impl fmt::Display for ConfusionMatrix {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        writeln!(f, "Confusion Matrix:")?;
125
126        // Header
127        write!(f, "      ")?;
128        for j in 0..self.n_classes {
129            write!(f, "Pred {j} ")?;
130        }
131        writeln!(f)?;
132
133        // Rows
134        for i in 0..self.n_classes {
135            write!(f, "True {i}")?;
136            for j in 0..self.n_classes {
137                write!(f, "{:>6} ", self.matrix[i][j])?;
138            }
139            writeln!(f)?;
140        }
141
142        Ok(())
143    }
144}