1use std::cmp::Ordering;
6use crate::util;
7use crate::numeric::Scalar;
8use crate::error::EvalError;
9use crate::display;
10
11#[derive(Copy, Clone, Debug, Eq, PartialEq)]
15pub struct BinaryConfusionMatrix {
16 pub tp_count: usize,
18 pub fp_count: usize,
20 pub tn_count: usize,
22 pub fn_count: usize,
24 sum: usize
26}
27
28impl BinaryConfusionMatrix {
29
30 pub fn compute<T: Scalar>(scores: &Vec<T>,
58 labels: &Vec<bool>,
59 threshold: T) -> Result<BinaryConfusionMatrix, EvalError> {
60 util::validate_input_dims(scores, labels).and_then(|()| {
61 let mut counts = [0, 0, 0, 0];
62 for (&score, &label) in scores.iter().zip(labels) {
63 if !score.is_finite() {
64 return Err(EvalError::infinite_value())
65 } else if score >= threshold && label {
66 counts[3] += 1;
67 } else if score >= threshold {
68 counts[2] += 1;
69 } else if score < threshold && !label {
70 counts[0] += 1;
71 } else {
72 counts[1] += 1;
73 }
74 };
75 let sum = counts.iter().sum();
76 Ok(BinaryConfusionMatrix {
77 tp_count: counts[3],
78 fp_count: counts[2],
79 tn_count: counts[0],
80 fn_count: counts[1],
81 sum
82 })
83 })
84 }
85
86 pub fn from_counts(tp_count: usize,
101 fp_count: usize,
102 tn_count: usize,
103 fn_count: usize) -> Result<BinaryConfusionMatrix, EvalError> {
104 match tp_count + fp_count + tn_count + fn_count {
105 0 => Err(EvalError::invalid_input("Confusion matrix has all zero counts")),
106 sum => Ok(BinaryConfusionMatrix {tp_count, fp_count, tn_count, fn_count, sum})
107 }
108 }
109
110 pub fn accuracy(&self) -> Result<f64, EvalError> {
114 let num = self.tp_count + self.tn_count;
115 match self.sum {
116 0 => Err(EvalError::undefined_metric("Accuracy")),
118 sum => Ok(num as f64 / sum as f64)
119 }
120 }
121
122 pub fn precision(&self) -> Result<f64, EvalError> {
126 match self.tp_count + self.fp_count {
127 0 => Err(EvalError::undefined_metric("Precision")),
128 den => Ok((self.tp_count as f64) / den as f64)
129 }
130 }
131
132 pub fn recall(&self) -> Result<f64, EvalError> {
136 match self.tp_count + self.fn_count {
137 0 => Err(EvalError::undefined_metric("Recall")),
138 den => Ok((self.tp_count as f64) / den as f64)
139 }
140 }
141
142 pub fn f1(&self) -> Result<f64, EvalError> {
146 match (self.precision(), self.recall()) {
147 (Ok(p), Ok(r)) if p == 0.0 && r == 0.0 => Ok(0.0),
148 (Ok(p), Ok(r)) => Ok(2.0 * (p * r) / (p + r)),
149 (Err(e), _) => Err(e),
150 (_, Err(e)) => Err(e)
151 }
152 }
153
154 pub fn mcc(&self) -> Result<f64, EvalError> {
158 let n = self.sum as f64;
159 let s = (self.tp_count + self.fn_count) as f64 / n;
160 let p = (self.tp_count + self.fp_count) as f64 / n;
161 match (p * s * (1.0 - s) * (1.0 - p)).sqrt() {
162 den if den == 0.0 => Err(EvalError::undefined_metric("MCC")),
163 den => Ok(((self.tp_count as f64 / n) - s * p) / den)
164 }
165 }
166}
167
168impl std::fmt::Display for BinaryConfusionMatrix {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 let counts = vec![
171 vec![self.tp_count, self.fp_count],
172 vec![self.fn_count, self.tn_count]
173 ];
174 let outcomes = vec![String::from("Positive"), String::from("Negative")];
175 write!(f, "{}", display::stringify_confusion_matrix(&counts, &outcomes))
176 }
177}
178
179#[derive(Copy, Clone, Debug, PartialEq)]
183pub struct RocPoint<T: Scalar> {
184 pub tp_rate: T,
186 pub fp_rate: T,
188 pub threshold: T
190}
191
192#[derive(Clone, Debug)]
196pub struct RocCurve<T: Scalar> {
197 pub points: Vec<RocPoint<T>>,
199 dim: usize
201}
202
203impl <T: Scalar> RocCurve<T> {
204
205 pub fn compute(scores: &Vec<T>, labels: &Vec<bool>) -> Result<RocCurve<T>, EvalError> {
232 util::validate_input_dims(scores, labels).and_then(|()| {
233 let n = match scores.len() {
235 1 => return Err(EvalError::invalid_input(
236 "Unable to compute roc curve on single data point"
237 )),
238 len => len
239 };
240 let (mut pairs, np) = create_pairs(scores, labels)?;
241 let nn = n - np;
242 sort_pairs_descending(&mut pairs);
243 let mut tpc = if pairs[0].1 {1} else {0};
244 let mut fpc = 1 - tpc;
245 let mut points = Vec::<RocPoint<T>>::new();
246 let mut last_tpr = T::zero();
247 let mut last_fpr = T::zero();
248 let mut trend: Option<RocTrend> = None;
249
250 for i in 1..n {
251 if pairs[i].0 != pairs[i-1].0 {
252 let tp_rate = T::from_usize(tpc) / T::from_usize(np);
253 let fp_rate = T::from_usize(fpc) / T::from_usize(nn);
254 if !tp_rate.is_finite() || !fp_rate.is_finite() {
255 return Err(EvalError::undefined_metric("ROC"))
256 }
257 let threshold = pairs[i-1].0;
258 match trend {
259 Some(RocTrend::Horizontal) => if tp_rate > last_tpr {
260 points.push(RocPoint {tp_rate, fp_rate, threshold});
261 } else if let Some(mut point) = points.last_mut() {
262 point.fp_rate = fp_rate;
263 point.threshold = threshold;
264 },
265 Some(RocTrend::Vertical) => if fp_rate > last_fpr {
266 points.push(RocPoint {tp_rate, fp_rate, threshold})
267 } else if let Some(mut point) = points.last_mut() {
268 point.tp_rate = tp_rate;
269 point.threshold = threshold;
270 },
271 _ => points.push(RocPoint {tp_rate, fp_rate, threshold}),
272 }
273
274 trend = if fp_rate > last_fpr && tp_rate == last_tpr {
275 Some(RocTrend::Horizontal)
276 } else if tp_rate > last_tpr && fp_rate == last_fpr {
277 Some(RocTrend::Vertical)
278 } else {
279 Some(RocTrend::Diagonal)
280 };
281 last_tpr = tp_rate;
282 last_fpr = fp_rate;
283 }
284 if pairs[i].1 {
285 tpc += 1;
286 } else {
287 fpc += 1;
288 }
289 }
290
291 if let Some(mut point) = points.last_mut() {
292 if point.tp_rate != T::one() || point.fp_rate != T::one() {
293 let threshold = pairs.last().unwrap().0;
294 match trend {
295 Some(RocTrend::Horizontal) if point.tp_rate == T::one() => {
296 point.fp_rate = T::one();
297 point.threshold = threshold;
298 },
299 Some(RocTrend::Vertical) if point.fp_rate == T::one() => {
300 point.tp_rate = T::one();
301 point.threshold = threshold;
302 }
303 _ => points.push(RocPoint {
304 tp_rate: T::one(), fp_rate: T::one(), threshold
305 })
306 }
307 }
308 }
309
310 match points.len() {
311 0 => Err(EvalError::constant_input_data()),
312 dim => Ok(RocCurve {points, dim})
313 }
314 })
315 }
316
317 pub fn auc(&self) -> T {
321 let mut val = self.points[0].tp_rate * self.points[0].fp_rate / T::from_f64(2.0);
322 for i in 1..self.dim {
323 let fpr_diff = self.points[i].fp_rate - self.points[i-1].fp_rate;
324 let a = self.points[i-1].tp_rate * fpr_diff;
325 let tpr_diff = self.points[i].tp_rate - self.points[i-1].tp_rate;
326 let b = tpr_diff * fpr_diff / T::from_f64(2.0);
327 val += a + b;
328 }
329 return val
330 }
331}
332
333#[derive(Copy, Clone, Debug, PartialEq)]
337pub struct PrPoint<T: Scalar> {
338 pub precision: T,
340 pub recall: T,
342 pub threshold: T
344}
345
346#[derive(Clone, Debug)]
350pub struct PrCurve<T: Scalar> {
351 pub points: Vec<PrPoint<T>>,
353 dim: usize
355}
356
357impl <T: Scalar> PrCurve<T> {
358
359 pub fn compute(scores: &Vec<T>, labels: &Vec<bool>) -> Result<PrCurve<T>, EvalError> {
386 util::validate_input_dims(scores, labels).and_then(|()| {
387 let n = match scores.len() {
388 1 => return Err(EvalError::invalid_input(
389 "Unable to compute pr curve on single data point"
390 )),
391 len => len
392 };
393 let (mut pairs, mut fnc) = create_pairs(scores, labels)?;
394 sort_pairs_descending(&mut pairs);
395 let mut tpc = 0;
396 let mut fpc = 0;
397 let mut points = Vec::<PrPoint<T>>::new();
398 let mut last_rec = T::zero();
399
400 for i in 0..n {
401 if pairs[i].1 {
402 tpc += 1;
403 fnc -= 1;
404 } else {
405 fpc += 1;
406 }
407 if (i < n-1 && pairs[i].0 != pairs[i+1].0) || i == n-1 {
408 let precision = T::from_usize(tpc) / T::from_usize(tpc + fpc);
409 let recall = T::from_usize(tpc) / T::from_usize(tpc + fnc);
410 if !precision.is_finite() || !recall.is_finite() {
411 return Err(EvalError::undefined_metric("PR"))
412 }
413 let threshold = pairs[i].0;
414 if recall != last_rec {
415 points.push(PrPoint {precision, recall, threshold});
416 }
417 last_rec = recall;
418 }
419 }
420
421 let dim = points.len();
422 Ok(PrCurve {points, dim})
423 })
424 }
425
426 pub fn ap(&self) -> T {
430 let mut val = self.points[0].precision * self.points[0].recall;
431 for i in 1..self.dim {
432 let rec_diff = self.points[i].recall - self.points[i-1].recall;
433 val += rec_diff * self.points[i].precision;
434 }
435 return val;
436 }
437}
438
439
440#[derive(Clone, Debug, Eq, PartialEq)]
445pub struct MultiConfusionMatrix {
446 pub dim: usize,
448 pub counts: Vec<Vec<usize>>,
450 sum: usize
452}
453
454impl MultiConfusionMatrix {
455
456 pub fn compute<T: Scalar>(scores: &Vec<Vec<T>>,
491 labels: &Vec<usize>) -> Result<MultiConfusionMatrix, EvalError> {
492 util::validate_input_dims(scores, labels).and_then(|()| {
493 let dim = scores[0].len();
494 let mut counts = vec![vec![0; dim]; dim];
495 let mut sum = 0;
496 for (i, s) in scores.iter().enumerate() {
497 if s.iter().any(|v| !v.is_finite()) {
498 return Err(EvalError::infinite_value())
499 } else if s.len() != dim {
500 return Err(EvalError::invalid_input("Inconsistent score dimension"))
501 } else if labels[i] >= dim {
502 return Err(EvalError::invalid_input("Labels have more classes than scores"))
503 }
504 let ind = s.iter().enumerate().max_by(|(_, a), (_, b)| {
505 a.partial_cmp(b).unwrap_or(Ordering::Equal)
506 }).map(|(mi, _)| mi).ok_or(EvalError::constant_input_data())?;
507 counts[ind][labels[i]] += 1;
508 sum += 1;
509 }
510 Ok(MultiConfusionMatrix {dim, counts, sum})
511 })
512 }
513
514 pub fn from_counts(counts: Vec<Vec<usize>>) -> Result<MultiConfusionMatrix, EvalError> {
543 let dim = counts.len();
544 let mut sum = 0;
545 for row in &counts {
546 sum += row.iter().sum::<usize>();
547 if row.len() != dim {
548 let msg = format!("Inconsistent column length ({})", row.len());
549 return Err(EvalError::invalid_input(msg.as_str()));
550 }
551 }
552 if sum == 0 {
553 Err(EvalError::invalid_input("Confusion matrix has all zero counts"))
554 } else {
555 Ok(MultiConfusionMatrix {dim, counts, sum})
556 }
557 }
558
559 pub fn accuracy(&self) -> Result<f64, EvalError> {
563 match self.sum {
564 0 => Err(EvalError::undefined_metric("Accuracy")),
566 sum => {
567 let mut correct = 0;
568 for i in 0..self.dim {
569 correct += self.counts[i][i];
570 }
571 Ok(correct as f64 / sum as f64)
572 }
573 }
574 }
575
576 pub fn precision(&self, avg: &Averaging) -> Result<f64, EvalError> {
584 self.agg_metric(&self.per_class_precision(), avg)
585 }
586
587 pub fn recall(&self, avg: &Averaging) -> Result<f64, EvalError> {
595 self.agg_metric(&self.per_class_recall(), avg)
596 }
597
598 pub fn f1(&self, avg: &Averaging) -> Result<f64, EvalError> {
606 self.agg_metric(&self.per_class_f1(), avg)
607 }
608
609 pub fn rk(&self) -> Result<f64, EvalError> {
615 let mut t = vec![0.0; self.dim];
616 let mut p = vec![0.0; self.dim];
617 let mut c = 0.0;
618 let s = self.sum as f64;
619
620 for i in 0..self.dim {
621 c += self.counts[i][i] as f64;
622 for j in 0..self.dim {
623 t[j] += self.counts[i][j] as f64;
624 p[i] += self.counts[i][j] as f64;
625 }
626 }
627
628 let tt = t.iter().fold(0.0, |acc, val| acc + (val * val));
629 let pp = p.iter().fold(0.0, |acc, val| acc + (val * val));
630 let tp = t.iter().zip(p).fold(0.0, |acc, (t_val, p_val)| acc + t_val * p_val);
631 let num = c * s - tp;
632 let den = (s * s - pp).sqrt() * (s * s - tt).sqrt();
633
634 if den == 0.0 {
635 Err(EvalError::undefined_metric("Rk"))
636 } else {
637 Ok(num / den)
638 }
639 }
640
641 pub fn per_class_accuracy(&self) -> Vec<Result<f64, EvalError>> {
645 self.per_class_binary_metric("accuracy")
646 }
647
648 pub fn per_class_precision(&self) -> Vec<Result<f64, EvalError>> {
652 self.per_class_binary_metric("precision")
653 }
654
655 pub fn per_class_recall(&self) -> Vec<Result<f64, EvalError>> {
659 self.per_class_binary_metric("recall")
660 }
661
662 pub fn per_class_f1(&self) -> Vec<Result<f64, EvalError>> {
666 self.per_class_binary_metric("f1")
667 }
668
669 pub fn per_class_mcc(&self) -> Vec<Result<f64, EvalError>> {
673 self.per_class_binary_metric("mcc")
674 }
675
676 fn per_class_binary_metric(&self, metric: &str) -> Vec<Result<f64, EvalError>> {
677 (0..self.dim).map(|k| {
678 let (mut tpc, mut fpc, mut tnc, mut fnc) = (0, 0, 0, 0);
679 for i in 0..self.dim {
680 for j in 0..self.dim {
681 let count = self.counts[i][j];
682 if i == k && j == k {
683 tpc = count;
684 } else if i == k {
685 fpc += count;
686 } else if j == k {
687 fnc += count;
688 } else {
689 tnc += count;
690 }
691 }
692 }
693 let matrix = BinaryConfusionMatrix::from_counts(tpc, fpc, tnc, fnc)?;
694 match metric {
695 "accuracy" => matrix.accuracy(),
696 "precision" => matrix.precision(),
697 "recall" => matrix.recall(),
698 "f1" => matrix.f1(),
699 "mcc" => matrix.mcc(),
700 other => Err(EvalError::invalid_metric(other))
701 }
702 }).collect()
703 }
704
705 fn agg_metric(&self, pcm: &Vec<Result<f64, EvalError>>,
706 avg: &Averaging) -> Result<f64, EvalError> {
707 match avg {
708 Averaging::Macro => self.macro_metric(pcm),
709 Averaging::Weighted => self.weighted_metric(pcm)
710 }
711 }
712
713 fn macro_metric(&self, pcm: &Vec<Result<f64, EvalError>>) -> Result<f64, EvalError> {
714 pcm.iter().try_fold(0.0, |sum, metric| {
715 match metric {
716 Ok(m) => Ok(sum + m),
717 Err(e) => Err(e.clone())
718 }
719 }).map(|sum| {sum / pcm.len() as f64})
720 }
721
722 fn weighted_metric(&self, pcm: &Vec<Result<f64, EvalError>>) -> Result<f64, EvalError> {
723 pcm.iter()
724 .zip(self.class_counts().iter())
725 .try_fold(0.0, |val, (metric, &class)| {
726 match metric {
727 Ok(m) => Ok(val + (m * (class as f64) / (self.sum as f64))),
728 Err(e) => Err(e.clone())
729 }
730 })
731 }
732
733 fn class_counts(&self) -> Vec<usize> {
734 let mut counts = vec![0; self.dim];
735 for i in 0..self.dim {
736 for j in 0..self.dim {
737 counts[j] += self.counts[i][j];
738 }
739 }
740 counts
741 }
742}
743
744impl std::fmt::Display for MultiConfusionMatrix {
745 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
746 if self.dim <= 25 {
747 let outcomes = (0..self.dim).map(|i| format!("Class-{}", i + 1)).collect();
748 write!(f, "{}", display::stringify_confusion_matrix(&self.counts, &outcomes))
749 } else {
750 write!(f, "[Confusion matrix is too large to display]")
751 }
752 }
753}
754
755pub fn m_auc<T: Scalar>(scores: &Vec<Vec<T>>, labels: &Vec<usize>) -> Result<T, EvalError> {
792 util::validate_input_dims(scores, labels).and_then(|()| {
793 let dim = scores[0].len();
794 let mut m_sum = T::zero();
795
796 fn subset<T: Scalar>(scr: &Vec<Vec<T>>,
797 lab: &Vec<usize>,
798 j: usize,
799 k: usize) -> (Vec<T>, Vec<bool>) {
800
801 scr.iter().zip(lab.iter()).filter(|(_, &l)| {
802 l == j || l == k
803 }).map(|(s, &l)| {
804 (s[k], l == k)
805 }).unzip()
806 }
807
808 for j in 0..dim {
809 for k in 0..j {
810 let (k_scores, k_labels) = subset(scores, labels, j, k);
811 let ajk = RocCurve::compute(&k_scores, &k_labels)?.auc();
812 let (j_scores, j_labels) = subset(scores, labels, k, j);
813 let akj = RocCurve::compute(&j_scores, &j_labels)?.auc();
814 m_sum += (ajk + akj) / T::from_f64(2.0);
815 }
816 }
817 Ok(m_sum * T::from_f64(2.0) / (T::from_usize(dim) * (T::from_usize(dim) - T::one())))
818 })
819}
820
821#[derive(Copy, Clone, Debug, Eq, PartialEq)]
825pub enum Averaging {
826 Macro,
828 Weighted
831}
832
833enum RocTrend {
834 Horizontal,
835 Vertical,
836 Diagonal
837}
838
839fn create_pairs<T: Scalar>(scores: &Vec<T>,
840 labels: &Vec<bool>) -> Result<(Vec<(T, bool)>, usize), EvalError> {
841 let n = scores.len();
842 let mut pairs = Vec::with_capacity(n);
843 let mut num_pos = 0;
844
845 for i in 0..n {
846 if !scores[i].is_finite() {
847 return Err(EvalError::infinite_value())
848 } else if labels[i] {
849 num_pos += 1;
850 }
851 pairs.push((scores[i], labels[i]))
852 }
853 Ok((pairs, num_pos))
854}
855
856fn sort_pairs_descending<T: Scalar>(pairs: &mut Vec<(T, bool)>) {
857 pairs.sort_unstable_by(|(s1, _), (s2, _)| {
858 if s1 > s2 {
859 Ordering::Less
860 } else if s1 < s2 {
861 Ordering::Greater
862 } else {
863 Ordering::Equal
864 }
865 });
866}
867
868#[cfg(test)]
869mod tests {
870 use assert_approx_eq::assert_approx_eq;
871 use super::*;
872
873 fn binary_data() -> (Vec<f64>, Vec<bool>) {
874 let scores = vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9];
875 let labels = vec![false, false, true, false, true, false, false, true];
876 (scores, labels)
877 }
878
879 fn multi_class_data() -> (Vec<Vec<f64>>, Vec<usize>) {
880
881 let scores = vec![
882 vec![0.3, 0.1, 0.6],
883 vec![0.5, 0.2, 0.3],
884 vec![0.2, 0.7, 0.1],
885 vec![0.3, 0.3, 0.4],
886 vec![0.5, 0.1, 0.4],
887 vec![0.8, 0.1, 0.1],
888 vec![0.3, 0.5, 0.2]
889 ];
890 let labels = vec![2, 1, 1, 2, 0, 2, 0];
891 (scores, labels)
892 }
893
894 #[test]
895 fn test_binary_confusion_matrix() {
896 let (scores, labels) = binary_data();
897 let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
898 assert_eq!(matrix.tp_count, 2);
899 assert_eq!(matrix.fp_count, 2);
900 assert_eq!(matrix.tn_count, 3);
901 assert_eq!(matrix.fn_count, 1);
902 }
903
904 #[test]
905 fn test_binary_confusion_matrix_empty() {
906 assert!(BinaryConfusionMatrix::compute(
907 &Vec::<f64>::new(),
908 &Vec::<bool>::new(),
909 0.5
910 ).is_err());
911 }
912
913 #[test]
914 fn test_binary_confusion_matrix_unequal_length() {
915 assert!(BinaryConfusionMatrix::compute(
916 &vec![0.1, 0.2],
917 &vec![true, false, true],
918 0.5
919 ).is_err());
920 }
921
922 #[test]
923 fn test_binary_confusion_matrix_nan() {
924 assert!(BinaryConfusionMatrix::compute(
925 &vec![f64::NAN, 0.2, 0.4],
926 &vec![true, false, true],
927 0.5
928 ).is_err());
929 }
930
931 #[test]
932 fn test_binary_confusion_matrix_with_counts() {
933 let matrix = BinaryConfusionMatrix::from_counts(2, 4, 5, 3).unwrap();
934 assert_eq!(matrix.tp_count, 2);
935 assert_eq!(matrix.fp_count, 4);
936 assert_eq!(matrix.tn_count, 5);
937 assert_eq!(matrix.fn_count, 3);
938 assert_eq!(matrix.sum, 14);
939 assert!(BinaryConfusionMatrix::from_counts(0, 0, 0, 0).is_err())
940 }
941
942 #[test]
943 fn test_binary_accuracy() {
944 let (scores, labels) = binary_data();
945 let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
946 assert_approx_eq!(matrix.accuracy().unwrap(), 0.625);
947 }
948
949 #[test]
950 fn test_binary_precision() {
951 let (scores, labels) = binary_data();
952 let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
953 assert_approx_eq!(matrix.precision().unwrap(), 0.5);
954
955 assert!(BinaryConfusionMatrix::compute(
957 &vec![0.4, 0.3, 0.1, 0.2, 0.1],
958 &vec![true, false, true, false, true],
959 0.5
960 ).unwrap().precision().is_err());
961 }
962
963 #[test]
964 fn test_binary_precision_empty() {
965 assert!(BinaryConfusionMatrix::compute(
966 &Vec::<f64>::new(),
967 &Vec::<bool>::new(),
968 0.5
969 ).is_err());
970 }
971
972 #[test]
973 fn test_binary_precision_unequal_length() {
974 assert!(BinaryConfusionMatrix::compute(
975 &vec![0.1, 0.2],
976 &vec![true, false, true],
977 0.5
978 ).is_err());
979 }
980
981 #[test]
982 fn test_binary_recall() {
983 let (scores, labels) = binary_data();
984 let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
985 assert_approx_eq!(matrix.recall().unwrap(), 2.0 / 3.0);
986
987 assert!(BinaryConfusionMatrix::compute(
989 &vec![0.4, 0.3, 0.1, 0.8, 0.7],
990 &vec![false, false, false, false, false],
991 0.5
992 ).unwrap().recall().is_err());
993 }
994
995 #[test]
996 fn test_binary_recall_empty() {
997 assert!(BinaryConfusionMatrix::compute(
998 &Vec::<f64>::new(),
999 &Vec::<bool>::new(),
1000 0.5
1001 ).is_err());
1002 }
1003
1004 #[test]
1005 fn test_binary_recall_unequal_length() {
1006 assert!(BinaryConfusionMatrix::compute(
1007 &vec![0.1, 0.2],
1008 &vec![true, false, true],
1009 0.5
1010 ).is_err());
1011 }
1012
1013 #[test]
1014 fn test_binary_f1() {
1015 let (scores, labels) = binary_data();
1016 let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
1017 assert_approx_eq!(matrix.f1().unwrap(), 0.5714285714285715);
1018
1019 assert!(BinaryConfusionMatrix::compute(
1021 &vec![0.4, 0.3, 0.1, 0.2, 0.1],
1022 &vec![true, false, true, false, true],
1023 0.5
1024 ).unwrap().f1().is_err());
1025
1026 assert!(BinaryConfusionMatrix::compute(
1028 &vec![0.4, 0.3, 0.1, 0.8, 0.7],
1029 &vec![false, false, false, false, false],
1030 0.5
1031 ).unwrap().f1().is_err());
1032 }
1033
1034 #[test]
1035 fn test_binary_f1_empty() {
1036 assert!(BinaryConfusionMatrix::compute(
1037 &Vec::<f64>::new(),
1038 &Vec::<bool>::new(),
1039 0.5
1040 ).is_err());
1041 }
1042
1043 #[test]
1044 fn test_binary_f1_unequal_length() {
1045 assert!(BinaryConfusionMatrix::compute(
1046 &vec![0.1, 0.2],
1047 &vec![true, false, true],
1048 0.5
1049 ).is_err());
1050 }
1051
1052 #[test]
1053 fn test_binary_f1_0p_0r() {
1054 let scores = vec![0.1, 0.2, 0.7, 0.8];
1055 let labels = vec![false, true, false, false];
1056
1057 assert_eq!(BinaryConfusionMatrix::compute(&scores, &labels, 0.5)
1058 .unwrap()
1059 .f1()
1060 .unwrap(), 0.0
1061 )
1062 }
1063
1064 #[test]
1065 fn test_mcc() {
1066 let (scores, labels) = binary_data();
1067 let matrix = BinaryConfusionMatrix::compute(&scores, &labels, 0.5).unwrap();
1068 assert_approx_eq!(matrix.mcc().unwrap(), 0.2581988897471611)
1069 }
1070
1071 #[test]
1072 fn test_roc() {
1073 let (scores, labels) = binary_data();
1074 let roc = RocCurve::compute(&scores, &labels).unwrap();
1075
1076 assert_eq!(roc.dim, 5);
1077 assert_approx_eq!(roc.points[0].tp_rate, 1.0 / 3.0);
1078 assert_approx_eq!(roc.points[0].fp_rate, 0.0);
1079 assert_approx_eq!(roc.points[0].threshold, 0.9);
1080 assert_approx_eq!(roc.points[1].tp_rate, 1.0 / 3.0);
1081 assert_approx_eq!(roc.points[1].fp_rate, 0.2);
1082 assert_approx_eq!(roc.points[1].threshold, 0.8);
1083 assert_approx_eq!(roc.points[2].tp_rate, 2.0 / 3.0);
1084 assert_approx_eq!(roc.points[2].fp_rate, 0.2);
1085 assert_approx_eq!(roc.points[2].threshold, 0.7);
1086 assert_approx_eq!(roc.points[3].tp_rate, 2.0 / 3.0);
1087 assert_approx_eq!(roc.points[3].fp_rate, 1.0);
1088 assert_approx_eq!(roc.points[3].threshold, 0.2);
1089 assert_approx_eq!(roc.points[4].tp_rate, 1.0);
1090 assert_approx_eq!(roc.points[4].fp_rate, 1.0);
1091 assert_approx_eq!(roc.points[4].threshold, 0.1);
1092 }
1093
1094 #[test]
1095 fn test_roc_tied_scores() {
1096 let scores = vec![1.0, 0.1, 1.0, 0.9, 0.5, 0.1, 0.8, 0.9, 1.0, 0.4];
1097 let labels = vec![true, false, false, false, false, false, true, true, false, false];
1098 let roc = RocCurve::compute(&scores, &labels).unwrap();
1099 assert_approx_eq!(roc.points[0].tp_rate, 1.0 / 3.0);
1100 assert_approx_eq!(roc.points[0].fp_rate, 0.2857142857142857);
1101 assert_approx_eq!(roc.points[0].threshold, 1.0);
1102 assert_approx_eq!(roc.points[1].tp_rate, 2.0 / 3.0);
1103 assert_approx_eq!(roc.points[1].fp_rate, 0.42857142857142855);
1104 assert_approx_eq!(roc.points[1].threshold, 0.9);
1105 assert_approx_eq!(roc.points[2].tp_rate, 1.0);
1106 assert_approx_eq!(roc.points[2].fp_rate, 0.42857142857142855);
1107 assert_approx_eq!(roc.points[2].threshold, 0.8);
1108 assert_approx_eq!(roc.points[3].tp_rate, 1.0);
1109 assert_approx_eq!(roc.points[3].fp_rate, 1.0);
1110 assert_approx_eq!(roc.points[3].threshold, 0.1);
1111 }
1112
1113 #[test]
1114 fn test_roc_empty() {
1115 assert!(RocCurve::compute(&Vec::<f64>::new(), &Vec::<bool>::new()).is_err());
1116 }
1117
1118 #[test]
1119 fn test_roc_unequal_length() {
1120 assert!(RocCurve::compute(
1121 &vec![0.4, 0.5, 0.2],
1122 &vec![true, false, true, false]
1123 ).is_err());
1124 }
1125
1126 #[test]
1127 fn test_roc_nan() {
1128 assert!(RocCurve::compute(
1129 &vec![0.4, 0.5, 0.2, f64::NAN],
1130 &vec![true, false, true, false]
1131 ).is_err());
1132 }
1133
1134 #[test]
1135 fn test_roc_constant_label() {
1136 let scores = vec![0.1, 0.4, 0.5, 0.7];
1137 let labels_true = vec![true; 4];
1138 let labels_false = vec![false; 4];
1139 assert!(match RocCurve::compute(&scores, &labels_true) {
1140 Err(err) if err.msg.contains("Undefined") => true,
1141 _ => false
1142 });
1143 assert!(match RocCurve::compute(&scores, &labels_false) {
1144 Err(err) if err.msg.contains("Undefined") => true,
1145 _ => false
1146 });
1147 }
1148
1149 #[test]
1150 fn test_roc_constant_score() {
1151 let scores = vec![0.4, 0.4, 0.4, 0.4];
1152 let labels = vec![true, false, true, false];
1153 assert!(match RocCurve::compute(&scores, &labels) {
1154 Err(err) if err.msg.contains("Constant") => true,
1155 _ => false
1156 });
1157 }
1158
1159 #[test]
1160 fn test_auc() {
1161 let (scores, labels) = binary_data();
1162 assert_approx_eq!(RocCurve::compute(&scores, &labels).unwrap().auc(), 0.6);
1163
1164 let scores2 = vec![0.2, 0.5, 0.5, 0.3];
1165 let labels2 = vec![false, true, false, true];
1166 assert_approx_eq!(RocCurve::compute(&scores2, &labels2).unwrap().auc(), 0.625);
1167 }
1168
1169 #[test]
1170 fn test_auc_tied_scores() {
1171 let scores = vec![0.1, 0.2, 0.3, 0.3, 0.3, 0.7, 0.8];
1172 let labels1 = vec![false, false, true, false, true, false, true];
1173 let labels2 = vec![false, false, true, true, false, false, true];
1174 let labels3 = vec![false, false, false, true, true, false, true];
1175 assert_approx_eq!(RocCurve::compute(&scores, &labels1).unwrap().auc(), 0.75);
1176 assert_approx_eq!(RocCurve::compute(&scores, &labels2).unwrap().auc(), 0.75);
1177 assert_approx_eq!(RocCurve::compute(&scores, &labels3).unwrap().auc(), 0.75);
1178
1179 let scores2 = vec![1.0, 0.1, 1.0, 0.9, 0.5, 0.1, 0.8, 0.9, 1.0, 0.4];
1180 let labels4 = vec![true, false, false, false, false, false, true, true, false, false];
1181 assert_approx_eq!(RocCurve::compute(&scores2, &labels4).unwrap().auc(), 0.6904761904761905);
1182 }
1183
1184 #[test]
1185 fn test_pr() {
1186 let (scores, labels) = binary_data();
1187 let pr = PrCurve::compute(&scores, &labels).unwrap();
1188 assert_approx_eq!(pr.points[0].precision, 1.0);
1189 assert_approx_eq!(pr.points[0].recall, 1.0 / 3.0);
1190 assert_approx_eq!(pr.points[0].threshold, 0.9);
1191 assert_approx_eq!(pr.points[1].precision, 2.0 / 3.0);
1192 assert_approx_eq!(pr.points[1].recall, 2.0 / 3.0);
1193 assert_approx_eq!(pr.points[1].threshold, 0.7);
1194 assert_approx_eq!(pr.points[2].precision, 0.375);
1195 assert_approx_eq!(pr.points[2].recall, 1.0);
1196 assert_approx_eq!(pr.points[2].threshold, 0.1);
1197 }
1198
1199 #[test]
1200 fn test_pr_empty() {
1201 assert!(PrCurve::compute(&Vec::<f64>::new(), &Vec::<bool>::new()).is_err());
1202 }
1203
1204 #[test]
1205 fn test_pr_unequal_length() {
1206 assert!(PrCurve::compute(&vec![0.4, 0.5, 0.2], &vec![true, false, true, false]).is_err());
1207 }
1208
1209 #[test]
1210 fn test_pr_nan() {
1211 assert!(PrCurve::compute(
1212 &vec![0.4, 0.5, 0.2, f64::NAN],
1213 &vec![true, false, true, false]
1214 ).is_err());
1215 }
1216
1217 #[test]
1218 fn test_pr_constant_label() {
1219 let scores = vec![0.1, 0.4, 0.5, 0.7];
1220 let labels_true = vec![true; 4];
1221 let labels_false = vec![false; 4];
1222 assert!(PrCurve::compute(&scores, &labels_true).is_ok());
1223 assert!(match PrCurve::compute(&scores, &labels_false) {
1224 Err(err) if err.msg.contains("Undefined") => true,
1225 _ => false
1226 });
1227 }
1228
1229 #[test]
1230 fn test_pr_constant_score() {
1231 let scores = vec![0.4, 0.4, 0.4, 0.4];
1232 let labels = vec![true, false, true, false];
1233 assert!(PrCurve::compute(&scores, &labels).is_ok());
1234 }
1235
1236 #[test]
1237 fn test_ap() {
1238 let (scores, labels) = binary_data();
1239 assert_approx_eq!(PrCurve::compute(&scores, &labels).unwrap().ap(), 0.6805555555555556);
1240
1241 let scores2 = vec![0.2, 0.5, 0.5, 0.3];
1242 let labels2 = vec![false, true, false, true];
1243 assert_approx_eq!(PrCurve::compute(&scores2, &labels2).unwrap().ap(), 0.58333333333333);
1244 }
1245
1246 #[test]
1247 fn test_ap_tied_scores() {
1248 let scores = vec![0.1, 0.2, 0.3, 0.3, 0.3, 0.7, 0.8];
1249 let labels1 = vec![false, false, true, false, true, false, true];
1250 let labels2 = vec![false, false, true, true, false, false, true];
1251 let labels3 = vec![false, false, false, true, true, false, true];
1252 assert_approx_eq!(PrCurve::compute(&scores, &labels1).unwrap().ap(), 0.7333333333333);
1253 assert_approx_eq!(PrCurve::compute(&scores, &labels2).unwrap().ap(), 0.7333333333333);
1254 assert_approx_eq!(PrCurve::compute(&scores, &labels3).unwrap().ap(), 0.7333333333333);
1255 }
1256
1257 #[test]
1258 fn test_multi_confusion_matrix() {
1259 let (scores, labels) = multi_class_data();
1260 let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1261 assert_eq!(matrix.counts, vec![vec![1, 1, 1], vec![1, 1, 0], vec![0, 0, 2]]);
1262 assert_eq!(matrix.dim, 3);
1263 assert_eq!(matrix.sum, 7);
1264 }
1265
1266 #[test]
1267 fn test_multi_confusion_matrix_empty() {
1268 let scores: Vec<Vec<f64>> = vec![];
1269 let labels = Vec::<usize>::new();
1270 assert!(MultiConfusionMatrix::compute(&scores, &labels).is_err());
1271 }
1272
1273 #[test]
1274 fn test_multi_confusion_matrix_unequal_length() {
1275 assert!(MultiConfusionMatrix::compute(&vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4]],
1276 &vec![2, 1, 0]).is_err());
1277 }
1278
1279 #[test]
1280 fn test_multi_confusion_matrix_nan() {
1281 assert!(MultiConfusionMatrix::compute(
1282 &vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4], vec![0.3, 0.7, f64::NAN]],
1283 &vec![2, 1, 0]
1284 ).is_err());
1285 }
1286
1287 #[test]
1288 fn test_multi_confusion_matrix_inconsistent_score_dims() {
1289 let scores = vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4], vec![0.3, 0.7]];
1290 let labels = vec![2, 1, 0];
1291 assert!(MultiConfusionMatrix::compute(&scores, &labels).is_err());
1292 }
1293
1294 #[test]
1295 fn test_multi_confusion_matrix_score_label_dim_mismatch() {
1296 let scores = vec![vec![0.2, 0.4, 0.4], vec![0.5, 0.1, 0.4], vec![0.3, 0.2, 0.5]];
1297 let labels = vec![2, 3, 0];
1298 assert!(MultiConfusionMatrix::compute(&scores, &labels).is_err());
1299 }
1300
1301 #[test]
1302 fn test_multi_confusion_matrix_counts() {
1303 let counts = vec![vec![6, 3, 1], vec![4, 2, 7], vec![5, 2, 8]];
1304 let matrix = MultiConfusionMatrix::from_counts(counts).unwrap();
1305 assert_eq!(matrix.dim, 3);
1306 assert_eq!(matrix.sum, 38);
1307 assert_eq!(matrix.counts, vec![vec![6, 3, 1], vec![4, 2, 7], vec![5, 2, 8]]);
1308 }
1309
1310 #[test]
1311 fn test_multi_confusion_matrix_bad_counts() {
1312 let counts = vec![vec![6, 3, 1], vec![4, 2], vec![5, 2, 8]];
1313 assert!(MultiConfusionMatrix::from_counts(counts).is_err())
1314 }
1315
1316 #[test]
1317 fn test_multi_accuracy() {
1318 let (scores, labels) = multi_class_data();
1319 let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1320 assert_approx_eq!(matrix.accuracy().unwrap(), 0.5714285714285714)
1321 }
1322
1323 #[test]
1324 fn test_multi_precision() {
1325 let (scores, labels) = multi_class_data();
1326 let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1327 assert_approx_eq!(matrix.precision(&Averaging::Macro).unwrap(), 0.611111111111111);
1328 assert_approx_eq!(matrix.precision(&Averaging::Weighted).unwrap(), 2.0 / 3.0);
1329
1330 assert!(MultiConfusionMatrix::compute(
1331 &vec![vec![0.6, 0.4, 0.0],
1332 vec![0.2, 0.8, 0.0],
1333 vec![0.9, 0.1, 0.0],
1334 vec![0.3, 0.7, 0.0]],
1335 &vec![0, 1, 2, 1]
1336 ).unwrap().precision(&Averaging::Macro).is_err())
1337 }
1338
1339 #[test]
1340 fn test_multi_recall() {
1341 let (scores, labels) = multi_class_data();
1342 let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1343 assert_approx_eq!(matrix.recall(&Averaging::Macro).unwrap(), 0.5555555555555555);
1344 assert_approx_eq!(matrix.recall(&Averaging::Weighted).unwrap(), 0.5714285714285714);
1345
1346 assert!(MultiConfusionMatrix::compute(
1347 &vec![vec![0.6, 0.3, 0.1],
1348 vec![0.2, 0.5, 0.3],
1349 vec![0.8, 0.1, 0.1],
1350 vec![0.3, 0.5, 0.2]],
1351 &vec![0, 1, 0, 1]
1352 ).unwrap().recall(&Averaging::Macro).is_err())
1353 }
1354
1355 #[test]
1356 fn test_multi_f1() {
1357 let (scores, labels) = multi_class_data();
1358 let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1359 assert_approx_eq!(matrix.f1(&Averaging::Macro).unwrap(), 0.5666666666666668);
1360 assert_approx_eq!(matrix.f1(&Averaging::Weighted).unwrap(), 0.6);
1361
1362 assert!(MultiConfusionMatrix::compute(
1363 &vec![vec![0.6, 0.4, 0.0],
1364 vec![0.2, 0.8, 0.0],
1365 vec![0.3, 0.7, 0.0]],
1366 &vec![0, 2, 1]
1367 ).unwrap().f1(&Averaging::Macro).is_err());
1368
1369 assert!(MultiConfusionMatrix::compute(
1370 &vec![vec![0.6, 0.3, 0.1],
1371 vec![0.2, 0.5, 0.3],
1372 vec![0.3, 0.5, 0.2]],
1373 &vec![1, 0, 1]
1374 ).unwrap().f1(&Averaging::Macro).is_err());
1375 }
1376
1377 #[test]
1378 fn test_multi_f1_0p_0r() {
1379 let scores = multi_class_data().0;
1380 let labels = vec![1, 2, 0, 0, 1, 1, 0];
1382
1383 assert_eq!(MultiConfusionMatrix::compute(&scores, &labels)
1384 .unwrap()
1385 .f1(&Averaging::Macro)
1386 .unwrap(), 0.0
1387 )
1388 }
1389
1390 #[test]
1391 fn test_rk() {
1392 let (scores, labels) = multi_class_data();
1393 let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1394 assert_approx_eq!(matrix.rk().unwrap(), 0.375)
1395 }
1396
1397 #[test]
1398 fn test_per_class_accuracy() {
1399 let (scores, labels) = multi_class_data();
1400 let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1401 let pca = matrix.per_class_accuracy();
1402 assert_eq!(pca.len(), 3);
1403 assert_approx_eq!(pca[0].as_ref().unwrap(), 0.5714285714285714);
1404 assert_approx_eq!(pca[1].as_ref().unwrap(), 0.7142857142857143);
1405 assert_approx_eq!(pca[2].as_ref().unwrap(), 0.8571428571428571);
1406 }
1407
1408 #[test]
1409 fn test_per_class_precision() {
1410 let (scores, labels) = multi_class_data();
1411 let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1412 let pcp = matrix.per_class_precision();
1413 assert_eq!(pcp.len(), 3);
1414 assert_approx_eq!(pcp[0].as_ref().unwrap(), 0.3333333333333333);
1415 assert_approx_eq!(pcp[1].as_ref().unwrap(), 0.5);
1416 assert_approx_eq!(pcp[2].as_ref().unwrap(), 1.0);
1417 println!("{}", matrix);
1418 }
1419
1420 #[test]
1421 fn test_per_class_recall() {
1422 let (scores, labels) = multi_class_data();
1423 let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1424 let pcr = matrix.per_class_recall();
1425 assert_eq!(pcr.len(), 3);
1426 assert_approx_eq!(pcr[0].as_ref().unwrap(), 0.5);
1427 assert_approx_eq!(pcr[1].as_ref().unwrap(), 0.5);
1428 assert_approx_eq!(pcr[2].as_ref().unwrap(), 0.6666666666666666);
1429 }
1430
1431 #[test]
1432 fn test_per_class_f1() {
1433 let (scores, labels) = multi_class_data();
1434 let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1435 let pcf = matrix.per_class_f1();
1436 assert_eq!(pcf.len(), 3);
1437 assert_approx_eq!(pcf[0].as_ref().unwrap(), 0.4);
1438 assert_approx_eq!(pcf[1].as_ref().unwrap(), 0.5);
1439 assert_approx_eq!(pcf[2].as_ref().unwrap(), 0.8);
1440 }
1441
1442 #[test]
1443 fn test_per_class_mcc() {
1444 let (scores, labels) = multi_class_data();
1445 let matrix = MultiConfusionMatrix::compute(&scores, &labels).unwrap();
1446 let pcm = matrix.per_class_mcc();
1447 assert_eq!(pcm.len(), 3);
1448 assert_approx_eq!(pcm[0].as_ref().unwrap(), 0.09128709291752773);
1449 assert_approx_eq!(pcm[1].as_ref().unwrap(), 0.3);
1450 assert_approx_eq!(pcm[2].as_ref().unwrap(), 0.7302967433402215);
1451 }
1452
1453 #[test]
1454 fn test_m_auc() {
1455 let (scores, labels) = multi_class_data();
1456 assert_approx_eq!(m_auc(&scores, &labels).unwrap(), 0.673611111111111)
1457 }
1458
1459 #[test]
1460 fn test_m_auc_empty() {
1461 assert!(m_auc(&Vec::<Vec<f64>>::new(), &Vec::<usize>::new()).is_err());
1462 }
1463
1464 #[test]
1465 fn test_m_auc_unequal_length() {
1466 assert!(m_auc(&Vec::<Vec<f64>>::new(), &vec![3, 0, 1, 2]).is_err());
1467 }
1468
1469 #[test]
1470 fn test_m_auc_nan() {
1471 let scores = vec![
1472 vec![0.3, 0.1, 0.6],
1473 vec![0.5, f64::NAN, 0.3],
1474 vec![0.2, 0.7, 0.1],
1475 ];
1476 let labels = vec![1, 2, 0];
1478 assert!(m_auc(&scores, &labels).is_err());
1479 }
1480
1481 #[test]
1482 fn test_m_auc_constant_label() {
1483 let scores = vec![
1484 vec![0.3, 0.1, 0.6],
1485 vec![0.5, 0.2, 0.3],
1486 vec![0.2, 0.7, 0.1],
1487 vec![0.8, 0.1, 0.1],
1488 ];
1489
1490 let labels = vec![1, 1, 1, 1];
1491 assert!(m_auc(&scores, &labels).is_err())
1492 }
1493}