#[must_use]
pub fn f1_per_class(y_pred: &[usize], y_true: &[usize]) -> Vec<f32> {
assert_eq!(y_pred.len(), y_true.len(), "Vectors must have same length");
assert!(!y_true.is_empty(), "Vectors cannot be empty");
let n_classes = y_true
.iter()
.chain(y_pred.iter())
.max()
.map_or(0, |&m| m + 1);
let (tp, fp, fn_counts, _) = compute_tp_fp_fn(y_pred, y_true, n_classes);
(0..n_classes)
.map(|i| class_f1(tp[i], fp[i], fn_counts[i]))
.collect()
}
#[must_use]
#[provable_contracts_macros::contract("metrics-classification-v1", equation = "confusion_matrix")]
pub fn confusion_matrix(y_pred: &[usize], y_true: &[usize]) -> Matrix<usize> {
contract_pre_confusion_matrix!();
assert_eq!(y_pred.len(), y_true.len(), "Vectors must have same length");
assert!(!y_true.is_empty(), "Vectors cannot be empty");
let n_classes = y_true
.iter()
.chain(y_pred.iter())
.max()
.map_or(0, |&m| m + 1);
let mut data = vec![0usize; n_classes * n_classes];
for (&true_label, &pred_label) in y_true.iter().zip(y_pred.iter()) {
data[true_label * n_classes + pred_label] += 1;
}
Matrix::from_vec(n_classes, n_classes, data)
.expect("Confusion matrix dimensions match data length")
}
fn compute_tp_fp_fn(
y_pred: &[usize],
y_true: &[usize],
n_classes: usize,
) -> (Vec<usize>, Vec<usize>, Vec<usize>, Vec<usize>) {
let mut tp = vec![0usize; n_classes];
let mut fp = vec![0usize; n_classes];
let mut fn_counts = vec![0usize; n_classes];
let mut support = vec![0usize; n_classes];
for (&true_label, &pred_label) in y_true.iter().zip(y_pred.iter()) {
support[true_label] += 1;
if true_label == pred_label {
tp[true_label] += 1;
} else {
fp[pred_label] += 1;
fn_counts[true_label] += 1;
}
}
(tp, fp, fn_counts, support)
}
#[path = "classification_report.rs"]
mod classification_report;
pub use classification_report::classification_report;