use super::average::Average;
use super::confusion::ConfusionMatrix;
#[derive(Clone, Debug)]
pub struct MultiClassMetrics {
pub precision: Vec<f64>,
pub recall: Vec<f64>,
pub f1: Vec<f64>,
pub support: Vec<usize>,
pub n_classes: usize,
}
impl MultiClassMetrics {
pub fn from_confusion_matrix(cm: &ConfusionMatrix) -> Self {
let n_classes = cm.n_classes();
let mut precision = Vec::with_capacity(n_classes);
let mut recall = Vec::with_capacity(n_classes);
let mut f1 = Vec::with_capacity(n_classes);
let mut support = Vec::with_capacity(n_classes);
for class in 0..n_classes {
let tp = cm.true_positives(class) as f64;
let fp = cm.false_positives(class) as f64;
let fn_ = cm.false_negatives(class) as f64;
let p = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
let r = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
let f = if p + r > 0.0 { 2.0 * p * r / (p + r) } else { 0.0 };
precision.push(p);
recall.push(r);
f1.push(f);
support.push(cm.support(class));
}
Self { precision, recall, f1, support, n_classes }
}
pub fn from_predictions(y_pred: &[usize], y_true: &[usize]) -> Self {
let cm = ConfusionMatrix::from_predictions(y_pred, y_true);
Self::from_confusion_matrix(&cm)
}
pub fn precision_avg(&self, average: Average) -> f64 {
self.average_metric(&self.precision, average)
}
pub fn recall_avg(&self, average: Average) -> f64 {
self.average_metric(&self.recall, average)
}
pub fn f1_avg(&self, average: Average) -> f64 {
self.average_metric(&self.f1, average)
}
fn average_metric(&self, values: &[f64], average: Average) -> f64 {
match average {
Average::Macro => {
if values.is_empty() {
0.0
} else {
values.iter().sum::<f64>() / values.len() as f64
}
}
Average::Micro => {
self.average_metric(values, Average::Macro)
}
Average::Weighted => {
let total_support: usize = self.support.iter().sum();
if total_support == 0 {
return 0.0;
}
values.iter().zip(self.support.iter()).map(|(&v, &s)| v * s as f64).sum::<f64>()
/ total_support as f64
}
Average::None => {
self.average_metric(values, Average::Macro)
}
}
}
}