use ndarray::{Array2, ArrayView1, ArrayView2};
#[derive(Debug)]
pub struct MetricsCalculator {
confusion_matrix: Array2<f32>,
}
impl MetricsCalculator {
pub fn new(labels: ArrayView2<f32>, predictions: ArrayView1<f32>) -> Self {
if labels.shape()[0] != predictions.shape()[0] {
return Self {
confusion_matrix: Array2::zeros((0, 0)),
};
}
if labels.is_empty() || predictions.is_empty() {
return Self {
confusion_matrix: Array2::zeros((0, 0)),
};
}
let num_classes = labels.iter().map(|e| *e as usize).max().unwrap_or(0) + 1;
let mut confusion_matrix = Array2::zeros((num_classes, num_classes));
for (true_label, pred_label) in labels.iter().zip(predictions.iter()) {
confusion_matrix[(*true_label as usize, *pred_label as usize)] += 1.0;
}
Self { confusion_matrix }
}
#[inline]
pub fn confusion_matrix(&self) -> &Array2<f32> {
&self.confusion_matrix
}
pub fn accuracy(&self) -> f32 {
if self.confusion_matrix.is_empty() {
return -1.;
}
let total_examples = self.confusion_matrix.sum();
let correct_predictions = self.confusion_matrix.diag().sum();
let result = correct_predictions / total_examples;
if result.is_nan() {
return -1.;
}
result
}
pub fn precision(&self) -> f32 {
if self.confusion_matrix.is_empty() {
return -1.;
}
let num_classes = self.confusion_matrix.shape()[0];
let mut precision_sum = 0.0;
for i in 0..num_classes {
let true_positives = self.confusion_matrix[[i, i]];
let predicted_positives = self.confusion_matrix.column(i).sum();
precision_sum += true_positives / predicted_positives;
}
let result = precision_sum / num_classes as f32;
if result.is_nan() {
return -1.;
}
result
}
pub fn recall(&self) -> f32 {
if self.confusion_matrix.is_empty() {
return -1.;
}
let num_classes = self.confusion_matrix.shape()[0];
let mut recall_sum = 0.0;
for i in 0..num_classes {
let true_positives = self.confusion_matrix[[i, i]];
let actual_positives = self.confusion_matrix.row(i).sum();
recall_sum += true_positives / actual_positives;
}
let result = recall_sum / num_classes as f32;
if result.is_nan() {
return -1.;
}
result
}
pub fn f1_score(&self) -> f32 {
if self.confusion_matrix.is_empty() {
return -1.;
}
let num_classes = self.confusion_matrix.shape()[0];
let mut f1_sum = 0.0;
for i in 0..num_classes {
let true_positives = self.confusion_matrix[[i, i]];
let predicted_positives = self.confusion_matrix.column(i).sum();
let actual_positives = self.confusion_matrix.row(i).sum();
let precision = true_positives / predicted_positives;
let recall = true_positives / actual_positives;
let f1 = (2.0 * precision * recall) / (precision + recall);
f1_sum += f1;
}
let result = f1_sum / num_classes as f32;
if result.is_nan() {
return -1.;
}
result
}
}
#[cfg(test)]
mod tests {
use ndarray::array;
use super::*;
#[test]
fn test_confusion_matrix_binary_classification() {
let labels = array![[0.], [0.], [1.], [0.], [1.], [1.]];
let predictions = array![0., 1., 1., 0., 0., 1.];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let confusion_matrix = class_metrics.confusion_matrix();
assert_eq!(confusion_matrix[[0, 0]], 2.0); assert_eq!(confusion_matrix[[0, 1]], 1.0); assert_eq!(confusion_matrix[[1, 0]], 1.0); assert_eq!(confusion_matrix[[1, 1]], 2.0); }
#[test]
fn test_confusion_matrix_multi_class_classification() {
let labels = array![[0.], [1.], [2.], [0.], [1.], [2.]];
let predictions = array![0., 2., 1., 0., 1., 2.];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let confusion_matrix = class_metrics.confusion_matrix();
assert_eq!(confusion_matrix[[0, 0]], 2.0);
assert_eq!(confusion_matrix[[0, 1]], 0.0);
assert_eq!(confusion_matrix[[0, 2]], 0.0);
assert_eq!(confusion_matrix[[1, 0]], 0.0);
assert_eq!(confusion_matrix[[1, 1]], 1.0);
assert_eq!(confusion_matrix[[1, 2]], 1.0);
assert_eq!(confusion_matrix[[2, 0]], 0.0);
assert_eq!(confusion_matrix[[2, 1]], 1.0);
assert_eq!(confusion_matrix[[2, 2]], 1.0);
}
#[test]
fn test_empty_confusion_matrix() {
let labels = Array2::<f32>::zeros((0, 0));
let predictions = array![];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let confusion_matrix = class_metrics.confusion_matrix();
assert_eq!(confusion_matrix.shape(), [0, 0]);
assert_eq!(confusion_matrix, Array2::<f32>::zeros((0, 0)));
}
#[test]
fn test_accuracy() {
let labels = array![[0.], [0.], [1.], [0.], [1.], [1.]];
let predictions = array![0., 1., 1., 0., 0., 1.];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let accuracy = class_metrics.accuracy();
assert_eq!(accuracy, 0.6666666666666666);
}
#[test]
fn test_precision() {
let labels = array![[0.], [0.], [1.], [0.], [1.], [1.]];
let predictions = array![0., 1., 1., 0., 0., 1.];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let precision = class_metrics.precision();
assert_eq!(precision, 0.6666666666666666);
}
#[test]
fn test_precision_nan() {
let labels = array![[0.], [0.], [1.], [0.], [1.], [1.]];
let predictions = array![f32::NAN, f32::NAN, f32::NAN, f32::NAN, f32::NAN, f32::NAN];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let precision = class_metrics.precision();
assert_eq!(precision, -1.);
}
#[test]
fn test_recall() {
let labels = array![[0.], [0.], [1.], [0.], [1.], [1.]];
let predictions = array![0., 1., 1., 0., 0., 1.];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let recall = class_metrics.recall();
assert_eq!(recall, 0.6666666666666666);
}
#[test]
fn test_f1_score() {
let labels = array![[0.], [0.], [1.], [0.], [1.], [1.]];
let predictions = array![0., 1., 1., 0., 0., 1.];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let f1_score = class_metrics.f1_score();
assert_eq!(f1_score, 0.6666666666666666);
}
#[test]
fn test_f1_score_nan() {
let labels = array![[0.], [0.], [1.], [0.], [1.], [1.]];
let predictions = array![f32::NAN, f32::NAN, f32::NAN, f32::NAN, f32::NAN, f32::NAN];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let f1_score = class_metrics.f1_score();
assert_eq!(f1_score, -1.);
}
#[test]
fn test_empty_cm_accuracy() {
let labels = Array2::<f32>::zeros((0, 0));
let predictions = array![];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let accuracy = class_metrics.accuracy();
assert_eq!(accuracy, -1.);
}
#[test]
fn test_empty_cm_precision() {
let labels = Array2::<f32>::zeros((0, 0));
let predictions = array![];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let precision = class_metrics.precision();
assert_eq!(precision, -1.);
}
#[test]
fn test_empty_cm_recall() {
let labels = Array2::<f32>::zeros((0, 0));
let predictions = array![];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let recall = class_metrics.recall();
assert_eq!(recall, -1.);
}
#[test]
fn test_empty_cm_f1_score() {
let labels = Array2::<f32>::zeros((0, 0));
let predictions = array![];
let class_metrics = MetricsCalculator::new(labels.view(), predictions.view());
let f1_score = class_metrics.f1_score();
assert_eq!(f1_score, -1.);
}
}