use super::average::Average;
use super::confusion::ConfusionMatrix;
use super::metrics::MultiClassMetrics;
pub fn confusion_matrix(y_pred: &[usize], y_true: &[usize]) -> ConfusionMatrix {
contract_pre_confusion_matrix!();
ConfusionMatrix::from_predictions(y_pred, y_true)
}
pub fn classification_report(y_pred: &[usize], y_true: &[usize]) -> String {
let cm = ConfusionMatrix::from_predictions(y_pred, y_true);
let metrics = MultiClassMetrics::from_confusion_matrix(&cm);
let mut report = String::new();
report.push_str(&format!(
"{:>12} {:>10} {:>10} {:>10} {:>10}\n",
"", "precision", "recall", "f1-score", "support"
));
report.push_str(&"-".repeat(54));
report.push('\n');
for class in 0..metrics.n_classes {
report.push_str(&format!(
"{:>12} {:>10.2} {:>10.2} {:>10.2} {:>10}\n",
format!("Class {}", class),
metrics.precision[class],
metrics.recall[class],
metrics.f1[class],
metrics.support[class]
));
}
report.push_str(&"-".repeat(54));
report.push('\n');
let total_support: usize = metrics.support.iter().sum();
report.push_str(&format!(
"{:>12} {:>10.2} {:>10.2} {:>10.2} {:>10}\n",
"macro avg",
metrics.precision_avg(Average::Macro),
metrics.recall_avg(Average::Macro),
metrics.f1_avg(Average::Macro),
total_support
));
report.push_str(&format!(
"{:>12} {:>10.2} {:>10.2} {:>10.2} {:>10}\n",
"weighted avg",
metrics.precision_avg(Average::Weighted),
metrics.recall_avg(Average::Weighted),
metrics.f1_avg(Average::Weighted),
total_support
));
report.push_str(&format!("\nAccuracy: {:.4}\n", cm.accuracy()));
report
}