Skip to main content

entrenar/eval/classification/
metrics.rs

1//! Multi-class classification metrics
2
3use super::average::Average;
4use super::confusion::ConfusionMatrix;
5
6/// Multi-class classification metrics
7#[derive(Clone, Debug)]
8pub struct MultiClassMetrics {
9    /// Per-class precision
10    pub precision: Vec<f64>,
11    /// Per-class recall
12    pub recall: Vec<f64>,
13    /// Per-class F1 score
14    pub f1: Vec<f64>,
15    /// Per-class support (count)
16    pub support: Vec<usize>,
17    /// Number of classes
18    pub n_classes: usize,
19}
20
21impl MultiClassMetrics {
22    /// Compute metrics from confusion matrix
23    pub fn from_confusion_matrix(cm: &ConfusionMatrix) -> Self {
24        let n_classes = cm.n_classes();
25        let mut precision = Vec::with_capacity(n_classes);
26        let mut recall = Vec::with_capacity(n_classes);
27        let mut f1 = Vec::with_capacity(n_classes);
28        let mut support = Vec::with_capacity(n_classes);
29
30        for class in 0..n_classes {
31            let tp = cm.true_positives(class) as f64;
32            let fp = cm.false_positives(class) as f64;
33            let fn_ = cm.false_negatives(class) as f64;
34
35            let p = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
36            let r = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
37            let f = if p + r > 0.0 { 2.0 * p * r / (p + r) } else { 0.0 };
38
39            precision.push(p);
40            recall.push(r);
41            f1.push(f);
42            support.push(cm.support(class));
43        }
44
45        Self { precision, recall, f1, support, n_classes }
46    }
47
48    /// Compute from predictions and ground truth
49    pub fn from_predictions(y_pred: &[usize], y_true: &[usize]) -> Self {
50        let cm = ConfusionMatrix::from_predictions(y_pred, y_true);
51        Self::from_confusion_matrix(&cm)
52    }
53
54    /// Get averaged precision
55    pub fn precision_avg(&self, average: Average) -> f64 {
56        self.average_metric(&self.precision, average)
57    }
58
59    /// Get averaged recall
60    pub fn recall_avg(&self, average: Average) -> f64 {
61        self.average_metric(&self.recall, average)
62    }
63
64    /// Get averaged F1
65    pub fn f1_avg(&self, average: Average) -> f64 {
66        self.average_metric(&self.f1, average)
67    }
68
69    fn average_metric(&self, values: &[f64], average: Average) -> f64 {
70        match average {
71            Average::Macro => {
72                if values.is_empty() {
73                    0.0
74                } else {
75                    values.iter().sum::<f64>() / values.len() as f64
76                }
77            }
78            Average::Micro => {
79                // For micro-averaging, we need to recalculate from totals
80                // Currently uses macro-average as fallback (FUTURE: full micro-avg)
81                self.average_metric(values, Average::Macro)
82            }
83            Average::Weighted => {
84                let total_support: usize = self.support.iter().sum();
85                if total_support == 0 {
86                    return 0.0;
87                }
88                values.iter().zip(self.support.iter()).map(|(&v, &s)| v * s as f64).sum::<f64>()
89                    / total_support as f64
90            }
91            Average::None => {
92                // Return macro as default for single value
93                self.average_metric(values, Average::Macro)
94            }
95        }
96    }
97}