Skip to main content

yscv_eval/
classification.rs

1use std::fmt::Write as FmtWrite;
2
3use crate::EvalError;
4
5/// Averaging strategy for multi-class F1 score.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum F1Average {
8    /// Compute F1 per class and average (unweighted).
9    Macro,
10    /// Compute global TP/FP/FN then derive a single F1.
11    Micro,
12    /// Compute F1 per class and average weighted by support (class count in targets).
13    Weighted,
14}
15
16/// Compute classification accuracy as the fraction of correct predictions.
17///
18/// Returns a value in `[0, 1]`. Returns an error if the slices have different lengths.
19pub fn accuracy(predictions: &[usize], targets: &[usize]) -> Result<f32, EvalError> {
20    if predictions.len() != targets.len() {
21        return Err(EvalError::CountLengthMismatch {
22            ground_truth: targets.len(),
23            predictions: predictions.len(),
24        });
25    }
26    if predictions.is_empty() {
27        return Ok(0.0);
28    }
29    let correct = predictions
30        .iter()
31        .zip(targets.iter())
32        .filter(|(p, t)| p == t)
33        .count();
34    Ok(correct as f32 / predictions.len() as f32)
35}
36
37/// Compute a confusion matrix for `num_classes` classes.
38///
39/// `result[actual][predicted]` contains the count of samples with true class
40/// `actual` that were predicted as class `predicted`.
41pub fn confusion_matrix(
42    predictions: &[usize],
43    targets: &[usize],
44    num_classes: usize,
45) -> Result<Vec<Vec<usize>>, EvalError> {
46    if predictions.len() != targets.len() {
47        return Err(EvalError::CountLengthMismatch {
48            ground_truth: targets.len(),
49            predictions: predictions.len(),
50        });
51    }
52    let mut cm = vec![vec![0usize; num_classes]; num_classes];
53    for (&pred, &target) in predictions.iter().zip(targets.iter()) {
54        if target < num_classes && pred < num_classes {
55            cm[target][pred] += 1;
56        }
57    }
58    Ok(cm)
59}
60
61/// Compute per-class precision and recall from a confusion matrix.
62///
63/// Returns a `Vec` of `(precision, recall)` tuples, one per class.
64/// If a class has no predictions, precision is `0.0`; if no ground truth, recall is `0.0`.
65pub fn per_class_precision_recall(cm: &[Vec<usize>]) -> Vec<(f32, f32)> {
66    let n = cm.len();
67    let mut result = Vec::with_capacity(n);
68    for c in 0..n {
69        let tp = cm[c][c] as f32;
70        let col_sum: f32 = cm.iter().map(|row| row[c] as f32).sum();
71        let row_sum: f32 = cm[c].iter().sum::<usize>() as f32;
72
73        let precision = if col_sum > 0.0 { tp / col_sum } else { 0.0 };
74        let recall = if row_sum > 0.0 { tp / row_sum } else { 0.0 };
75        result.push((precision, recall));
76    }
77    result
78}
79
80/// Generate a human-readable classification report (similar to scikit-learn's
81/// `classification_report`).
82///
83/// Example output:
84/// ```text
85///               precision  recall  f1-score  support
86///          cat      0.800   0.889     0.842        9
87///          dog      0.857   0.750     0.800        8
88///     accuracy                        0.824       17
89/// ```
90pub fn classification_report(
91    predictions: &[usize],
92    targets: &[usize],
93    labels: &[&str],
94) -> Result<String, EvalError> {
95    let num_classes = labels.len();
96    let cm = confusion_matrix(predictions, targets, num_classes)?;
97    let pr = per_class_precision_recall(&cm);
98    let acc = accuracy(predictions, targets)?;
99
100    let max_label = labels.iter().map(|l| l.len()).max().unwrap_or(5).max(10);
101    let mut report = String::new();
102
103    writeln!(
104        report,
105        "{:>width$}  precision  recall  f1-score  support",
106        "",
107        width = max_label
108    )
109    .expect("write to String");
110
111    let total_support = targets.len();
112
113    for (i, label) in labels.iter().enumerate() {
114        let (prec, rec) = pr[i];
115        let f1 = if prec + rec > 0.0 {
116            2.0 * prec * rec / (prec + rec)
117        } else {
118            0.0
119        };
120        let support: usize = cm[i].iter().sum();
121        writeln!(
122            report,
123            "{:>width$}    {:.3}   {:.3}     {:.3}     {:>4}",
124            label,
125            prec,
126            rec,
127            f1,
128            support,
129            width = max_label
130        )
131        .expect("write to String");
132    }
133
134    writeln!(
135        report,
136        "{:>width$}                      {:.3}     {:>4}",
137        "accuracy",
138        acc,
139        total_support,
140        width = max_label
141    )
142    .expect("write to String");
143
144    Ok(report)
145}
146
147/// Compute F1 score with the specified averaging strategy.
148///
149/// `num_classes` must be at least as large as the maximum label value + 1.
150pub fn f1_score(
151    predictions: &[usize],
152    targets: &[usize],
153    num_classes: usize,
154    average: F1Average,
155) -> Result<f32, EvalError> {
156    if predictions.len() != targets.len() {
157        return Err(EvalError::CountLengthMismatch {
158            ground_truth: targets.len(),
159            predictions: predictions.len(),
160        });
161    }
162
163    let cm = confusion_matrix(predictions, targets, num_classes)?;
164
165    match average {
166        F1Average::Macro => {
167            let pr = per_class_precision_recall(&cm);
168            let mut sum_f1 = 0.0f32;
169            for &(prec, rec) in &pr {
170                let f1 = if prec + rec > 0.0 {
171                    2.0 * prec * rec / (prec + rec)
172                } else {
173                    0.0
174                };
175                sum_f1 += f1;
176            }
177            Ok(sum_f1 / num_classes as f32)
178        }
179        F1Average::Micro => {
180            let mut tp_total = 0usize;
181            let mut fp_total = 0usize;
182            let mut fn_total = 0usize;
183            for c in 0..num_classes {
184                let tp = cm[c][c];
185                let fp: usize = cm.iter().map(|row| row[c]).sum::<usize>() - tp;
186                let fn_c: usize = cm[c].iter().sum::<usize>() - tp;
187                tp_total += tp;
188                fp_total += fp;
189                fn_total += fn_c;
190            }
191            let precision = if tp_total + fp_total > 0 {
192                tp_total as f32 / (tp_total + fp_total) as f32
193            } else {
194                0.0
195            };
196            let recall = if tp_total + fn_total > 0 {
197                tp_total as f32 / (tp_total + fn_total) as f32
198            } else {
199                0.0
200            };
201            if precision + recall > 0.0 {
202                Ok(2.0 * precision * recall / (precision + recall))
203            } else {
204                Ok(0.0)
205            }
206        }
207        F1Average::Weighted => {
208            let pr = per_class_precision_recall(&cm);
209            let mut weighted_f1 = 0.0f32;
210            let total: usize = targets.len();
211            for c in 0..num_classes {
212                let support: usize = cm[c].iter().sum();
213                let (prec, rec) = pr[c];
214                let f1 = if prec + rec > 0.0 {
215                    2.0 * prec * rec / (prec + rec)
216                } else {
217                    0.0
218                };
219                weighted_f1 += f1 * support as f32;
220            }
221            if total > 0 {
222                Ok(weighted_f1 / total as f32)
223            } else {
224                Ok(0.0)
225            }
226        }
227    }
228}
229
230/// Compute precision-recall curve from binary classification scores and labels.
231///
232/// Returns `(precisions, recalls, thresholds)` sorted by decreasing threshold.
233pub fn precision_recall_curve(
234    scores: &[f32],
235    labels: &[bool],
236) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>), EvalError> {
237    if scores.len() != labels.len() {
238        return Err(EvalError::CountLengthMismatch {
239            ground_truth: labels.len(),
240            predictions: scores.len(),
241        });
242    }
243
244    let n = scores.len();
245    let total_pos = labels.iter().filter(|&&l| l).count() as f32;
246
247    // Sort indices by score descending
248    let mut indices: Vec<usize> = (0..n).collect();
249    indices.sort_unstable_by(|&a, &b| {
250        scores[b]
251            .partial_cmp(&scores[a])
252            .unwrap_or(std::cmp::Ordering::Equal)
253    });
254
255    let mut precisions = Vec::with_capacity(n);
256    let mut recalls = Vec::with_capacity(n);
257    let mut thresholds = Vec::with_capacity(n);
258
259    let mut tp = 0.0f32;
260
261    for (rank, &i) in indices.iter().enumerate() {
262        if labels[i] {
263            tp += 1.0;
264        }
265        let predicted_pos = (rank + 1) as f32;
266        precisions.push(tp / predicted_pos);
267        recalls.push(if total_pos > 0.0 { tp / total_pos } else { 0.0 });
268        thresholds.push(scores[i]);
269    }
270
271    Ok((precisions, recalls, thresholds))
272}
273
274/// Compute average precision (area under the precision-recall curve) using the trapezoidal rule.
275///
276/// Prepends the point (recall=0, precision=1) to ensure the full area is captured.
277pub fn average_precision(scores: &[f32], labels: &[bool]) -> Result<f32, EvalError> {
278    let (precisions, recalls, _) = precision_recall_curve(scores, labels)?;
279
280    if recalls.is_empty() {
281        return Ok(0.0);
282    }
283
284    // Prepend (recall=0, precision=1.0) as the starting point of the PR curve.
285    let mut full_recalls = Vec::with_capacity(recalls.len() + 1);
286    let mut full_precisions = Vec::with_capacity(precisions.len() + 1);
287    full_recalls.push(0.0f32);
288    full_precisions.push(1.0f32);
289    full_recalls.extend_from_slice(&recalls);
290    full_precisions.extend_from_slice(&precisions);
291
292    // Trapezoidal rule over recall (which is monotonically non-decreasing)
293    let mut ap = 0.0f32;
294    for i in 1..full_recalls.len() {
295        let dr = full_recalls[i] - full_recalls[i - 1];
296        ap += dr * (full_precisions[i] + full_precisions[i - 1]) / 2.0;
297    }
298    Ok(ap)
299}
300
301/// Cohen's kappa coefficient measuring inter-annotator agreement.
302///
303/// κ = (p_o - p_e) / (1 - p_e) where p_o is observed agreement and p_e is expected agreement.
304pub fn cohens_kappa(
305    predictions: &[usize],
306    targets: &[usize],
307    num_classes: usize,
308) -> Result<f32, EvalError> {
309    if predictions.len() != targets.len() {
310        return Err(EvalError::CountLengthMismatch {
311            ground_truth: targets.len(),
312            predictions: predictions.len(),
313        });
314    }
315
316    let n = predictions.len();
317    if n == 0 {
318        return Ok(0.0);
319    }
320
321    let cm = confusion_matrix(predictions, targets, num_classes)?;
322    let n_f = n as f32;
323
324    // Observed agreement
325    let p_o: f32 = (0..num_classes).map(|c| cm[c][c] as f32).sum::<f32>() / n_f;
326
327    // Expected agreement
328    let mut p_e = 0.0f32;
329    for c in 0..num_classes {
330        let row_sum: f32 = cm[c].iter().sum::<usize>() as f32; // true class c count
331        let col_sum: f32 = cm.iter().map(|row| row[c]).sum::<usize>() as f32; // predicted class c count
332        p_e += (row_sum / n_f) * (col_sum / n_f);
333    }
334
335    if (1.0 - p_e).abs() < 1e-10 {
336        return Ok(1.0); // perfect agreement when p_e ≈ 1
337    }
338
339    Ok((p_o - p_e) / (1.0 - p_e))
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_accuracy_perfect() {
348        let preds = vec![0, 1, 2, 0, 1];
349        let targets = vec![0, 1, 2, 0, 1];
350        let acc = accuracy(&preds, &targets).unwrap();
351        assert!((acc - 1.0).abs() < 1e-6);
352    }
353
354    #[test]
355    fn test_accuracy_half() {
356        let preds = vec![0, 0, 1, 1];
357        let targets = vec![0, 1, 0, 1];
358        let acc = accuracy(&preds, &targets).unwrap();
359        assert!((acc - 0.5).abs() < 1e-6);
360    }
361
362    #[test]
363    fn test_accuracy_length_mismatch() {
364        assert!(accuracy(&[0, 1], &[0]).is_err());
365    }
366
367    #[test]
368    fn test_confusion_matrix_basic() {
369        let preds = vec![0, 0, 1, 1, 2, 2];
370        let targets = vec![0, 1, 1, 2, 2, 0];
371        let cm = confusion_matrix(&preds, &targets, 3).unwrap();
372
373        // Diagonal: correct predictions.
374        assert_eq!(cm[0][0], 1); // target=0, pred=0
375        assert_eq!(cm[1][1], 1); // target=1, pred=1
376        assert_eq!(cm[2][2], 1); // target=2, pred=2
377
378        // Off-diagonal: misclassifications.
379        assert_eq!(cm[1][0], 1); // target=1, pred=0
380        assert_eq!(cm[2][1], 1); // target=2, pred=1
381        assert_eq!(cm[0][2], 1); // target=0, pred=2
382    }
383
384    #[test]
385    fn test_per_class_precision_recall() {
386        // 2 classes: [0, 0, 1, 1] vs [0, 1, 0, 1]
387        let cm = confusion_matrix(&[0, 0, 1, 1], &[0, 1, 0, 1], 2).unwrap();
388        let pr = per_class_precision_recall(&cm);
389        // Class 0: TP=1, FP=1 (pred=0 when target=1), FN=1 => precision=0.5, recall=0.5
390        assert!((pr[0].0 - 0.5).abs() < 1e-5);
391        assert!((pr[0].1 - 0.5).abs() < 1e-5);
392        // Class 1: same situation
393        assert!((pr[1].0 - 0.5).abs() < 1e-5);
394        assert!((pr[1].1 - 0.5).abs() < 1e-5);
395    }
396
397    #[test]
398    fn test_classification_report_format() {
399        let preds = vec![0, 0, 1, 1, 1];
400        let targets = vec![0, 1, 1, 1, 0];
401        let report = classification_report(&preds, &targets, &["cat", "dog"]).unwrap();
402
403        assert!(report.contains("precision"));
404        assert!(report.contains("recall"));
405        assert!(report.contains("cat"));
406        assert!(report.contains("dog"));
407        assert!(report.contains("accuracy"));
408    }
409
410    #[test]
411    fn test_accuracy_empty() {
412        let acc = accuracy(&[], &[]).unwrap();
413        assert_eq!(acc, 0.0);
414    }
415}