entrenar/eval/classification/
report.rs1use super::average::Average;
4use super::confusion::ConfusionMatrix;
5use super::metrics::MultiClassMetrics;
6
7pub fn confusion_matrix(y_pred: &[usize], y_true: &[usize]) -> ConfusionMatrix {
28 contract_pre_confusion_matrix!();
29 ConfusionMatrix::from_predictions(y_pred, y_true)
30}
31
32pub fn classification_report(y_pred: &[usize], y_true: &[usize]) -> String {
49 let cm = ConfusionMatrix::from_predictions(y_pred, y_true);
50 let metrics = MultiClassMetrics::from_confusion_matrix(&cm);
51
52 let mut report = String::new();
53
54 report.push_str(&format!(
56 "{:>12} {:>10} {:>10} {:>10} {:>10}\n",
57 "", "precision", "recall", "f1-score", "support"
58 ));
59 report.push_str(&"-".repeat(54));
60 report.push('\n');
61
62 for class in 0..metrics.n_classes {
64 report.push_str(&format!(
65 "{:>12} {:>10.2} {:>10.2} {:>10.2} {:>10}\n",
66 format!("Class {}", class),
67 metrics.precision[class],
68 metrics.recall[class],
69 metrics.f1[class],
70 metrics.support[class]
71 ));
72 }
73
74 report.push_str(&"-".repeat(54));
75 report.push('\n');
76
77 let total_support: usize = metrics.support.iter().sum();
79
80 report.push_str(&format!(
81 "{:>12} {:>10.2} {:>10.2} {:>10.2} {:>10}\n",
82 "macro avg",
83 metrics.precision_avg(Average::Macro),
84 metrics.recall_avg(Average::Macro),
85 metrics.f1_avg(Average::Macro),
86 total_support
87 ));
88
89 report.push_str(&format!(
90 "{:>12} {:>10.2} {:>10.2} {:>10.2} {:>10}\n",
91 "weighted avg",
92 metrics.precision_avg(Average::Weighted),
93 metrics.recall_avg(Average::Weighted),
94 metrics.f1_avg(Average::Weighted),
95 total_support
96 ));
97
98 report.push_str(&format!("\nAccuracy: {:.4}\n", cm.accuracy()));
99
100 report
101}