1use std::fmt::Write as FmtWrite;
2
3use crate::EvalError;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum F1Average {
8 Macro,
10 Micro,
12 Weighted,
14}
15
16pub fn accuracy(predictions: &[usize], targets: &[usize]) -> Result<f32, EvalError> {
20 if predictions.len() != targets.len() {
21 return Err(EvalError::CountLengthMismatch {
22 ground_truth: targets.len(),
23 predictions: predictions.len(),
24 });
25 }
26 if predictions.is_empty() {
27 return Ok(0.0);
28 }
29 let correct = predictions
30 .iter()
31 .zip(targets.iter())
32 .filter(|(p, t)| p == t)
33 .count();
34 Ok(correct as f32 / predictions.len() as f32)
35}
36
37pub fn confusion_matrix(
42 predictions: &[usize],
43 targets: &[usize],
44 num_classes: usize,
45) -> Result<Vec<Vec<usize>>, EvalError> {
46 if predictions.len() != targets.len() {
47 return Err(EvalError::CountLengthMismatch {
48 ground_truth: targets.len(),
49 predictions: predictions.len(),
50 });
51 }
52 let mut cm = vec![vec![0usize; num_classes]; num_classes];
53 for (&pred, &target) in predictions.iter().zip(targets.iter()) {
54 if target < num_classes && pred < num_classes {
55 cm[target][pred] += 1;
56 }
57 }
58 Ok(cm)
59}
60
61pub fn per_class_precision_recall(cm: &[Vec<usize>]) -> Vec<(f32, f32)> {
66 let n = cm.len();
67 let mut result = Vec::with_capacity(n);
68 for c in 0..n {
69 let tp = cm[c][c] as f32;
70 let col_sum: f32 = cm.iter().map(|row| row[c] as f32).sum();
71 let row_sum: f32 = cm[c].iter().sum::<usize>() as f32;
72
73 let precision = if col_sum > 0.0 { tp / col_sum } else { 0.0 };
74 let recall = if row_sum > 0.0 { tp / row_sum } else { 0.0 };
75 result.push((precision, recall));
76 }
77 result
78}
79
80pub fn classification_report(
91 predictions: &[usize],
92 targets: &[usize],
93 labels: &[&str],
94) -> Result<String, EvalError> {
95 let num_classes = labels.len();
96 let cm = confusion_matrix(predictions, targets, num_classes)?;
97 let pr = per_class_precision_recall(&cm);
98 let acc = accuracy(predictions, targets)?;
99
100 let max_label = labels.iter().map(|l| l.len()).max().unwrap_or(5).max(10);
101 let mut report = String::new();
102
103 writeln!(
104 report,
105 "{:>width$} precision recall f1-score support",
106 "",
107 width = max_label
108 )
109 .expect("write to String");
110
111 let total_support = targets.len();
112
113 for (i, label) in labels.iter().enumerate() {
114 let (prec, rec) = pr[i];
115 let f1 = if prec + rec > 0.0 {
116 2.0 * prec * rec / (prec + rec)
117 } else {
118 0.0
119 };
120 let support: usize = cm[i].iter().sum();
121 writeln!(
122 report,
123 "{:>width$} {:.3} {:.3} {:.3} {:>4}",
124 label,
125 prec,
126 rec,
127 f1,
128 support,
129 width = max_label
130 )
131 .expect("write to String");
132 }
133
134 writeln!(
135 report,
136 "{:>width$} {:.3} {:>4}",
137 "accuracy",
138 acc,
139 total_support,
140 width = max_label
141 )
142 .expect("write to String");
143
144 Ok(report)
145}
146
147pub fn f1_score(
151 predictions: &[usize],
152 targets: &[usize],
153 num_classes: usize,
154 average: F1Average,
155) -> Result<f32, EvalError> {
156 if predictions.len() != targets.len() {
157 return Err(EvalError::CountLengthMismatch {
158 ground_truth: targets.len(),
159 predictions: predictions.len(),
160 });
161 }
162
163 let cm = confusion_matrix(predictions, targets, num_classes)?;
164
165 match average {
166 F1Average::Macro => {
167 let pr = per_class_precision_recall(&cm);
168 let mut sum_f1 = 0.0f32;
169 for &(prec, rec) in &pr {
170 let f1 = if prec + rec > 0.0 {
171 2.0 * prec * rec / (prec + rec)
172 } else {
173 0.0
174 };
175 sum_f1 += f1;
176 }
177 Ok(sum_f1 / num_classes as f32)
178 }
179 F1Average::Micro => {
180 let mut tp_total = 0usize;
181 let mut fp_total = 0usize;
182 let mut fn_total = 0usize;
183 for c in 0..num_classes {
184 let tp = cm[c][c];
185 let fp: usize = cm.iter().map(|row| row[c]).sum::<usize>() - tp;
186 let fn_c: usize = cm[c].iter().sum::<usize>() - tp;
187 tp_total += tp;
188 fp_total += fp;
189 fn_total += fn_c;
190 }
191 let precision = if tp_total + fp_total > 0 {
192 tp_total as f32 / (tp_total + fp_total) as f32
193 } else {
194 0.0
195 };
196 let recall = if tp_total + fn_total > 0 {
197 tp_total as f32 / (tp_total + fn_total) as f32
198 } else {
199 0.0
200 };
201 if precision + recall > 0.0 {
202 Ok(2.0 * precision * recall / (precision + recall))
203 } else {
204 Ok(0.0)
205 }
206 }
207 F1Average::Weighted => {
208 let pr = per_class_precision_recall(&cm);
209 let mut weighted_f1 = 0.0f32;
210 let total: usize = targets.len();
211 for c in 0..num_classes {
212 let support: usize = cm[c].iter().sum();
213 let (prec, rec) = pr[c];
214 let f1 = if prec + rec > 0.0 {
215 2.0 * prec * rec / (prec + rec)
216 } else {
217 0.0
218 };
219 weighted_f1 += f1 * support as f32;
220 }
221 if total > 0 {
222 Ok(weighted_f1 / total as f32)
223 } else {
224 Ok(0.0)
225 }
226 }
227 }
228}
229
230pub fn precision_recall_curve(
234 scores: &[f32],
235 labels: &[bool],
236) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>), EvalError> {
237 if scores.len() != labels.len() {
238 return Err(EvalError::CountLengthMismatch {
239 ground_truth: labels.len(),
240 predictions: scores.len(),
241 });
242 }
243
244 let n = scores.len();
245 let total_pos = labels.iter().filter(|&&l| l).count() as f32;
246
247 let mut indices: Vec<usize> = (0..n).collect();
249 indices.sort_unstable_by(|&a, &b| {
250 scores[b]
251 .partial_cmp(&scores[a])
252 .unwrap_or(std::cmp::Ordering::Equal)
253 });
254
255 let mut precisions = Vec::with_capacity(n);
256 let mut recalls = Vec::with_capacity(n);
257 let mut thresholds = Vec::with_capacity(n);
258
259 let mut tp = 0.0f32;
260
261 for (rank, &i) in indices.iter().enumerate() {
262 if labels[i] {
263 tp += 1.0;
264 }
265 let predicted_pos = (rank + 1) as f32;
266 precisions.push(tp / predicted_pos);
267 recalls.push(if total_pos > 0.0 { tp / total_pos } else { 0.0 });
268 thresholds.push(scores[i]);
269 }
270
271 Ok((precisions, recalls, thresholds))
272}
273
274pub fn average_precision(scores: &[f32], labels: &[bool]) -> Result<f32, EvalError> {
278 let (precisions, recalls, _) = precision_recall_curve(scores, labels)?;
279
280 if recalls.is_empty() {
281 return Ok(0.0);
282 }
283
284 let mut full_recalls = Vec::with_capacity(recalls.len() + 1);
286 let mut full_precisions = Vec::with_capacity(precisions.len() + 1);
287 full_recalls.push(0.0f32);
288 full_precisions.push(1.0f32);
289 full_recalls.extend_from_slice(&recalls);
290 full_precisions.extend_from_slice(&precisions);
291
292 let mut ap = 0.0f32;
294 for i in 1..full_recalls.len() {
295 let dr = full_recalls[i] - full_recalls[i - 1];
296 ap += dr * (full_precisions[i] + full_precisions[i - 1]) / 2.0;
297 }
298 Ok(ap)
299}
300
301pub fn cohens_kappa(
305 predictions: &[usize],
306 targets: &[usize],
307 num_classes: usize,
308) -> Result<f32, EvalError> {
309 if predictions.len() != targets.len() {
310 return Err(EvalError::CountLengthMismatch {
311 ground_truth: targets.len(),
312 predictions: predictions.len(),
313 });
314 }
315
316 let n = predictions.len();
317 if n == 0 {
318 return Ok(0.0);
319 }
320
321 let cm = confusion_matrix(predictions, targets, num_classes)?;
322 let n_f = n as f32;
323
324 let p_o: f32 = (0..num_classes).map(|c| cm[c][c] as f32).sum::<f32>() / n_f;
326
327 let mut p_e = 0.0f32;
329 for c in 0..num_classes {
330 let row_sum: f32 = cm[c].iter().sum::<usize>() as f32; let col_sum: f32 = cm.iter().map(|row| row[c]).sum::<usize>() as f32; p_e += (row_sum / n_f) * (col_sum / n_f);
333 }
334
335 if (1.0 - p_e).abs() < 1e-10 {
336 return Ok(1.0); }
338
339 Ok((p_o - p_e) / (1.0 - p_e))
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_accuracy_perfect() {
348 let preds = vec![0, 1, 2, 0, 1];
349 let targets = vec![0, 1, 2, 0, 1];
350 let acc = accuracy(&preds, &targets).unwrap();
351 assert!((acc - 1.0).abs() < 1e-6);
352 }
353
354 #[test]
355 fn test_accuracy_half() {
356 let preds = vec![0, 0, 1, 1];
357 let targets = vec![0, 1, 0, 1];
358 let acc = accuracy(&preds, &targets).unwrap();
359 assert!((acc - 0.5).abs() < 1e-6);
360 }
361
362 #[test]
363 fn test_accuracy_length_mismatch() {
364 assert!(accuracy(&[0, 1], &[0]).is_err());
365 }
366
367 #[test]
368 fn test_confusion_matrix_basic() {
369 let preds = vec![0, 0, 1, 1, 2, 2];
370 let targets = vec![0, 1, 1, 2, 2, 0];
371 let cm = confusion_matrix(&preds, &targets, 3).unwrap();
372
373 assert_eq!(cm[0][0], 1); assert_eq!(cm[1][1], 1); assert_eq!(cm[2][2], 1); assert_eq!(cm[1][0], 1); assert_eq!(cm[2][1], 1); assert_eq!(cm[0][2], 1); }
383
384 #[test]
385 fn test_per_class_precision_recall() {
386 let cm = confusion_matrix(&[0, 0, 1, 1], &[0, 1, 0, 1], 2).unwrap();
388 let pr = per_class_precision_recall(&cm);
389 assert!((pr[0].0 - 0.5).abs() < 1e-5);
391 assert!((pr[0].1 - 0.5).abs() < 1e-5);
392 assert!((pr[1].0 - 0.5).abs() < 1e-5);
394 assert!((pr[1].1 - 0.5).abs() < 1e-5);
395 }
396
397 #[test]
398 fn test_classification_report_format() {
399 let preds = vec![0, 0, 1, 1, 1];
400 let targets = vec![0, 1, 1, 1, 0];
401 let report = classification_report(&preds, &targets, &["cat", "dog"]).unwrap();
402
403 assert!(report.contains("precision"));
404 assert!(report.contains("recall"));
405 assert!(report.contains("cat"));
406 assert!(report.contains("dog"));
407 assert!(report.contains("accuracy"));
408 }
409
410 #[test]
411 fn test_accuracy_empty() {
412 let acc = accuracy(&[], &[]).unwrap();
413 assert_eq!(acc, 0.0);
414 }
415}