Skip to main content

entrenar/eval/classification/
report.rs

1//! Classification report functions
2
3use super::average::Average;
4use super::confusion::ConfusionMatrix;
5use super::metrics::MultiClassMetrics;
6
7/// Compute confusion matrix from predictions and ground truth
8///
9/// # Arguments
10/// * `y_pred` - Predicted class labels
11/// * `y_true` - Ground truth class labels
12///
13/// # Returns
14/// A ConfusionMatrix where element [i][j] is count of true label i predicted as j
15///
16/// # Example
17/// ```ignore
18/// use entrenar::eval::confusion_matrix;
19///
20/// let y_pred = vec![0, 1, 1, 2, 0];
21/// let y_true = vec![0, 1, 0, 2, 1];
22/// let cm = confusion_matrix(&y_pred, &y_true);
23///
24/// assert_eq!(cm.get(0, 0), 1);  // True 0, predicted 0
25/// assert_eq!(cm.get(0, 1), 1);  // True 0, predicted 1
26/// ```
27pub fn confusion_matrix(y_pred: &[usize], y_true: &[usize]) -> ConfusionMatrix {
28    contract_pre_confusion_matrix!();
29    ConfusionMatrix::from_predictions(y_pred, y_true)
30}
31
32/// Generate sklearn-style classification report
33///
34/// # Arguments
35/// * `y_pred` - Predicted class labels
36/// * `y_true` - Ground truth class labels
37///
38/// # Returns
39/// A formatted string containing per-class and overall metrics
40///
41/// # Example
42/// ```ignore
43/// use entrenar::eval::classification_report;
44///
45/// let report = classification_report(&y_pred, &y_true);
46/// println!("{}", report);
47/// ```
48pub fn classification_report(y_pred: &[usize], y_true: &[usize]) -> String {
49    let cm = ConfusionMatrix::from_predictions(y_pred, y_true);
50    let metrics = MultiClassMetrics::from_confusion_matrix(&cm);
51
52    let mut report = String::new();
53
54    // Header
55    report.push_str(&format!(
56        "{:>12} {:>10} {:>10} {:>10} {:>10}\n",
57        "", "precision", "recall", "f1-score", "support"
58    ));
59    report.push_str(&"-".repeat(54));
60    report.push('\n');
61
62    // Per-class metrics
63    for class in 0..metrics.n_classes {
64        report.push_str(&format!(
65            "{:>12} {:>10.2} {:>10.2} {:>10.2} {:>10}\n",
66            format!("Class {}", class),
67            metrics.precision[class],
68            metrics.recall[class],
69            metrics.f1[class],
70            metrics.support[class]
71        ));
72    }
73
74    report.push_str(&"-".repeat(54));
75    report.push('\n');
76
77    // Averages
78    let total_support: usize = metrics.support.iter().sum();
79
80    report.push_str(&format!(
81        "{:>12} {:>10.2} {:>10.2} {:>10.2} {:>10}\n",
82        "macro avg",
83        metrics.precision_avg(Average::Macro),
84        metrics.recall_avg(Average::Macro),
85        metrics.f1_avg(Average::Macro),
86        total_support
87    ));
88
89    report.push_str(&format!(
90        "{:>12} {:>10.2} {:>10.2} {:>10.2} {:>10}\n",
91        "weighted avg",
92        metrics.precision_avg(Average::Weighted),
93        metrics.recall_avg(Average::Weighted),
94        metrics.f1_avg(Average::Weighted),
95        total_support
96    ));
97
98    report.push_str(&format!("\nAccuracy: {:.4}\n", cm.accuracy()));
99
100    report
101}