entrenar/eval/classification/
metrics.rs1use super::average::Average;
4use super::confusion::ConfusionMatrix;
5
6#[derive(Clone, Debug)]
8pub struct MultiClassMetrics {
9 pub precision: Vec<f64>,
11 pub recall: Vec<f64>,
13 pub f1: Vec<f64>,
15 pub support: Vec<usize>,
17 pub n_classes: usize,
19}
20
21impl MultiClassMetrics {
22 pub fn from_confusion_matrix(cm: &ConfusionMatrix) -> Self {
24 let n_classes = cm.n_classes();
25 let mut precision = Vec::with_capacity(n_classes);
26 let mut recall = Vec::with_capacity(n_classes);
27 let mut f1 = Vec::with_capacity(n_classes);
28 let mut support = Vec::with_capacity(n_classes);
29
30 for class in 0..n_classes {
31 let tp = cm.true_positives(class) as f64;
32 let fp = cm.false_positives(class) as f64;
33 let fn_ = cm.false_negatives(class) as f64;
34
35 let p = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
36 let r = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
37 let f = if p + r > 0.0 { 2.0 * p * r / (p + r) } else { 0.0 };
38
39 precision.push(p);
40 recall.push(r);
41 f1.push(f);
42 support.push(cm.support(class));
43 }
44
45 Self { precision, recall, f1, support, n_classes }
46 }
47
48 pub fn from_predictions(y_pred: &[usize], y_true: &[usize]) -> Self {
50 let cm = ConfusionMatrix::from_predictions(y_pred, y_true);
51 Self::from_confusion_matrix(&cm)
52 }
53
54 pub fn precision_avg(&self, average: Average) -> f64 {
56 self.average_metric(&self.precision, average)
57 }
58
59 pub fn recall_avg(&self, average: Average) -> f64 {
61 self.average_metric(&self.recall, average)
62 }
63
64 pub fn f1_avg(&self, average: Average) -> f64 {
66 self.average_metric(&self.f1, average)
67 }
68
69 fn average_metric(&self, values: &[f64], average: Average) -> f64 {
70 match average {
71 Average::Macro => {
72 if values.is_empty() {
73 0.0
74 } else {
75 values.iter().sum::<f64>() / values.len() as f64
76 }
77 }
78 Average::Micro => {
79 self.average_metric(values, Average::Macro)
82 }
83 Average::Weighted => {
84 let total_support: usize = self.support.iter().sum();
85 if total_support == 0 {
86 return 0.0;
87 }
88 values.iter().zip(self.support.iter()).map(|(&v, &s)| v * s as f64).sum::<f64>()
89 / total_support as f64
90 }
91 Average::None => {
92 self.average_metric(values, Average::Macro)
94 }
95 }
96 }
97}