imbalanced_metrics/
lib.rs

1//! Performance metrics for imbalanced learning
2//!
3//! This crate provides metrics specifically useful for evaluating
4//! models on imbalanced datasets.
5
6#![deny(missing_docs)]
7#![warn(clippy::all, clippy::pedantic)]
8#![allow(clippy::module_name_repetitions)]
9
10use ndarray::ArrayView1;
11use rayon::prelude::*;
12use std::collections::HashMap;
13
14/// Classification report containing precision, recall, F1-score, and support
15#[derive(Debug, Clone)]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17pub struct ClassificationReport {
18    /// Per-class metrics
19    pub class_metrics: HashMap<usize, ClassMetrics>,
20    /// Macro-averaged metrics
21    pub macro_avg: AggregateMetrics,
22    /// Weighted-averaged metrics
23    pub weighted_avg: AggregateMetrics,
24    /// Overall accuracy
25    pub accuracy: f64,
26}
27
28/// Metrics for a single class
29#[derive(Debug, Clone, Copy)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct ClassMetrics {
32    /// Precision: TP / (TP + FP)
33    pub precision: f64,
34    /// Recall: TP / (TP + FN)
35    pub recall: f64,
36    /// F1-score: 2 * (precision * recall) / (precision + recall)
37    pub f1_score: f64,
38    /// Number of true instances for this class
39    pub support: usize,
40}
41
42/// Aggregate metrics (macro/weighted averages)
43#[derive(Debug, Clone, Copy)]
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45pub struct AggregateMetrics {
46    /// Average precision
47    pub precision: f64,
48    /// Average recall
49    pub recall: f64,
50    /// Average F1-score
51    pub f1_score: f64,
52}
53
54/// Compute confusion matrix
55pub fn confusion_matrix(y_true: ArrayView1<usize>, y_pred: ArrayView1<usize>) -> Array2<usize> {
56    use ndarray::Array2;
57    
58    assert_eq!(y_true.len(), y_pred.len(), "Arrays must have same length");
59    
60    let n_classes = y_true.iter().chain(y_pred.iter()).max().unwrap() + 1;
61    let mut matrix = Array2::zeros((n_classes, n_classes));
62    
63    for (&true_label, &pred_label) in y_true.iter().zip(y_pred.iter()) {
64        matrix[[true_label, pred_label]] += 1;
65    }
66    
67    matrix
68}
69
70/// Generate classification report
71pub fn classification_report(
72    y_true: ArrayView1<usize>, 
73    y_pred: ArrayView1<usize>
74) -> ClassificationReport {
75    let cm = confusion_matrix(y_true, y_pred);
76    let n_classes = cm.nrows();
77    
78    let class_metrics: HashMap<usize, ClassMetrics> = (0..n_classes)
79        .into_par_iter()
80        .map(|class| {
81            let tp = cm[[class, class]] as f64;
82            let fp = cm.column(class).sum() as f64 - tp;
83            let false_neg = cm.row(class).sum() as f64 - tp;
84            
85            let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
86            let recall = if tp + false_neg > 0.0 { tp / (tp + false_neg) } else { 0.0 };
87            let f1_score = if precision + recall > 0.0 {
88                2.0 * precision * recall / (precision + recall)
89            } else {
90                0.0
91            };
92            
93            let support = cm.row(class).sum();
94            
95            (class, ClassMetrics { precision, recall, f1_score, support })
96        })
97        .collect();
98    
99    // Compute macro averages
100    let macro_avg = AggregateMetrics {
101        precision: class_metrics.values().map(|m| m.precision).sum::<f64>() / n_classes as f64,
102        recall: class_metrics.values().map(|m| m.recall).sum::<f64>() / n_classes as f64,
103        f1_score: class_metrics.values().map(|m| m.f1_score).sum::<f64>() / n_classes as f64,
104    };
105    
106    // Compute weighted averages
107    let total_support = class_metrics.values().map(|m| m.support).sum::<usize>() as f64;
108    let weighted_avg = AggregateMetrics {
109        precision: class_metrics.values()
110            .map(|m| m.precision * m.support as f64)
111            .sum::<f64>() / total_support,
112        recall: class_metrics.values()
113            .map(|m| m.recall * m.support as f64)
114            .sum::<f64>() / total_support,
115        f1_score: class_metrics.values()
116            .map(|m| m.f1_score * m.support as f64)
117            .sum::<f64>() / total_support,
118    };
119    
120    let accuracy = cm.diag().sum() as f64 / cm.sum() as f64;
121    
122    ClassificationReport {
123        class_metrics,
124        macro_avg,
125        weighted_avg,
126        accuracy,
127    }
128}
129
130/// Calculate F1 score for binary or multiclass classification
131pub fn f1_score(y_true: ArrayView1<usize>, y_pred: ArrayView1<usize>) -> f64 {
132    let report = classification_report(y_true, y_pred);
133    report.weighted_avg.f1_score
134}
135
136/// Calculate balanced accuracy
137pub fn balanced_accuracy(y_true: ArrayView1<usize>, y_pred: ArrayView1<usize>) -> f64 {
138    let report = classification_report(y_true, y_pred);
139    report.macro_avg.recall  // Balanced accuracy is macro-averaged recall
140}
141
142/// Display classification report
143impl std::fmt::Display for ClassificationReport {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        writeln!(f, "\n              precision    recall  f1-score   support\n")?;
146        
147        let mut classes: Vec<_> = self.class_metrics.keys().copied().collect();
148        classes.sort_unstable();
149        
150        for class in classes {
151            let metrics = &self.class_metrics[&class];
152            writeln!(
153                f,
154                "   class {:2}      {:5.2}      {:5.2}      {:5.2}      {:4}",
155                class, metrics.precision, metrics.recall, metrics.f1_score, metrics.support
156            )?;
157        }
158        
159        writeln!(f, "\n   accuracy                          {:5.2}", self.accuracy)?;
160        writeln!(
161            f,
162            "  macro avg      {:5.2}      {:5.2}      {:5.2}",
163            self.macro_avg.precision, self.macro_avg.recall, self.macro_avg.f1_score
164        )?;
165        writeln!(
166            f,
167            "weighted avg      {:5.2}      {:5.2}      {:5.2}",
168            self.weighted_avg.precision, self.weighted_avg.recall, self.weighted_avg.f1_score
169        )?;
170        
171        Ok(())
172    }
173}
174
175/// Prelude module for convenient imports
176pub mod prelude {
177    pub use crate::{
178        classification_report, confusion_matrix, f1_score, balanced_accuracy,
179        ClassificationReport, ClassMetrics, AggregateMetrics,
180    };
181}
182
183// Missing import
184use ndarray::Array2;