imbalanced_metrics/
lib.rs1#![deny(missing_docs)]
7#![warn(clippy::all, clippy::pedantic)]
8#![allow(clippy::module_name_repetitions)]
9
10use ndarray::ArrayView1;
11use rayon::prelude::*;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17pub struct ClassificationReport {
18 pub class_metrics: HashMap<usize, ClassMetrics>,
20 pub macro_avg: AggregateMetrics,
22 pub weighted_avg: AggregateMetrics,
24 pub accuracy: f64,
26}
27
28#[derive(Debug, Clone, Copy)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct ClassMetrics {
32 pub precision: f64,
34 pub recall: f64,
36 pub f1_score: f64,
38 pub support: usize,
40}
41
42#[derive(Debug, Clone, Copy)]
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45pub struct AggregateMetrics {
46 pub precision: f64,
48 pub recall: f64,
50 pub f1_score: f64,
52}
53
54pub fn confusion_matrix(y_true: ArrayView1<usize>, y_pred: ArrayView1<usize>) -> Array2<usize> {
56 use ndarray::Array2;
57
58 assert_eq!(y_true.len(), y_pred.len(), "Arrays must have same length");
59
60 let n_classes = y_true.iter().chain(y_pred.iter()).max().unwrap() + 1;
61 let mut matrix = Array2::zeros((n_classes, n_classes));
62
63 for (&true_label, &pred_label) in y_true.iter().zip(y_pred.iter()) {
64 matrix[[true_label, pred_label]] += 1;
65 }
66
67 matrix
68}
69
70pub fn classification_report(
72 y_true: ArrayView1<usize>,
73 y_pred: ArrayView1<usize>
74) -> ClassificationReport {
75 let cm = confusion_matrix(y_true, y_pred);
76 let n_classes = cm.nrows();
77
78 let class_metrics: HashMap<usize, ClassMetrics> = (0..n_classes)
79 .into_par_iter()
80 .map(|class| {
81 let tp = cm[[class, class]] as f64;
82 let fp = cm.column(class).sum() as f64 - tp;
83 let false_neg = cm.row(class).sum() as f64 - tp;
84
85 let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
86 let recall = if tp + false_neg > 0.0 { tp / (tp + false_neg) } else { 0.0 };
87 let f1_score = if precision + recall > 0.0 {
88 2.0 * precision * recall / (precision + recall)
89 } else {
90 0.0
91 };
92
93 let support = cm.row(class).sum();
94
95 (class, ClassMetrics { precision, recall, f1_score, support })
96 })
97 .collect();
98
99 let macro_avg = AggregateMetrics {
101 precision: class_metrics.values().map(|m| m.precision).sum::<f64>() / n_classes as f64,
102 recall: class_metrics.values().map(|m| m.recall).sum::<f64>() / n_classes as f64,
103 f1_score: class_metrics.values().map(|m| m.f1_score).sum::<f64>() / n_classes as f64,
104 };
105
106 let total_support = class_metrics.values().map(|m| m.support).sum::<usize>() as f64;
108 let weighted_avg = AggregateMetrics {
109 precision: class_metrics.values()
110 .map(|m| m.precision * m.support as f64)
111 .sum::<f64>() / total_support,
112 recall: class_metrics.values()
113 .map(|m| m.recall * m.support as f64)
114 .sum::<f64>() / total_support,
115 f1_score: class_metrics.values()
116 .map(|m| m.f1_score * m.support as f64)
117 .sum::<f64>() / total_support,
118 };
119
120 let accuracy = cm.diag().sum() as f64 / cm.sum() as f64;
121
122 ClassificationReport {
123 class_metrics,
124 macro_avg,
125 weighted_avg,
126 accuracy,
127 }
128}
129
130pub fn f1_score(y_true: ArrayView1<usize>, y_pred: ArrayView1<usize>) -> f64 {
132 let report = classification_report(y_true, y_pred);
133 report.weighted_avg.f1_score
134}
135
136pub fn balanced_accuracy(y_true: ArrayView1<usize>, y_pred: ArrayView1<usize>) -> f64 {
138 let report = classification_report(y_true, y_pred);
139 report.macro_avg.recall }
141
142impl std::fmt::Display for ClassificationReport {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 writeln!(f, "\n precision recall f1-score support\n")?;
146
147 let mut classes: Vec<_> = self.class_metrics.keys().copied().collect();
148 classes.sort_unstable();
149
150 for class in classes {
151 let metrics = &self.class_metrics[&class];
152 writeln!(
153 f,
154 " class {:2} {:5.2} {:5.2} {:5.2} {:4}",
155 class, metrics.precision, metrics.recall, metrics.f1_score, metrics.support
156 )?;
157 }
158
159 writeln!(f, "\n accuracy {:5.2}", self.accuracy)?;
160 writeln!(
161 f,
162 " macro avg {:5.2} {:5.2} {:5.2}",
163 self.macro_avg.precision, self.macro_avg.recall, self.macro_avg.f1_score
164 )?;
165 writeln!(
166 f,
167 "weighted avg {:5.2} {:5.2} {:5.2}",
168 self.weighted_avg.precision, self.weighted_avg.recall, self.weighted_avg.f1_score
169 )?;
170
171 Ok(())
172 }
173}
174
175pub mod prelude {
177 pub use crate::{
178 classification_report, confusion_matrix, f1_score, balanced_accuracy,
179 ClassificationReport, ClassMetrics, AggregateMetrics,
180 };
181}
182
183use ndarray::Array2;