eval_metrics/
classification.rs

1//!
2//! Provides support for both binary and multi-class classification metrics
3//!
4
5use std::cmp::Ordering;
6use crate::util;
7use crate::numeric::Scalar;
8use crate::error::EvalError;
9use crate::display;
10
11///
12/// Confusion matrix for binary classification
13///
14#[derive(Copy, Clone, Debug, Eq, PartialEq)]
15pub struct BinaryConfusionMatrix {
16    /// true positive count
17    pub tp_count: usize,
18    /// false positive count
19    pub fp_count: usize,
20    /// true negative count
21    pub tn_count: usize,
22    /// false negative count
23    pub fn_count: usize,
24    /// count sum
25    sum: usize
26}
27
28impl BinaryConfusionMatrix {
29
30    ///
31    /// Computes a new binary confusion matrix from the provided scores and labels
32    ///
33    /// # Arguments
34    ///
35    /// * `scores` - vector of scores
36    /// * `labels` - vector of boolean labels
37    /// * `threshold` - decision threshold value for classifying scores
38    ///
39    /// # Errors
40    ///
41    /// An invalid input error will be returned if either scores or labels are empty, or if their
42    /// lengths do not match. An undefined metric error will be returned if scores contain any value
43    /// that is not finite.
44    ///
45    /// # Examples
46    ///
47    /// ```
48    /// # use eval_metrics::error::EvalError;
49    /// # fn main() -> Result<(), EvalError> {
50    /// use eval_metrics::classification::BinaryConfusionMatrix;
51    /// let scores = vec![0.4, 0.7, 0.1, 0.3, 0.9];
52    /// let labels = vec![false, true, false, true, true];
53    /// let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5)?;
54    /// # Ok(())}
55    /// ```
56    ///
57    pub fn compute<T: Scalar>(scores: &Vec<T>,
58                              labels: &Vec<bool>,
59                              threshold: T) -> Result<BinaryConfusionMatrix, EvalError> {
60        util::validate_input_dims(scores, labels).and_then(|()| {
61            let mut counts = [0, 0, 0, 0];
62            for (&score, &label) in scores.iter().zip(labels) {
63                if !score.is_finite() {
64                    return Err(EvalError::infinite_value())
65                } else if score >= threshold && label {
66                    counts[3] += 1;
67                } else if score >= threshold {
68                    counts[2] += 1;
69                } else if score < threshold && !label {
70                    counts[0] += 1;
71                } else {
72                    counts[1] += 1;
73                }
74            };
75            let sum = counts.iter().sum();
76            Ok(BinaryConfusionMatrix {
77                tp_count: counts[3],
78                fp_count: counts[2],
79                tn_count: counts[0],
80                fn_count: counts[1],
81                sum
82            })
83        })
84    }
85
86    ///
87    /// Constructs a binary confusion matrix with the provided counts
88    ///
89    /// # Arguments
90    ///
91    /// * `tp_count` - true positive count
92    /// * `fp_count` - false positive count
93    /// * `tn_count` - true negative count
94    /// * `fn_count` - false negative count
95    ///
96    /// # Errors
97    ///
98    /// An invalid input error will be returned if all provided counts are zero
99    ///
100    pub fn from_counts(tp_count: usize,
101                       fp_count: usize,
102                       tn_count: usize,
103                       fn_count: usize) -> Result<BinaryConfusionMatrix, EvalError> {
104        match tp_count + fp_count + tn_count + fn_count {
105            0 => Err(EvalError::invalid_input("Confusion matrix has all zero counts")),
106            sum => Ok(BinaryConfusionMatrix {tp_count, fp_count, tn_count, fn_count, sum})
107        }
108    }
109
110    ///
111    /// Computes accuracy
112    ///
113    pub fn accuracy(&self) -> Result<f64, EvalError> {
114        let num = self.tp_count + self.tn_count;
115        match self.sum {
116            // This should never happen as long as we prevent empty confusion matrices
117            0 => Err(EvalError::undefined_metric("Accuracy")),
118            sum => Ok(num as f64 / sum as f64)
119        }
120    }
121
122    ///
123    /// Computes precision
124    ///
125    pub fn precision(&self) -> Result<f64, EvalError> {
126        match self.tp_count + self.fp_count {
127            0 => Err(EvalError::undefined_metric("Precision")),
128            den => Ok((self.tp_count as f64) / den as f64)
129        }
130    }
131
132    ///
133    /// Computes recall
134    ///
135    pub fn recall(&self) -> Result<f64, EvalError> {
136        match self.tp_count + self.fn_count {
137            0 => Err(EvalError::undefined_metric("Recall")),
138            den => Ok((self.tp_count as f64) / den as f64)
139        }
140    }
141
142    ///
143    /// Computes F1
144    ///
145    pub fn f1(&self) -> Result<f64, EvalError> {
146        match (self.precision(), self.recall()) {
147            (Ok(p), Ok(r)) if p == 0.0 && r == 0.0 => Ok(0.0),
148            (Ok(p), Ok(r)) => Ok(2.0 * (p * r) / (p + r)),
149            (Err(e), _) => Err(e),
150            (_, Err(e)) => Err(e)
151        }
152    }
153
154    ///
155    /// Computes Matthews correlation coefficient (phi)
156    ///
157    pub fn mcc(&self) -> Result<f64, EvalError> {
158        let n = self.sum as f64;
159        let s = (self.tp_count + self.fn_count) as f64 / n;
160        let p = (self.tp_count + self.fp_count) as f64 / n;
161        match (p * s * (1.0 - s) * (1.0 - p)).sqrt() {
162            den if den == 0.0 => Err(EvalError::undefined_metric("MCC")),
163            den => Ok(((self.tp_count as f64 / n) - s * p) / den)
164        }
165    }
166}
167
168impl std::fmt::Display for BinaryConfusionMatrix {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        let counts = vec![
171            vec![self.tp_count, self.fp_count],
172            vec![self.fn_count, self.tn_count]
173        ];
174        let outcomes = vec![String::from("Positive"), String::from("Negative")];
175        write!(f, "{}", display::stringify_confusion_matrix(&counts, &outcomes))
176    }
177}
178
179///
180/// Represents a single point along a roc curve
181///
182#[derive(Copy, Clone, Debug, PartialEq)]
183pub struct RocPoint<T: Scalar> {
184    /// True positive rate
185    pub tp_rate: T,
186    /// False positive rate
187    pub fp_rate: T,
188    /// Score threshold
189    pub threshold: T
190}
191
192///
193/// Represents a full roc curve
194///
195#[derive(Clone, Debug)]
196pub struct RocCurve<T: Scalar> {
197    /// Roc curve points
198    pub points: Vec<RocPoint<T>>,
199    /// Length
200    dim: usize
201}
202
203impl <T: Scalar> RocCurve<T> {
204
205    ///
206    /// Computes the roc curve from the provided data
207    ///
208    /// # Arguments
209    ///
210    /// * `scores` - vector of scores
211    /// * `labels` - vector of labels
212    ///
213    /// # Errors
214    ///
215    /// An invalid input error will be returned if either scores or labels are empty or contain a
216    /// single data point, or if their lengths do not match. An undefined metric error will be
217    /// returned if scores contain any value that is not finite or if labels are all constant.
218    ///
219    /// # Examples
220    ///
221    /// ```
222    /// # use eval_metrics::error::EvalError;
223    /// # fn main() -> Result<(), EvalError> {
224    /// use eval_metrics::classification::RocCurve;
225    /// let scores = vec![0.4, 0.7, 0.1, 0.3, 0.9];
226    /// let labels = vec![false, true, false, true, true];
227    /// let roc = RocCurve::compute(&scores, &labels)?;
228    /// # Ok(())}
229    /// ```
230    ///
231    pub fn compute(scores: &Vec<T>, labels: &Vec<bool>) -> Result<RocCurve<T>, EvalError> {
232        util::validate_input_dims(scores, labels).and_then(|()| {
233            // roc not defined for a single data point
234            let n = match scores.len() {
235                1 => return Err(EvalError::invalid_input(
236                    "Unable to compute roc curve on single data point"
237                )),
238                len => len
239            };
240            let (mut pairs, np) = create_pairs(scores, labels)?;
241            let nn = n - np;
242            sort_pairs_descending(&mut pairs);
243            let mut tpc = if pairs[0].1 {1} else {0};
244            let mut fpc = 1 - tpc;
245            let mut points = Vec::<RocPoint<T>>::new();
246            let mut last_tpr = T::zero();
247            let mut last_fpr = T::zero();
248            let mut trend: Option<RocTrend> = None;
249
250            for i in 1..n {
251                if pairs[i].0 != pairs[i-1].0 {
252                    let tp_rate = T::from_usize(tpc) / T::from_usize(np);
253                    let fp_rate = T::from_usize(fpc) / T::from_usize(nn);
254                    if !tp_rate.is_finite() || !fp_rate.is_finite() {
255                        return Err(EvalError::undefined_metric("ROC"))
256                    }
257                    let threshold = pairs[i-1].0;
258                    match trend {
259                        Some(RocTrend::Horizontal) => if tp_rate > last_tpr {
260                            points.push(RocPoint {tp_rate, fp_rate, threshold});
261                        } else if let Some(mut point) = points.last_mut() {
262                            point.fp_rate = fp_rate;
263                            point.threshold = threshold;
264                        },
265                        Some(RocTrend::Vertical) => if fp_rate > last_fpr {
266                            points.push(RocPoint {tp_rate, fp_rate, threshold})
267                        } else if let Some(mut point) = points.last_mut() {
268                            point.tp_rate = tp_rate;
269                            point.threshold = threshold;
270                        },
271                        _ => points.push(RocPoint {tp_rate, fp_rate, threshold}),
272                    }
273
274                    trend = if fp_rate > last_fpr && tp_rate == last_tpr {
275                        Some(RocTrend::Horizontal)
276                    } else if tp_rate > last_tpr && fp_rate == last_fpr {
277                        Some(RocTrend::Vertical)
278                    } else {
279                        Some(RocTrend::Diagonal)
280                    };
281                    last_tpr = tp_rate;
282                    last_fpr = fp_rate;
283                }
284                if pairs[i].1 {
285                    tpc += 1;
286                } else {
287                    fpc += 1;
288                }
289            }
290
291            if let Some(mut point) = points.last_mut() {
292                if point.tp_rate != T::one() || point.fp_rate != T::one() {
293                    let threshold = pairs.last().unwrap().0;
294                    match trend {
295                        Some(RocTrend::Horizontal) if point.tp_rate == T::one() => {
296                            point.fp_rate = T::one();
297                            point.threshold = threshold;
298                        },
299                        Some(RocTrend::Vertical) if point.fp_rate == T::one() => {
300                            point.tp_rate = T::one();
301                            point.threshold = threshold;
302                        }
303                        _ => points.push(RocPoint {
304                            tp_rate: T::one(), fp_rate: T::one(), threshold
305                        })
306                    }
307                }
308            }
309
310            match points.len() {
311                0 => Err(EvalError::constant_input_data()),
312                dim => Ok(RocCurve {points, dim})
313            }
314        })
315    }
316
317    ///
318    /// Computes AUC from the roc curve
319    ///
320    pub fn auc(&self) -> T {
321        let mut val = self.points[0].tp_rate * self.points[0].fp_rate / T::from_f64(2.0);
322        for i in 1..self.dim {
323            let fpr_diff = self.points[i].fp_rate - self.points[i-1].fp_rate;
324            let a = self.points[i-1].tp_rate * fpr_diff;
325            let tpr_diff = self.points[i].tp_rate - self.points[i-1].tp_rate;
326            let b = tpr_diff * fpr_diff / T::from_f64(2.0);
327            val += a + b;
328        }
329        return val
330    }
331}
332
333///
334/// Represents a single point along a precision-recall curve
335///
336#[derive(Copy, Clone, Debug, PartialEq)]
337pub struct PrPoint<T: Scalar> {
338    /// Precision value
339    pub precision: T,
340    /// Recall value
341    pub recall: T,
342    /// Score threshold
343    pub threshold: T
344}
345
346///
347/// Represents a full precision-recall curve
348///
349#[derive(Clone, Debug)]
350pub struct PrCurve<T: Scalar> {
351    /// PR curve points
352    pub points: Vec<PrPoint<T>>,
353    /// Length
354    dim: usize
355}
356
357impl <T: Scalar> PrCurve<T> {
358
359    ///
360    /// Computes the precision-recall curve from the provided data
361    ///
362    /// # Arguments
363    ///
364    /// * `scores` - vector of scores
365    /// * `labels` - vector of labels
366    ///
367    /// # Errors
368    ///
369    /// An invalid input error will be returned if either scores or labels are empty or contain a
370    /// single data point, or if their lengths do not match. An undefined metric error will be
371    /// returned if scores contain any value that is not finite, or if labels are all false.
372    ///
373    /// # Examples
374    ///
375    /// ```
376    /// # use eval_metrics::error::EvalError;
377    /// # fn main() -> Result<(), EvalError> {
378    /// use eval_metrics::classification::PrCurve;
379    /// let scores = vec![0.4, 0.7, 0.1, 0.3, 0.9];
380    /// let labels = vec![false, true, false, true, true];
381    /// let pr = PrCurve::compute(&scores, &labels)?;
382    /// # Ok(())}
383    /// ```
384    ///
385    pub fn compute(scores: &Vec<T>, labels: &Vec<bool>) -> Result<PrCurve<T>, EvalError> {
386        util::validate_input_dims(scores, labels).and_then(|()| {
387            let n = match scores.len() {
388                1 => return Err(EvalError::invalid_input(
389                    "Unable to compute pr curve on single data point"
390                )),
391                len => len
392            };
393            let (mut pairs, mut fnc) = create_pairs(scores, labels)?;
394            sort_pairs_descending(&mut pairs);
395            let mut tpc = 0;
396            let mut fpc = 0;
397            let mut points = Vec::<PrPoint<T>>::new();
398            let mut last_rec = T::zero();
399
400            for i in 0..n {
401                if pairs[i].1 {
402                    tpc += 1;
403                    fnc -= 1;
404                } else {
405                    fpc += 1;
406                }
407                if (i < n-1 && pairs[i].0 != pairs[i+1].0) || i == n-1 {
408                    let precision = T::from_usize(tpc) / T::from_usize(tpc + fpc);
409                    let recall = T::from_usize(tpc) / T::from_usize(tpc + fnc);
410                    if !precision.is_finite() || !recall.is_finite() {
411                        return Err(EvalError::undefined_metric("PR"))
412                    }
413                    let threshold = pairs[i].0;
414                    if recall != last_rec {
415                        points.push(PrPoint {precision, recall, threshold});
416                    }
417                    last_rec = recall;
418                }
419            }
420
421            let dim = points.len();
422            Ok(PrCurve {points, dim})
423        })
424    }
425
426    ///
427    /// Computes average precision from the PR curve
428    ///
429    pub fn ap(&self) -> T {
430        let mut val = self.points[0].precision * self.points[0].recall;
431        for i in 1..self.dim {
432            let rec_diff = self.points[i].recall - self.points[i-1].recall;
433            val += rec_diff * self.points[i].precision;
434        }
435        return val;
436    }
437}
438
439
440///
441/// Confusion matrix for multi-class classification, in which rows represent predicted counts and
442/// columns represent labeled counts
443///
444#[derive(Clone, Debug, Eq, PartialEq)]
445pub struct MultiConfusionMatrix {
446    /// output dimension
447    pub dim: usize,
448    /// count data
449    pub counts: Vec<Vec<usize>>,
450    /// count sum
451    sum: usize
452}
453
454impl MultiConfusionMatrix {
455
456    ///
457    /// Computes a new confusion matrix from the provided scores and labels
458    ///
459    /// # Arguments
460    ///
461    /// * `scores` - vector of class scores
462    /// * `labels` - vector of class labels (indexed at zero)
463    ///
464    /// # Errors
465    ///
466    /// An invalid input error will be returned if either scores or labels are empty, or if their
467    /// lengths do not match. An undefined metric error will be returned if scores contain any value
468    /// that is not finite.
469    ///
470    /// # Examples
471    ///
472    /// ```
473    /// # use eval_metrics::error::EvalError;
474    /// # fn main() -> Result<(), EvalError> {
475    /// use eval_metrics::classification::MultiConfusionMatrix;
476    /// let scores = vec![
477    ///    vec![0.3, 0.1, 0.6],
478    ///    vec![0.5, 0.2, 0.3],
479    ///    vec![0.2, 0.7, 0.1],
480    ///    vec![0.3, 0.3, 0.4],
481    ///    vec![0.5, 0.1, 0.4],
482    ///    vec![0.8, 0.1, 0.1],
483    ///    vec![0.3, 0.5, 0.2]
484    /// ];
485    /// let labels = vec![2, 1, 1, 2, 0, 2, 0];
486    /// let matrix = MultiConfusionMatrix::compute(&scores, &labels)?;
487    /// # Ok(())}
488    /// ```
489    ///
490    pub fn compute<T: Scalar>(scores: &Vec<Vec<T>>,
491                              labels: &Vec<usize>) -> Result<MultiConfusionMatrix, EvalError> {
492        util::validate_input_dims(scores, labels).and_then(|()| {
493            let dim = scores[0].len();
494            let mut counts = vec![vec![0; dim]; dim];
495            let mut sum = 0;
496            for (i, s) in scores.iter().enumerate() {
497                if s.iter().any(|v| !v.is_finite()) {
498                    return Err(EvalError::infinite_value())
499                } else if s.len() != dim {
500                    return Err(EvalError::invalid_input("Inconsistent score dimension"))
501                } else if labels[i] >= dim {
502                    return Err(EvalError::invalid_input("Labels have more classes than scores"))
503                }
504                let ind = s.iter().enumerate().max_by(|(_, a), (_, b)| {
505                    a.partial_cmp(b).unwrap_or(Ordering::Equal)
506                }).map(|(mi, _)| mi).ok_or(EvalError::constant_input_data())?;
507                counts[ind][labels[i]] += 1;
508                sum += 1;
509            }
510            Ok(MultiConfusionMatrix {dim, counts, sum})
511        })
512    }
513
514    ///
515    /// Constructs a multi confusion matrix with the provided counts
516    ///
517    /// # Arguments
518    ///
519    /// * `counts` - vector of vector of counts, where each inner vector represents a row in the
520    /// confusion matrix
521    ///
522    /// # Errors
523    ///
524    /// An invalid input error will be returned if the counts are not a square matrix, or if the
525    /// counts are all zero
526    ///
527    /// # Examples
528    ///
529    /// ```
530    /// # use eval_metrics::error::EvalError;
531    /// # fn main() -> Result<(), EvalError> {
532    /// use eval_metrics::classification::MultiConfusionMatrix;
533    /// let counts = vec![
534    ///     vec![8, 3, 2],
535    ///     vec![1, 5, 3],
536    ///     vec![2, 1, 9]
537    /// ];
538    /// let matrix = MultiConfusionMatrix::from_counts(counts)?;
539    /// # Ok(())}
540    /// ```
541    ///
542    pub fn from_counts(counts: Vec<Vec<usize>>) -> Result<MultiConfusionMatrix, EvalError> {
543        let dim = counts.len();
544        let mut sum = 0;
545        for row in &counts {
546            sum += row.iter().sum::<usize>();
547            if row.len() != dim {
548                let msg = format!("Inconsistent column length ({})", row.len());
549                return Err(EvalError::invalid_input(msg.as_str()));
550            }
551        }
552        if sum == 0 {
553            Err(EvalError::invalid_input("Confusion matrix has all zero counts"))
554        } else {
555            Ok(MultiConfusionMatrix {dim, counts, sum})
556        }
557    }
558
559    ///
560    /// Computes accuracy
561    ///
562    pub fn accuracy(&self) -> Result<f64, EvalError> {
563        match self.sum {
564            // This should never happen as long as we prevent empty confusion matrices
565            0 => Err(EvalError::undefined_metric("Accuracy")),
566            sum => {
567                let mut correct = 0;
568                for i in 0..self.dim {
569                    correct += self.counts[i][i];
570                }
571                Ok(correct as f64 / sum as f64)
572            }
573        }
574    }
575
576    ///
577    /// Computes precision, which necessarily requires a specified averaging method
578    ///
579    /// # Arguments
580    ///
581    /// * `avg` - averaging method, which can be either 'Macro' or 'Weighted'
582    ///
583    pub fn precision(&self, avg: &Averaging) -> Result<f64, EvalError> {
584        self.agg_metric(&self.per_class_precision(), avg)
585    }
586
587    ///
588    /// Computes recall, which necessarily requires a specified averaging method
589    ///
590    /// # Arguments
591    ///
592    /// * `avg` - averaging method, which can be either 'Macro' or 'Weighted'
593    ///
594    pub fn recall(&self, avg: &Averaging) -> Result<f64, EvalError> {
595        self.agg_metric(&self.per_class_recall(), avg)
596    }
597
598    ///
599    /// Computes F1, which necessarily requires a specified averaging method
600    ///
601    /// # Arguments
602    ///
603    /// * `avg` - averaging method, which can be either 'Macro' or 'Weighted'
604    ///
605    pub fn f1(&self, avg: &Averaging) -> Result<f64, EvalError> {
606        self.agg_metric(&self.per_class_f1(), avg)
607    }
608
609    ///
610    /// Computes Rk, also known as the multi-class Matthews correlation coefficient following the
611    /// approach of Gorodkin in "Comparing two K-category assignments by a K-category correlation
612    /// coefficient" (2004)
613    ///
614    pub fn rk(&self) -> Result<f64, EvalError> {
615        let mut t = vec![0.0; self.dim];
616        let mut p = vec![0.0; self.dim];
617        let mut c = 0.0;
618        let s = self.sum as f64;
619
620        for i in 0..self.dim {
621            c += self.counts[i][i] as f64;
622            for j in 0..self.dim {
623                t[j] += self.counts[i][j] as f64;
624                p[i] += self.counts[i][j] as f64;
625            }
626        }
627
628        let tt = t.iter().fold(0.0, |acc, val| acc + (val * val));
629        let pp = p.iter().fold(0.0, |acc, val| acc + (val * val));
630        let tp = t.iter().zip(p).fold(0.0, |acc, (t_val, p_val)| acc + t_val * p_val);
631        let num = c * s - tp;
632        let den = (s * s - pp).sqrt() * (s * s - tt).sqrt();
633
634        if den == 0.0 {
635            Err(EvalError::undefined_metric("Rk"))
636        } else {
637            Ok(num / den)
638        }
639    }
640
641    ///
642    /// Computes per-class accuracy, resulting in a vector of values for each class
643    ///
644    pub fn per_class_accuracy(&self) -> Vec<Result<f64, EvalError>> {
645        self.per_class_binary_metric("accuracy")
646    }
647
648    ///
649    /// Computes per-class precision, resulting in a vector of values for each class
650    ///
651    pub fn per_class_precision(&self) -> Vec<Result<f64, EvalError>> {
652        self.per_class_binary_metric("precision")
653    }
654
655    ///
656    /// Computes per-class recall, resulting in a vector of values for each class
657    ///
658    pub fn per_class_recall(&self) -> Vec<Result<f64, EvalError>> {
659        self.per_class_binary_metric("recall")
660    }
661
662    ///
663    /// Computes per-class F1, resulting in a vector of values for each class
664    ///
665    pub fn per_class_f1(&self) -> Vec<Result<f64, EvalError>> {
666        self.per_class_binary_metric("f1")
667    }
668
669    ///
670    /// Computes per-class MCC, resulting in a vector of values for each class
671    ///
672    pub fn per_class_mcc(&self) -> Vec<Result<f64, EvalError>> {
673        self.per_class_binary_metric("mcc")
674    }
675
676    fn per_class_binary_metric(&self, metric: &str) -> Vec<Result<f64, EvalError>> {
677        (0..self.dim).map(|k| {
678            let (mut tpc, mut fpc, mut tnc, mut fnc) = (0, 0, 0, 0);
679            for i in 0..self.dim {
680                for j in 0..self.dim {
681                    let count = self.counts[i][j];
682                    if i == k && j == k {
683                        tpc = count;
684                    } else if i == k {
685                        fpc += count;
686                    } else if j == k {
687                        fnc += count;
688                    } else {
689                        tnc += count;
690                    }
691                }
692            }
693            let matrix = BinaryConfusionMatrix::from_counts(tpc, fpc, tnc, fnc)?;
694            match metric {
695                "accuracy" => matrix.accuracy(),
696                "precision" => matrix.precision(),
697                "recall" => matrix.recall(),
698                "f1" => matrix.f1(),
699                "mcc" => matrix.mcc(),
700                other => Err(EvalError::invalid_metric(other))
701            }
702        }).collect()
703    }
704
705    fn agg_metric(&self, pcm: &Vec<Result<f64, EvalError>>,
706                  avg: &Averaging) -> Result<f64, EvalError> {
707        match avg {
708            Averaging::Macro => self.macro_metric(pcm),
709            Averaging::Weighted => self.weighted_metric(pcm)
710        }
711    }
712
713    fn macro_metric(&self, pcm: &Vec<Result<f64, EvalError>>) -> Result<f64, EvalError> {
714        pcm.iter().try_fold(0.0, |sum, metric| {
715            match metric {
716                Ok(m) => Ok(sum + m),
717                Err(e) => Err(e.clone())
718            }
719        }).map(|sum| {sum / pcm.len() as f64})
720    }
721
722    fn weighted_metric(&self, pcm: &Vec<Result<f64, EvalError>>) -> Result<f64, EvalError> {
723        pcm.iter()
724            .zip(self.class_counts().iter())
725            .try_fold(0.0, |val, (metric, &class)| {
726                match metric {
727                    Ok(m) => Ok(val + (m * (class as f64) / (self.sum as f64))),
728                    Err(e) => Err(e.clone())
729                }
730            })
731    }
732
733    fn class_counts(&self) -> Vec<usize> {
734        let mut counts = vec![0; self.dim];
735        for i in 0..self.dim {
736            for j in 0..self.dim {
737                counts[j] += self.counts[i][j];
738            }
739        }
740        counts
741    }
742}
743
744impl std::fmt::Display for MultiConfusionMatrix {
745    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
746        if self.dim <= 25 {
747            let outcomes = (0..self.dim).map(|i| format!("Class-{}", i + 1)).collect();
748            write!(f, "{}", display::stringify_confusion_matrix(&self.counts, &outcomes))
749        } else {
750            write!(f, "[Confusion matrix is too large to display]")
751        }
752    }
753}
754
755///
756/// Computes multi-class AUC as described by Hand and Till in "A Simple Generalisation of the Area
757/// Under the ROC Curve for Multiple Class Classification Problems" (2001)
758///
759/// # Arguments
760///
761/// * `scores` - vector of class scores
762/// * `labels` - vector of class labels
763///
764/// # Errors
765///
766/// An invalid input error will be returned if either scores or labels are empty or contain a
767/// single data point, or if their lengths do not match. An undefined metric error will be
768/// returned if scores contain any value that is not finite, or if any pairwise roc curve is not
769/// defined for all distinct class label pairs.
770///
771/// # Examples
772///
773/// ```
774/// # use eval_metrics::error::EvalError;
775/// # fn main() -> Result<(), EvalError> {
776/// use eval_metrics::classification::m_auc;
777/// let scores = vec![
778///    vec![0.3, 0.1, 0.6],
779///    vec![0.5, 0.2, 0.3],
780///    vec![0.2, 0.7, 0.1],
781///    vec![0.3, 0.3, 0.4],
782///    vec![0.5, 0.1, 0.4],
783///    vec![0.8, 0.1, 0.1],
784///    vec![0.3, 0.5, 0.2]
785/// ];
786/// let labels = vec![2, 1, 1, 2, 0, 2, 0];
787/// let metric = m_auc(&scores, &labels)?;
788/// # Ok(())}
789/// ```
790
791pub fn m_auc<T: Scalar>(scores: &Vec<Vec<T>>, labels: &Vec<usize>) -> Result<T, EvalError> {
792    util::validate_input_dims(scores, labels).and_then(|()| {
793        let dim = scores[0].len();
794        let mut m_sum = T::zero();
795
796        fn subset<T: Scalar>(scr: &Vec<Vec<T>>,
797                             lab: &Vec<usize>,
798                             j: usize,
799                             k: usize) -> (Vec<T>, Vec<bool>) {
800
801            scr.iter().zip(lab.iter()).filter(|(_, &l)| {
802                l == j || l == k
803            }).map(|(s, &l)| {
804                (s[k], l == k)
805            }).unzip()
806        }
807
808        for j in 0..dim {
809            for k in 0..j {
810                let (k_scores, k_labels) = subset(scores, labels, j, k);
811                let ajk = RocCurve::compute(&k_scores, &k_labels)?.auc();
812                let (j_scores, j_labels) = subset(scores, labels, k, j);
813                let akj = RocCurve::compute(&j_scores, &j_labels)?.auc();
814                m_sum += (ajk + akj) / T::from_f64(2.0);
815            }
816        }
817        Ok(m_sum * T::from_f64(2.0) / (T::from_usize(dim) * (T::from_usize(dim) - T::one())))
818    })
819}
820
821///
822/// Specifies the averaging method to use for computing multi-class metrics
823///
824#[derive(Copy, Clone, Debug, Eq, PartialEq)]
825pub enum Averaging {
826    /// Macro average, in which the individual metrics for each class are weighted uniformly
827    Macro,
828    /// Weighted average, in which the individual metrics for each class are weighted by the number
829    /// of occurrences of that class
830    Weighted
831}
832
833enum RocTrend {
834    Horizontal,
835    Vertical,
836    Diagonal
837}
838
839fn create_pairs<T: Scalar>(scores: &Vec<T>,
840                           labels: &Vec<bool>) -> Result<(Vec<(T, bool)>, usize), EvalError> {
841    let n = scores.len();
842    let mut pairs = Vec::with_capacity(n);
843    let mut num_pos = 0;
844
845    for i in 0..n {
846        if !scores[i].is_finite() {
847            return Err(EvalError::infinite_value())
848        } else if labels[i] {
849            num_pos += 1;
850        }
851        pairs.push((scores[i], labels[i]))
852    }
853    Ok((pairs, num_pos))
854}
855
856fn sort_pairs_descending<T: Scalar>(pairs: &mut Vec<(T, bool)>) {
857    pairs.sort_unstable_by(|(s1, _), (s2, _)| {
858        if s1 > s2 {
859            Ordering::Less
860        } else if s1 < s2 {
861            Ordering::Greater
862        } else {
863            Ordering::Equal
864        }
865    });
866}
867
868#[cfg(test)]
869mod tests {
870    use assert_approx_eq::assert_approx_eq;
871    use super::*;
872
873    fn binary_data() -> (Vec<f64>, Vec<bool>) {
874        let scores = vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9];
875        let labels = vec![false, false, true, false, true, false, false, true];
876        (scores, labels)
877    }
878
879    fn multi_class_data() -> (Vec<Vec<f64>>, Vec<usize>) {
880
881        let scores = vec![
882            vec![0.3, 0.1, 0.6],
883            vec![0.5, 0.2, 0.3],
884            vec![0.2, 0.7, 0.1],
885            vec![0.3, 0.3, 0.4],
886            vec![0.5, 0.1, 0.4],
887            vec![0.8, 0.1, 0.1],
888            vec![0.3, 0.5, 0.2]
889        ];
890        let labels = vec![2, 1, 1, 2, 0, 2, 0];
891        (scores, labels)
892    }
893
894    #[test]
895    fn test_binary_confusion_matrix() {
896        let (scores, labels) = binary_data();
897        let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
898        assert_eq!(matrix.tp_count, 2);
899        assert_eq!(matrix.fp_count, 2);
900        assert_eq!(matrix.tn_count, 3);
901        assert_eq!(matrix.fn_count, 1);
902    }
903
904    #[test]
905    fn test_binary_confusion_matrix_empty() {
906        assert!(BinaryConfusionMatrix::compute(
907            &Vec::<f64>::new(),
908            &Vec::<bool>::new(),
909            0.5
910        ).is_err());
911    }
912
913    #[test]
914    fn test_binary_confusion_matrix_unequal_length() {
915        assert!(BinaryConfusionMatrix::compute(
916            &vec![0.1, 0.2],
917            &vec![true, false, true],
918            0.5
919        ).is_err());
920    }
921
922    #[test]
923    fn test_binary_confusion_matrix_nan() {
924        assert!(BinaryConfusionMatrix::compute(
925            &vec![f64::NAN, 0.2, 0.4],
926            &vec![true, false, true],
927            0.5
928        ).is_err());
929    }
930
931    #[test]
932    fn test_binary_confusion_matrix_with_counts() {
933        let matrix = BinaryConfusionMatrix::from_counts(2, 4, 5, 3).unwrap();
934        assert_eq!(matrix.tp_count, 2);
935        assert_eq!(matrix.fp_count, 4);
936        assert_eq!(matrix.tn_count, 5);
937        assert_eq!(matrix.fn_count, 3);
938        assert_eq!(matrix.sum, 14);
939        assert!(BinaryConfusionMatrix::from_counts(0, 0, 0, 0).is_err())
940    }
941
942    #[test]
943    fn test_binary_accuracy() {
944        let (scores, labels) = binary_data();
945        let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
946        assert_approx_eq!(matrix.accuracy().unwrap(), 0.625);
947    }
948
949    #[test]
950    fn test_binary_precision() {
951        let (scores, labels) = binary_data();
952        let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
953        assert_approx_eq!(matrix.precision().unwrap(), 0.5);
954
955        // test edge case where we never predict a positive
956        assert!(BinaryConfusionMatrix::compute(
957            &vec![0.4, 0.3, 0.1, 0.2, 0.1],
958            &vec![true, false, true, false, true],
959            0.5
960        ).unwrap().precision().is_err());
961    }
962
963    #[test]
964    fn test_binary_precision_empty() {
965        assert!(BinaryConfusionMatrix::compute(
966            &Vec::<f64>::new(),
967            &Vec::<bool>::new(),
968            0.5
969        ).is_err());
970    }
971
972    #[test]
973    fn test_binary_precision_unequal_length() {
974        assert!(BinaryConfusionMatrix::compute(
975            &vec![0.1, 0.2],
976            &vec![true, false, true],
977            0.5
978        ).is_err());
979    }
980
981    #[test]
982    fn test_binary_recall() {
983        let (scores, labels) = binary_data();
984        let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
985        assert_approx_eq!(matrix.recall().unwrap(), 2.0 / 3.0);
986
987        // test edge case where we have no positive class
988        assert!(BinaryConfusionMatrix::compute(
989            &vec![0.4, 0.3, 0.1, 0.8, 0.7],
990            &vec![false, false, false, false, false],
991            0.5
992        ).unwrap().recall().is_err());
993    }
994
995    #[test]
996    fn test_binary_recall_empty() {
997        assert!(BinaryConfusionMatrix::compute(
998            &Vec::<f64>::new(),
999            &Vec::<bool>::new(),
1000            0.5
1001        ).is_err());
1002    }
1003
1004    #[test]
1005    fn test_binary_recall_unequal_length() {
1006        assert!(BinaryConfusionMatrix::compute(
1007            &vec![0.1, 0.2],
1008            &vec![true, false, true],
1009            0.5
1010        ).is_err());
1011    }
1012
1013    #[test]
1014    fn test_binary_f1() {
1015        let (scores, labels) = binary_data();
1016        let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
1017        assert_approx_eq!(matrix.f1().unwrap(), 0.5714285714285715);
1018
1019        // test edge case where we never predict a positive
1020        assert!(BinaryConfusionMatrix::compute(
1021            &vec![0.4, 0.3, 0.1, 0.2, 0.1],
1022            &vec![true, false, true, false, true],
1023            0.5
1024        ).unwrap().f1().is_err());
1025
1026        // test edge case where we have no positive class
1027        assert!(BinaryConfusionMatrix::compute(
1028            &vec![0.4, 0.3, 0.1, 0.8, 0.7],
1029            &vec![false, false, false, false, false],
1030            0.5
1031        ).unwrap().f1().is_err());
1032    }
1033
1034    #[test]
1035    fn test_binary_f1_empty() {
1036        assert!(BinaryConfusionMatrix::compute(
1037            &Vec::<f64>::new(),
1038            &Vec::<bool>::new(),
1039            0.5
1040        ).is_err());
1041    }
1042
1043    #[test]
1044    fn test_binary_f1_unequal_length() {
1045        assert!(BinaryConfusionMatrix::compute(
1046            &vec![0.1, 0.2],
1047            &vec![true, false, true],
1048            0.5
1049        ).is_err());
1050    }
1051
1052    #[test]
1053    fn test_binary_f1_0p_0r() {
1054        let scores = vec![0.1, 0.2, 0.7, 0.8];
1055        let labels = vec![false, true, false, false];
1056
1057        assert_eq!(BinaryConfusionMatrix::compute(&scores, &labels, 0.5)
1058                       .unwrap()
1059                       .f1()
1060                       .unwrap(), 0.0
1061        )
1062    }
1063
1064    #[test]
1065    fn test_mcc() {
1066        let (scores, labels) = binary_data();
1067        let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
1068        assert_approx_eq!(matrix.mcc().unwrap(), 0.2581988897471611)
1069    }
1070
1071    #[test]
1072    fn test_roc() {
1073        let (scores, labels) = binary_data();
1074        let roc = RocCurve::compute(&scores, &labels).unwrap();
1075
1076        assert_eq!(roc.dim, 5);
1077        assert_approx_eq!(roc.points[0].tp_rate, 1.0 / 3.0);
1078        assert_approx_eq!(roc.points[0].fp_rate, 0.0);
1079        assert_approx_eq!(roc.points[0].threshold, 0.9);
1080        assert_approx_eq!(roc.points[1].tp_rate, 1.0 / 3.0);
1081        assert_approx_eq!(roc.points[1].fp_rate, 0.2);
1082        assert_approx_eq!(roc.points[1].threshold, 0.8);
1083        assert_approx_eq!(roc.points[2].tp_rate, 2.0 / 3.0);
1084        assert_approx_eq!(roc.points[2].fp_rate, 0.2);
1085        assert_approx_eq!(roc.points[2].threshold, 0.7);
1086        assert_approx_eq!(roc.points[3].tp_rate, 2.0 / 3.0);
1087        assert_approx_eq!(roc.points[3].fp_rate, 1.0);
1088        assert_approx_eq!(roc.points[3].threshold, 0.2);
1089        assert_approx_eq!(roc.points[4].tp_rate, 1.0);
1090        assert_approx_eq!(roc.points[4].fp_rate, 1.0);
1091        assert_approx_eq!(roc.points[4].threshold, 0.1);
1092    }
1093
1094    #[test]
1095    fn test_roc_tied_scores() {
1096        let scores = vec![1.0, 0.1, 1.0, 0.9, 0.5, 0.1, 0.8, 0.9, 1.0, 0.4];
1097        let labels = vec![true, false, false, false, false, false, true, true, false, false];
1098        let roc = RocCurve::compute(&scores, &labels).unwrap();
1099        assert_approx_eq!(roc.points[0].tp_rate, 1.0 / 3.0);
1100        assert_approx_eq!(roc.points[0].fp_rate, 0.2857142857142857);
1101        assert_approx_eq!(roc.points[0].threshold, 1.0);
1102        assert_approx_eq!(roc.points[1].tp_rate, 2.0 / 3.0);
1103        assert_approx_eq!(roc.points[1].fp_rate, 0.42857142857142855);
1104        assert_approx_eq!(roc.points[1].threshold, 0.9);
1105        assert_approx_eq!(roc.points[2].tp_rate, 1.0);
1106        assert_approx_eq!(roc.points[2].fp_rate, 0.42857142857142855);
1107        assert_approx_eq!(roc.points[2].threshold, 0.8);
1108        assert_approx_eq!(roc.points[3].tp_rate, 1.0);
1109        assert_approx_eq!(roc.points[3].fp_rate, 1.0);
1110        assert_approx_eq!(roc.points[3].threshold, 0.1);
1111    }
1112
1113    #[test]
1114    fn test_roc_empty() {
1115        assert!(RocCurve::compute(&Vec::<f64>::new(), &Vec::<bool>::new()).is_err());
1116    }
1117
1118    #[test]
1119    fn test_roc_unequal_length() {
1120        assert!(RocCurve::compute(
1121            &vec![0.4, 0.5, 0.2],
1122            &vec![true, false, true, false]
1123        ).is_err());
1124    }
1125
1126    #[test]
1127    fn test_roc_nan() {
1128        assert!(RocCurve::compute(
1129            &vec![0.4, 0.5, 0.2, f64::NAN],
1130            &vec![true, false, true, false]
1131        ).is_err());
1132    }
1133
1134    #[test]
1135    fn test_roc_constant_label() {
1136        let scores = vec![0.1, 0.4, 0.5, 0.7];
1137        let labels_true = vec![true; 4];
1138        let labels_false = vec![false; 4];
1139        assert!(match RocCurve::compute(&scores, &labels_true) {
1140            Err(err) if err.msg.contains("Undefined") => true,
1141            _ => false
1142        });
1143        assert!(match RocCurve::compute(&scores, &labels_false) {
1144            Err(err) if err.msg.contains("Undefined") => true,
1145            _ => false
1146        });
1147    }
1148
1149    #[test]
1150    fn test_roc_constant_score() {
1151        let scores = vec![0.4, 0.4, 0.4, 0.4];
1152        let labels = vec![true, false, true, false];
1153        assert!(match RocCurve::compute(&scores, &labels) {
1154            Err(err) if err.msg.contains("Constant") => true,
1155            _ => false
1156        });
1157    }
1158
1159    #[test]
1160    fn test_auc() {
1161        let (scores, labels) = binary_data();
1162        assert_approx_eq!(RocCurve::compute(&scores, &labels).unwrap().auc(), 0.6);
1163
1164        let scores2 = vec![0.2, 0.5, 0.5, 0.3];
1165        let labels2 = vec![false, true, false, true];
1166        assert_approx_eq!(RocCurve::compute(&scores2, &labels2).unwrap().auc(), 0.625);
1167    }
1168
1169    #[test]
1170    fn test_auc_tied_scores() {
1171        let scores = vec![0.1, 0.2, 0.3, 0.3, 0.3, 0.7, 0.8];
1172        let labels1 = vec![false, false, true, false, true, false, true];
1173        let labels2 = vec![false, false, true, true, false, false, true];
1174        let labels3 = vec![false, false, false, true, true, false, true];
1175        assert_approx_eq!(RocCurve::compute(&scores, &labels1).unwrap().auc(), 0.75);
1176        assert_approx_eq!(RocCurve::compute(&scores, &labels2).unwrap().auc(), 0.75);
1177        assert_approx_eq!(RocCurve::compute(&scores, &labels3).unwrap().auc(), 0.75);
1178
1179        let scores2 = vec![1.0, 0.1, 1.0, 0.9, 0.5, 0.1, 0.8, 0.9, 1.0, 0.4];
1180        let labels4 = vec![true, false, false, false, false, false, true, true, false, false];
1181        assert_approx_eq!(RocCurve::compute(&scores2, &labels4).unwrap().auc(), 0.6904761904761905);
1182    }
1183
1184    #[test]
1185    fn test_pr() {
1186        let (scores, labels) = binary_data();
1187        let pr = PrCurve::compute(&scores, &labels).unwrap();
1188        assert_approx_eq!(pr.points[0].precision, 1.0);
1189        assert_approx_eq!(pr.points[0].recall, 1.0 / 3.0);
1190        assert_approx_eq!(pr.points[0].threshold, 0.9);
1191        assert_approx_eq!(pr.points[1].precision, 2.0 / 3.0);
1192        assert_approx_eq!(pr.points[1].recall, 2.0 / 3.0);
1193        assert_approx_eq!(pr.points[1].threshold, 0.7);
1194        assert_approx_eq!(pr.points[2].precision, 0.375);
1195        assert_approx_eq!(pr.points[2].recall, 1.0);
1196        assert_approx_eq!(pr.points[2].threshold, 0.1);
1197    }
1198
1199    #[test]
1200    fn test_pr_empty() {
1201        assert!(PrCurve::compute(&Vec::<f64>::new(), &Vec::<bool>::new()).is_err());
1202    }
1203
1204    #[test]
1205    fn test_pr_unequal_length() {
1206        assert!(PrCurve::compute(&vec![0.4, 0.5, 0.2], &vec![true, false, true, false]).is_err());
1207    }
1208
1209    #[test]
1210    fn test_pr_nan() {
1211        assert!(PrCurve::compute(
1212            &vec![0.4, 0.5, 0.2, f64::NAN],
1213            &vec![true, false, true, false]
1214        ).is_err());
1215    }
1216
1217    #[test]
1218    fn test_pr_constant_label() {
1219        let scores = vec![0.1, 0.4, 0.5, 0.7];
1220        let labels_true = vec![true; 4];
1221        let labels_false = vec![false; 4];
1222        assert!(PrCurve::compute(&scores, &labels_true).is_ok());
1223        assert!(match PrCurve::compute(&scores, &labels_false) {
1224            Err(err) if err.msg.contains("Undefined") => true,
1225            _ => false
1226        });
1227    }
1228
1229    #[test]
1230    fn test_pr_constant_score() {
1231        let scores = vec![0.4, 0.4, 0.4, 0.4];
1232        let labels = vec![true, false, true, false];
1233        assert!(PrCurve::compute(&scores, &labels).is_ok());
1234    }
1235
1236    #[test]
1237    fn test_ap() {
1238        let (scores, labels) = binary_data();
1239        assert_approx_eq!(PrCurve::compute(&scores, &labels).unwrap().ap(), 0.6805555555555556);
1240
1241        let scores2 = vec![0.2, 0.5, 0.5, 0.3];
1242        let labels2 = vec![false, true, false, true];
1243        assert_approx_eq!(PrCurve::compute(&scores2, &labels2).unwrap().ap(), 0.58333333333333);
1244    }
1245
1246    #[test]
1247    fn test_ap_tied_scores() {
1248        let scores = vec![0.1, 0.2, 0.3, 0.3, 0.3, 0.7, 0.8];
1249        let labels1 = vec![false, false, true, false, true, false, true];
1250        let labels2 = vec![false, false, true, true, false, false, true];
1251        let labels3 = vec![false, false, false, true, true, false, true];
1252        assert_approx_eq!(PrCurve::compute(&scores, &labels1).unwrap().ap(), 0.7333333333333);
1253        assert_approx_eq!(PrCurve::compute(&scores, &labels2).unwrap().ap(), 0.7333333333333);
1254        assert_approx_eq!(PrCurve::compute(&scores, &labels3).unwrap().ap(), 0.7333333333333);
1255    }
1256
1257    #[test]
1258    fn test_multi_confusion_matrix() {
1259        let (scores, labels) = multi_class_data();
1260        let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1261        assert_eq!(matrix.counts, vec![vec![1, 1, 1], vec![1, 1, 0], vec![0, 0, 2]]);
1262        assert_eq!(matrix.dim, 3);
1263        assert_eq!(matrix.sum, 7);
1264    }
1265
1266    #[test]
1267    fn test_multi_confusion_matrix_empty() {
1268        let scores: Vec<Vec<f64>> = vec![];
1269        let labels = Vec::<usize>::new();
1270        assert!(MultiConfusionMatrix::compute(&scores, &labels).is_err());
1271    }
1272
1273    #[test]
1274    fn test_multi_confusion_matrix_unequal_length() {
1275        assert!(MultiConfusionMatrix::compute(&vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4]],
1276                                              &vec![2, 1, 0]).is_err());
1277    }
1278
1279    #[test]
1280    fn test_multi_confusion_matrix_nan() {
1281        assert!(MultiConfusionMatrix::compute(
1282            &vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4], vec![0.3, 0.7, f64::NAN]],
1283            &vec![2, 1, 0]
1284        ).is_err());
1285    }
1286
1287    #[test]
1288    fn test_multi_confusion_matrix_inconsistent_score_dims() {
1289        let scores = vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4], vec![0.3, 0.7]];
1290        let labels = vec![2, 1, 0];
1291        assert!(MultiConfusionMatrix::compute(&scores, &labels).is_err());
1292    }
1293
1294    #[test]
1295    fn test_multi_confusion_matrix_score_label_dim_mismatch() {
1296        let scores = vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4], vec![0.3, 0.2, 0.5]];
1297        let labels = vec![2, 3, 0];
1298        assert!(MultiConfusionMatrix::compute(&scores, &labels).is_err());
1299    }
1300
1301    #[test]
1302    fn test_multi_confusion_matrix_counts() {
1303        let counts = vec![vec![6, 3, 1], vec![4, 2, 7], vec![5, 2, 8]];
1304        let matrix = MultiConfusionMatrix::from_counts(counts).unwrap();
1305        assert_eq!(matrix.dim, 3);
1306        assert_eq!(matrix.sum, 38);
1307        assert_eq!(matrix.counts, vec![vec![6, 3, 1], vec![4, 2, 7], vec![5, 2, 8]]);
1308    }
1309
1310    #[test]
1311    fn test_multi_confusion_matrix_bad_counts() {
1312        let counts = vec![vec![6, 3, 1], vec![4, 2], vec![5, 2, 8]];
1313        assert!(MultiConfusionMatrix::from_counts(counts).is_err())
1314    }
1315
1316    #[test]
1317    fn test_multi_accuracy() {
1318        let (scores, labels) = multi_class_data();
1319        let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1320        assert_approx_eq!(matrix.accuracy().unwrap(), 0.5714285714285714)
1321    }
1322
1323    #[test]
1324    fn test_multi_precision() {
1325        let (scores, labels) = multi_class_data();
1326        let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1327        assert_approx_eq!(matrix.precision(&Averaging::Macro).unwrap(), 0.611111111111111);
1328        assert_approx_eq!(matrix.precision(&Averaging::Weighted).unwrap(), 2.0 / 3.0);
1329
1330        assert!(MultiConfusionMatrix::compute(
1331            &vec![vec![0.6, 0.4, 0.0],
1332                  vec![0.2, 0.8, 0.0],
1333                  vec![0.9, 0.1, 0.0],
1334                  vec![0.3, 0.7, 0.0]],
1335            &vec![0, 1, 2, 1]
1336        ).unwrap().precision(&Averaging::Macro).is_err())
1337    }
1338
1339    #[test]
1340    fn test_multi_recall() {
1341        let (scores, labels) = multi_class_data();
1342        let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1343        assert_approx_eq!(matrix.recall(&Averaging::Macro).unwrap(), 0.5555555555555555);
1344        assert_approx_eq!(matrix.recall(&Averaging::Weighted).unwrap(), 0.5714285714285714);
1345
1346        assert!(MultiConfusionMatrix::compute(
1347            &vec![vec![0.6, 0.3, 0.1],
1348                  vec![0.2, 0.5, 0.3],
1349                  vec![0.8, 0.1, 0.1],
1350                  vec![0.3, 0.5, 0.2]],
1351            &vec![0, 1, 0, 1]
1352        ).unwrap().recall(&Averaging::Macro).is_err())
1353    }
1354
1355    #[test]
1356    fn test_multi_f1() {
1357        let (scores, labels) = multi_class_data();
1358        let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1359        assert_approx_eq!(matrix.f1(&Averaging::Macro).unwrap(), 0.5666666666666668);
1360        assert_approx_eq!(matrix.f1(&Averaging::Weighted).unwrap(), 0.6);
1361
1362        assert!(MultiConfusionMatrix::compute(
1363            &vec![vec![0.6, 0.4, 0.0],
1364                  vec![0.2, 0.8, 0.0],
1365                  vec![0.3, 0.7, 0.0]],
1366            &vec![0, 2, 1]
1367        ).unwrap().f1(&Averaging::Macro).is_err());
1368
1369        assert!(MultiConfusionMatrix::compute(
1370            &vec![vec![0.6, 0.3, 0.1],
1371                  vec![0.2, 0.5, 0.3],
1372                  vec![0.3, 0.5, 0.2]],
1373            &vec![1, 0, 1]
1374        ).unwrap().f1(&Averaging::Macro).is_err());
1375    }
1376
1377    #[test]
1378    fn test_multi_f1_0p_0r() {
1379        let scores = multi_class_data().0;
1380        // every prediction is wrong
1381        let labels = vec![1, 2, 0, 0, 1, 1, 0];
1382
1383        assert_eq!(MultiConfusionMatrix::compute(&scores, &labels)
1384                       .unwrap()
1385                       .f1(&Averaging::Macro)
1386                       .unwrap(), 0.0
1387        )
1388    }
1389
1390    #[test]
1391    fn test_rk() {
1392        let (scores, labels) = multi_class_data();
1393        let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1394        assert_approx_eq!(matrix.rk().unwrap(), 0.375)
1395    }
1396
1397    #[test]
1398    fn test_per_class_accuracy() {
1399        let (scores, labels) = multi_class_data();
1400        let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1401        let pca = matrix.per_class_accuracy();
1402        assert_eq!(pca.len(), 3);
1403        assert_approx_eq!(pca[0].as_ref().unwrap(), 0.5714285714285714);
1404        assert_approx_eq!(pca[1].as_ref().unwrap(), 0.7142857142857143);
1405        assert_approx_eq!(pca[2].as_ref().unwrap(), 0.8571428571428571);
1406    }
1407
1408    #[test]
1409    fn test_per_class_precision() {
1410        let (scores, labels) = multi_class_data();
1411        let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1412        let pcp = matrix.per_class_precision();
1413        assert_eq!(pcp.len(), 3);
1414        assert_approx_eq!(pcp[0].as_ref().unwrap(), 0.3333333333333333);
1415        assert_approx_eq!(pcp[1].as_ref().unwrap(), 0.5);
1416        assert_approx_eq!(pcp[2].as_ref().unwrap(), 1.0);
1417        println!("{}", matrix);
1418    }
1419
1420    #[test]
1421    fn test_per_class_recall() {
1422        let (scores, labels) = multi_class_data();
1423        let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1424        let pcr = matrix.per_class_recall();
1425        assert_eq!(pcr.len(), 3);
1426        assert_approx_eq!(pcr[0].as_ref().unwrap(), 0.5);
1427        assert_approx_eq!(pcr[1].as_ref().unwrap(), 0.5);
1428        assert_approx_eq!(pcr[2].as_ref().unwrap(), 0.6666666666666666);
1429    }
1430
1431    #[test]
1432    fn test_per_class_f1() {
1433        let (scores, labels) = multi_class_data();
1434        let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1435        let pcf = matrix.per_class_f1();
1436        assert_eq!(pcf.len(), 3);
1437        assert_approx_eq!(pcf[0].as_ref().unwrap(), 0.4);
1438        assert_approx_eq!(pcf[1].as_ref().unwrap(), 0.5);
1439        assert_approx_eq!(pcf[2].as_ref().unwrap(), 0.8);
1440    }
1441
1442    #[test]
1443    fn test_per_class_mcc() {
1444        let (scores, labels) = multi_class_data();
1445        let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1446        let pcm = matrix.per_class_mcc();
1447        assert_eq!(pcm.len(), 3);
1448        assert_approx_eq!(pcm[0].as_ref().unwrap(), 0.09128709291752773);
1449        assert_approx_eq!(pcm[1].as_ref().unwrap(), 0.3);
1450        assert_approx_eq!(pcm[2].as_ref().unwrap(), 0.7302967433402215);
1451    }
1452
1453    #[test]
1454    fn test_m_auc() {
1455        let (scores, labels) = multi_class_data();
1456        assert_approx_eq!(m_auc(&scores, &labels).unwrap(), 0.673611111111111)
1457    }
1458
1459    #[test]
1460    fn test_m_auc_empty() {
1461        assert!(m_auc(&Vec::<Vec<f64>>::new(), &Vec::<usize>::new()).is_err());
1462    }
1463
1464    #[test]
1465    fn test_m_auc_unequal_length() {
1466        assert!(m_auc(&Vec::<Vec<f64>>::new(), &vec![3, 0, 1, 2]).is_err());
1467    }
1468
1469    #[test]
1470    fn test_m_auc_nan() {
1471        let scores = vec![
1472            vec![0.3, 0.1, 0.6],
1473            vec![0.5, f64::NAN, 0.3],
1474            vec![0.2, 0.7, 0.1],
1475        ];
1476        // every prediction is wrong
1477        let labels = vec![1, 2, 0];
1478        assert!(m_auc(&scores, &labels).is_err());
1479    }
1480
1481    #[test]
1482    fn test_m_auc_constant_label() {
1483        let scores = vec![
1484            vec![0.3, 0.1, 0.6],
1485            vec![0.5, 0.2, 0.3],
1486            vec![0.2, 0.7, 0.1],
1487            vec![0.8, 0.1, 0.1],
1488        ];
1489
1490        let labels = vec![1, 1, 1, 1];
1491        assert!(m_auc(&scores, &labels).is_err())
1492    }
1493}