use crate::{TrainError, TrainResult};
use scirs2_core::ndarray::{ArrayView, Ix2};
use super::Metric;
#[derive(Debug, Clone)]
pub struct ConfusionMatrix {
pub(crate) num_classes: usize,
pub(crate) matrix: Vec<Vec<usize>>,
}
impl ConfusionMatrix {
pub fn new(num_classes: usize) -> Self {
Self {
num_classes,
matrix: vec![vec![0; num_classes]; num_classes],
}
}
pub fn compute(
predictions: &ArrayView<f64, Ix2>,
targets: &ArrayView<f64, Ix2>,
) -> TrainResult<Self> {
if predictions.shape() != targets.shape() {
return Err(TrainError::MetricsError(format!(
"Shape mismatch: predictions {:?} vs targets {:?}",
predictions.shape(),
targets.shape()
)));
}
let num_classes = predictions.ncols();
let mut matrix = vec![vec![0; num_classes]; num_classes];
for i in 0..predictions.nrows() {
let mut pred_class = 0;
let mut max_pred = predictions[[i, 0]];
for j in 1..num_classes {
if predictions[[i, j]] > max_pred {
max_pred = predictions[[i, j]];
pred_class = j;
}
}
let mut true_class = 0;
let mut max_true = targets[[i, 0]];
for j in 1..num_classes {
if targets[[i, j]] > max_true {
max_true = targets[[i, j]];
true_class = j;
}
}
matrix[true_class][pred_class] += 1;
}
Ok(Self {
num_classes,
matrix,
})
}
pub fn matrix(&self) -> &Vec<Vec<usize>> {
&self.matrix
}
pub fn get(&self, true_class: usize, pred_class: usize) -> usize {
self.matrix[true_class][pred_class]
}
pub fn precision_per_class(&self) -> Vec<f64> {
let mut precisions = Vec::with_capacity(self.num_classes);
for pred_class in 0..self.num_classes {
let mut predicted_positive = 0;
let mut true_positive = 0;
for true_class in 0..self.num_classes {
predicted_positive += self.matrix[true_class][pred_class];
if true_class == pred_class {
true_positive += self.matrix[true_class][pred_class];
}
}
let precision = if predicted_positive == 0 {
0.0
} else {
true_positive as f64 / predicted_positive as f64
};
precisions.push(precision);
}
precisions
}
pub fn recall_per_class(&self) -> Vec<f64> {
let mut recalls = Vec::with_capacity(self.num_classes);
for true_class in 0..self.num_classes {
let mut actual_positive = 0;
let mut true_positive = 0;
for pred_class in 0..self.num_classes {
actual_positive += self.matrix[true_class][pred_class];
if true_class == pred_class {
true_positive += self.matrix[true_class][pred_class];
}
}
let recall = if actual_positive == 0 {
0.0
} else {
true_positive as f64 / actual_positive as f64
};
recalls.push(recall);
}
recalls
}
pub fn f1_per_class(&self) -> Vec<f64> {
let precisions = self.precision_per_class();
let recalls = self.recall_per_class();
precisions
.iter()
.zip(recalls.iter())
.map(|(p, r)| {
if p + r == 0.0 {
0.0
} else {
2.0 * p * r / (p + r)
}
})
.collect()
}
pub fn accuracy(&self) -> f64 {
let mut correct = 0;
let mut total = 0;
for i in 0..self.num_classes {
for j in 0..self.num_classes {
total += self.matrix[i][j];
if i == j {
correct += self.matrix[i][j];
}
}
}
if total == 0 {
0.0
} else {
correct as f64 / total as f64
}
}
pub fn total_predictions(&self) -> usize {
let mut total = 0;
for i in 0..self.num_classes {
for j in 0..self.num_classes {
total += self.matrix[i][j];
}
}
total
}
}
impl std::fmt::Display for ConfusionMatrix {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Confusion Matrix:")?;
write!(f, " ")?;
for j in 0..self.num_classes {
write!(f, "{:5}", j)?;
}
writeln!(f)?;
for i in 0..self.num_classes {
write!(f, "{:3}| ", i)?;
for j in 0..self.num_classes {
write!(f, "{:5}", self.matrix[i][j])?;
}
writeln!(f)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct RocCurve {
pub fpr: Vec<f64>,
pub tpr: Vec<f64>,
pub thresholds: Vec<f64>,
}
impl RocCurve {
pub fn compute(predictions: &[f64], targets: &[bool]) -> TrainResult<Self> {
if predictions.len() != targets.len() {
return Err(TrainError::MetricsError(format!(
"Length mismatch: predictions {} vs targets {}",
predictions.len(),
targets.len()
)));
}
let mut indices: Vec<usize> = (0..predictions.len()).collect();
indices.sort_by(|&a, &b| {
predictions[b]
.partial_cmp(&predictions[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut fpr = Vec::new();
let mut tpr = Vec::new();
let mut thresholds = Vec::new();
let num_positive = targets.iter().filter(|&&x| x).count();
let num_negative = targets.len() - num_positive;
let mut true_positives = 0;
let mut false_positives = 0;
fpr.push(0.0);
tpr.push(0.0);
thresholds.push(f64::INFINITY);
for &idx in &indices {
if targets[idx] {
true_positives += 1;
} else {
false_positives += 1;
}
let fpr_val = if num_negative == 0 {
0.0
} else {
false_positives as f64 / num_negative as f64
};
let tpr_val = if num_positive == 0 {
0.0
} else {
true_positives as f64 / num_positive as f64
};
fpr.push(fpr_val);
tpr.push(tpr_val);
thresholds.push(predictions[idx]);
}
Ok(Self {
fpr,
tpr,
thresholds,
})
}
pub fn auc(&self) -> f64 {
let mut auc = 0.0;
for i in 1..self.fpr.len() {
let width = self.fpr[i] - self.fpr[i - 1];
let height = (self.tpr[i] + self.tpr[i - 1]) / 2.0;
auc += width * height;
}
auc
}
}
#[derive(Debug, Clone)]
pub struct PerClassMetrics {
pub precision: Vec<f64>,
pub recall: Vec<f64>,
pub f1_score: Vec<f64>,
pub support: Vec<usize>,
}
impl PerClassMetrics {
pub fn compute(
predictions: &ArrayView<f64, Ix2>,
targets: &ArrayView<f64, Ix2>,
) -> TrainResult<Self> {
let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
let precision = confusion_matrix.precision_per_class();
let recall = confusion_matrix.recall_per_class();
let f1_score = confusion_matrix.f1_per_class();
let num_classes = targets.ncols();
let mut support = vec![0; num_classes];
for i in 0..targets.nrows() {
let mut true_class = 0;
let mut max_true = targets[[i, 0]];
for j in 1..num_classes {
if targets[[i, j]] > max_true {
max_true = targets[[i, j]];
true_class = j;
}
}
support[true_class] += 1;
}
Ok(Self {
precision,
recall,
f1_score,
support,
})
}
}
impl std::fmt::Display for PerClassMetrics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Per-Class Metrics:")?;
writeln!(f, "Class Precision Recall F1-Score Support")?;
writeln!(f, "----- --------- ------ -------- -------")?;
for i in 0..self.precision.len() {
writeln!(
f,
"{:5} {:9.4} {:6.4} {:8.4} {:7}",
i, self.precision[i], self.recall[i], self.f1_score[i], self.support[i]
)?;
}
let macro_precision: f64 = self.precision.iter().sum::<f64>() / self.precision.len() as f64;
let macro_recall: f64 = self.recall.iter().sum::<f64>() / self.recall.len() as f64;
let macro_f1: f64 = self.f1_score.iter().sum::<f64>() / self.f1_score.len() as f64;
let total_support: usize = self.support.iter().sum();
writeln!(f, "----- --------- ------ -------- -------")?;
writeln!(
f,
"Macro {:9.4} {:6.4} {:8.4} {:7}",
macro_precision, macro_recall, macro_f1, total_support
)?;
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct MatthewsCorrelationCoefficient;
impl Metric for MatthewsCorrelationCoefficient {
fn compute(
&self,
predictions: &ArrayView<f64, Ix2>,
targets: &ArrayView<f64, Ix2>,
) -> TrainResult<f64> {
let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
let num_classes = confusion_matrix.num_classes;
if num_classes == 2 {
let tp = confusion_matrix.matrix[1][1] as f64;
let tn = confusion_matrix.matrix[0][0] as f64;
let fp = confusion_matrix.matrix[0][1] as f64;
let fn_val = confusion_matrix.matrix[1][0] as f64;
let numerator = (tp * tn) - (fp * fn_val);
let denominator = ((tp + fp) * (tp + fn_val) * (tn + fp) * (tn + fn_val)).sqrt();
if denominator == 0.0 {
Ok(0.0)
} else {
Ok(numerator / denominator)
}
} else {
let mut s = 0.0;
let mut c = 0.0;
let t = confusion_matrix.total_predictions() as f64;
let mut p_k = vec![0.0; num_classes];
let mut t_k = vec![0.0; num_classes];
for k in 0..num_classes {
for l in 0..num_classes {
p_k[k] += confusion_matrix.matrix[l][k] as f64;
t_k[k] += confusion_matrix.matrix[k][l] as f64;
}
}
for k in 0..num_classes {
c += confusion_matrix.matrix[k][k] as f64;
}
for k in 0..num_classes {
s += p_k[k] * t_k[k];
}
let numerator = (t * c) - s;
let denominator_1 = ((t * t) - s).sqrt();
let mut sum_p_sq = 0.0;
let mut sum_t_sq = 0.0;
for k in 0..num_classes {
sum_p_sq += p_k[k] * p_k[k];
sum_t_sq += t_k[k] * t_k[k];
}
let denominator_2 = ((t * t) - sum_p_sq).sqrt();
let denominator_3 = ((t * t) - sum_t_sq).sqrt();
let denominator = denominator_1 * denominator_2 * denominator_3;
if denominator == 0.0 {
Ok(0.0)
} else {
Ok(numerator / denominator)
}
}
}
fn name(&self) -> &str {
"mcc"
}
}
#[derive(Debug, Clone, Default)]
pub struct CohensKappa;
impl Metric for CohensKappa {
fn compute(
&self,
predictions: &ArrayView<f64, Ix2>,
targets: &ArrayView<f64, Ix2>,
) -> TrainResult<f64> {
let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
let num_classes = confusion_matrix.num_classes;
let total = confusion_matrix.total_predictions() as f64;
let mut observed = 0.0;
for i in 0..num_classes {
observed += confusion_matrix.matrix[i][i] as f64;
}
observed /= total;
let mut expected = 0.0;
for i in 0..num_classes {
let row_sum: f64 = (0..num_classes)
.map(|j| confusion_matrix.matrix[i][j] as f64)
.sum();
let col_sum: f64 = (0..num_classes)
.map(|j| confusion_matrix.matrix[j][i] as f64)
.sum();
expected += (row_sum / total) * (col_sum / total);
}
if expected >= 1.0 {
Ok(0.0)
} else {
Ok((observed - expected) / (1.0 - expected))
}
}
fn name(&self) -> &str {
"cohens_kappa"
}
}
#[derive(Debug, Clone, Default)]
pub struct BalancedAccuracy;
impl Metric for BalancedAccuracy {
fn compute(
&self,
predictions: &ArrayView<f64, Ix2>,
targets: &ArrayView<f64, Ix2>,
) -> TrainResult<f64> {
let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
let recalls = confusion_matrix.recall_per_class();
let sum: f64 = recalls.iter().sum();
Ok(sum / recalls.len() as f64)
}
fn name(&self) -> &str {
"balanced_accuracy"
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_confusion_matrix() {
let predictions = array![
[0.9, 0.1, 0.0],
[0.1, 0.8, 0.1],
[0.2, 0.1, 0.7],
[0.8, 0.1, 0.1]
];
let targets = array![
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[1.0, 0.0, 0.0]
];
let cm = ConfusionMatrix::compute(&predictions.view(), &targets.view()).expect("unwrap");
assert_eq!(cm.get(0, 0), 2); assert_eq!(cm.get(1, 1), 1); assert_eq!(cm.get(2, 2), 1); assert_eq!(cm.accuracy(), 1.0);
}
#[test]
fn test_confusion_matrix_per_class_metrics() {
let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.1, 0.9]];
let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
let cm = ConfusionMatrix::compute(&predictions.view(), &targets.view()).expect("unwrap");
let precision = cm.precision_per_class();
let recall = cm.recall_per_class();
let f1 = cm.f1_per_class();
assert_eq!(precision.len(), 2);
assert_eq!(recall.len(), 2);
assert_eq!(f1.len(), 2);
assert_eq!(precision[0], 1.0);
assert_eq!(precision[1], 1.0);
assert_eq!(recall[0], 1.0);
assert_eq!(recall[1], 1.0);
}
#[test]
fn test_roc_curve() {
let predictions = vec![0.9, 0.8, 0.4, 0.3, 0.1];
let targets = vec![true, true, false, true, false];
let roc = RocCurve::compute(&predictions, &targets).expect("unwrap");
assert!(!roc.fpr.is_empty());
assert!(!roc.tpr.is_empty());
assert!(!roc.thresholds.is_empty());
assert_eq!(roc.fpr.len(), roc.tpr.len());
let auc = roc.auc();
assert!((0.0..=1.0).contains(&auc));
}
#[test]
fn test_roc_auc_perfect() {
let predictions = vec![0.9, 0.8, 0.3, 0.1];
let targets = vec![true, true, false, false];
let roc = RocCurve::compute(&predictions, &targets).expect("unwrap");
let auc = roc.auc();
assert!((auc - 1.0).abs() < 1e-6);
}
#[test]
fn test_per_class_metrics() {
let predictions = array![
[0.9, 0.1, 0.0],
[0.1, 0.8, 0.1],
[0.2, 0.1, 0.7],
[0.8, 0.1, 0.1]
];
let targets = array![
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[1.0, 0.0, 0.0]
];
let metrics =
PerClassMetrics::compute(&predictions.view(), &targets.view()).expect("unwrap");
assert_eq!(metrics.precision.len(), 3);
assert_eq!(metrics.recall.len(), 3);
assert_eq!(metrics.f1_score.len(), 3);
assert_eq!(metrics.support.len(), 3);
assert_eq!(metrics.support[0], 2);
assert_eq!(metrics.support[1], 1);
assert_eq!(metrics.support[2], 1);
}
#[test]
fn test_matthews_correlation_coefficient() {
let metric = MatthewsCorrelationCoefficient;
let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
let mcc = metric
.compute(&predictions.view(), &targets.view())
.expect("unwrap");
assert!((mcc - 1.0).abs() < 1e-6);
let predictions = array![[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]];
let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
let mcc = metric
.compute(&predictions.view(), &targets.view())
.expect("unwrap");
assert!(mcc.abs() < 0.1);
}
#[test]
fn test_cohens_kappa() {
let metric = CohensKappa;
let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
let kappa = metric
.compute(&predictions.view(), &targets.view())
.expect("unwrap");
assert!((kappa - 1.0).abs() < 1e-6);
let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.9, 0.1]];
let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
let kappa = metric
.compute(&predictions.view(), &targets.view())
.expect("unwrap");
assert!((-1.0..=1.0).contains(&kappa));
}
#[test]
fn test_balanced_accuracy() {
let metric = BalancedAccuracy;
let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
let balanced_acc = metric
.compute(&predictions.view(), &targets.view())
.expect("unwrap");
assert!((balanced_acc - 1.0).abs() < 1e-6);
let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.1, 0.9]];
let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let balanced_acc = metric
.compute(&predictions.view(), &targets.view())
.expect("unwrap");
assert!((balanced_acc - 1.0).abs() < 1e-6);
}
}