use std::collections::HashMap;
use std::fmt;
use ndarray::prelude::*;
use ndarray::Data;
use crate::dataset::AsSingleTargets;
use crate::dataset::{AsTargets, DatasetBase, Label, Labels, Pr, Records};
use crate::error::{Error, Result};
fn map_prediction_to_idx<L: Label>(
prediction: &[L],
ground_truth: &[L],
classes: &[L],
) -> Vec<Option<(usize, usize)>> {
let set = classes
.iter()
.enumerate()
.map(|(a, b)| (b, a))
.collect::<HashMap<_, usize>>();
prediction
.iter()
.zip(ground_truth.iter())
.map(|(a, b)| set.get(&a).and_then(|x| set.get(&b).map(|y| (*x, *y))))
.collect::<Vec<Option<_>>>()
}
#[derive(Clone, PartialEq)]
pub struct ConfusionMatrix<A> {
matrix: Array2<f32>,
members: Array1<A>,
}
impl<A> ConfusionMatrix<A> {
fn is_binary(&self) -> bool {
self.matrix.shape() == [2, 2]
}
pub fn precision(&self) -> f32 {
if self.is_binary() {
self.matrix[(0, 0)] / (self.matrix[(0, 0)] + self.matrix[(1, 0)])
} else {
self.split_one_vs_all()
.into_iter()
.map(|x| x.precision())
.sum::<f32>()
/ self.members.len() as f32
}
}
pub fn recall(&self) -> f32 {
if self.is_binary() {
self.matrix[(0, 0)] / (self.matrix[(0, 0)] + self.matrix[(0, 1)])
} else {
self.split_one_vs_all()
.into_iter()
.map(|x| x.recall())
.sum::<f32>()
/ self.members.len() as f32
}
}
pub fn accuracy(&self) -> f32 {
self.matrix.diag().sum() / self.matrix.sum()
}
pub fn f_score(&self, beta: f32) -> f32 {
let sb = beta * beta;
let p = self.precision();
let r = self.recall();
(1. + sb) * (p * r) / (sb * p + r)
}
pub fn f1_score(&self) -> f32 {
self.f_score(1.0)
}
pub fn mcc(&self) -> f32 {
let mut cov_xy = 0.0;
for k in 0..self.members.len() {
for l in 0..self.members.len() {
for m in 0..self.members.len() {
cov_xy += self.matrix[(k, k)] * self.matrix[(l, m)];
cov_xy -= self.matrix[(k, l)] * self.matrix[(m, k)];
}
}
}
let sum = self.matrix.sum();
let sum_over_cols = self.matrix.sum_axis(Axis(0));
let sum_over_rows = self.matrix.sum_axis(Axis(1));
let mut cov_xx: f32 = 0.0;
let mut cov_yy: f32 = 0.0;
for k in 0..self.members.len() {
cov_xx += sum_over_rows[k] * (sum - sum_over_rows[k]);
cov_yy += sum_over_cols[k] * (sum - sum_over_cols[k]);
}
cov_xy / cov_xx.sqrt() / cov_yy.sqrt()
}
pub fn split_one_vs_all(&self) -> Vec<ConfusionMatrix<bool>> {
let sum = self.matrix.sum();
(0..self.members.len())
.map(|i| {
let tp = self.matrix[(i, i)];
let fp = self.matrix.row(i).sum() - tp;
let _fn = self.matrix.column(i).sum() - tp;
let tn = sum - tp - fp - _fn;
ConfusionMatrix {
matrix: array![[tp, fp], [_fn, tn]],
members: Array1::from(vec![true, false]),
}
})
.collect()
}
pub fn split_one_vs_one(&self) -> Vec<ConfusionMatrix<bool>> {
let n = self.members.len();
let mut cms = Vec::with_capacity(n * (n - 1) / 2);
for i in 0..n {
for j in i..n {
let tp = self.matrix[(i, i)];
let fp = self.matrix[(i, j)];
let _fn = self.matrix[(j, i)];
let tn = self.matrix[(j, j)];
cms.push(ConfusionMatrix {
matrix: array![[tp, fp], [_fn, tn]],
members: Array1::from(vec![true, false]),
});
}
}
cms
}
}
impl<A: fmt::Display> fmt::Debug for ConfusionMatrix<A> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let len = self.matrix.len_of(Axis(0));
writeln!(f)?;
write!(f, "{: <10}", "classes")?;
for i in 0..len {
write!(f, " | {: <10}", self.members[i])?;
}
writeln!(f)?;
for i in 0..len {
write!(f, "{: <10}", self.members[i])?;
for j in 0..len {
write!(f, " | {: <10}", self.matrix[(i, j)])?;
}
writeln!(f)?;
}
Ok(())
}
}
pub trait ToConfusionMatrix<A, T> {
fn confusion_matrix(&self, ground_truth: T) -> Result<ConfusionMatrix<A>>;
}
impl<L: Label, S, T> ToConfusionMatrix<L, ArrayBase<S, Ix1>> for T
where
S: Data<Elem = L>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
fn confusion_matrix(&self, ground_truth: ArrayBase<S, Ix1>) -> Result<ConfusionMatrix<L>> {
self.confusion_matrix(&ground_truth)
}
}
impl<L: Label, S, T> ToConfusionMatrix<L, &ArrayBase<S, Ix1>> for T
where
S: Data<Elem = L>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
fn confusion_matrix(&self, ground_truth: &ArrayBase<S, Ix1>) -> Result<ConfusionMatrix<L>> {
let targets = self.as_single_targets();
if targets.len() != ground_truth.len() {
return Err(Error::MismatchedShapes(targets.len(), ground_truth.len()));
}
let classes = self.labels();
let indices = map_prediction_to_idx(
targets.as_slice().unwrap(),
ground_truth.as_slice().unwrap(),
&classes,
);
let mut confusion_matrix = Array2::zeros((classes.len(), classes.len()));
for (i1, i2) in indices.into_iter().flatten() {
confusion_matrix[(i1, i2)] += 1.0;
}
Ok(ConfusionMatrix {
matrix: confusion_matrix,
members: Array1::from(classes),
})
}
}
impl<L: Label, R, R2, T, T2> ToConfusionMatrix<L, &DatasetBase<R, T>> for DatasetBase<R2, T2>
where
R: Records,
R2: Records,
T: AsSingleTargets<Elem = L>,
T2: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
fn confusion_matrix(&self, ground_truth: &DatasetBase<R, T>) -> Result<ConfusionMatrix<L>> {
self.targets().confusion_matrix(ground_truth.as_targets())
}
}
impl<L: Label, S: Data<Elem = L>, T: AsSingleTargets<Elem = L> + Labels<Elem = L>, R: Records>
ToConfusionMatrix<L, &DatasetBase<R, T>> for ArrayBase<S, Ix1>
{
fn confusion_matrix(&self, ground_truth: &DatasetBase<R, T>) -> Result<ConfusionMatrix<L>> {
ground_truth.confusion_matrix(self.view())
}
}
fn trapezoidal<A: NdFloat>(vals: &[(A, A)]) -> A {
let mut prev_x = vals[0].0;
let mut prev_y = vals[0].1;
let mut integral = A::zero();
for (x, y) in vals.iter().skip(1) {
integral += (*x - prev_x) * (prev_y + *y) / A::from(2.0).unwrap();
prev_x = *x;
prev_y = *y;
}
integral
}
#[derive(Debug, Clone, PartialEq)]
pub struct ReceiverOperatingCharacteristic {
curve: Vec<(f32, f32)>,
thresholds: Vec<f32>,
}
impl ReceiverOperatingCharacteristic {
pub fn get_curve(&self) -> Vec<(f32, f32)> {
self.curve.clone()
}
pub fn get_thresholds(&self) -> Vec<f32> {
self.thresholds.clone()
}
pub fn area_under_curve(&self) -> f32 {
trapezoidal(&self.curve)
}
}
pub trait BinaryClassification<T> {
fn roc(&self, y: T) -> Result<ReceiverOperatingCharacteristic>;
fn log_loss(&self, y: T) -> Result<f32>;
}
impl BinaryClassification<&[bool]> for &[Pr] {
fn roc(&self, y: &[bool]) -> Result<ReceiverOperatingCharacteristic> {
let mut tuples = self
.iter()
.zip(y.iter())
.filter_map(|(a, b)| if **a >= 0.0 { Some((*a, *b)) } else { None })
.collect::<Vec<(Pr, bool)>>();
tuples.sort_unstable_by(&|a: &(Pr, _), b: &(Pr, _)| match a.0.partial_cmp(&b.0) {
Some(ord) => ord,
None => unreachable!(),
});
let (mut tp, mut fp) = (0.0, 0.0);
let mut tps_fps = Vec::new();
let mut thresholds = Vec::new();
let mut s0 = 0.0;
for (s, t) in tuples {
if (*s - s0).abs() > 1e-10 {
tps_fps.push((tp, fp));
thresholds.push(s);
s0 = *s;
}
if t {
tp += 1.0;
} else {
fp += 1.0;
}
}
tps_fps.push((tp, fp));
let (max_tp, max_fp) = (tp, fp);
for (tp, fp) in &mut tps_fps {
*tp /= max_tp;
*fp /= max_fp;
}
Ok(ReceiverOperatingCharacteristic {
curve: tps_fps,
thresholds: thresholds.into_iter().map(|x| *x).collect(),
})
}
fn log_loss(&self, y: &[bool]) -> Result<f32> {
let probabilities = aview1(self);
probabilities.log_loss(y)
}
}
impl<D: Data<Elem = Pr>> BinaryClassification<&[bool]> for ArrayBase<D, Ix1> {
fn roc(&self, y: &[bool]) -> Result<ReceiverOperatingCharacteristic> {
self.as_slice().unwrap().roc(y)
}
fn log_loss(&self, y: &[bool]) -> Result<f32> {
assert_eq!(
self.len(),
y.len(),
"The number of predicted points must match the length of target."
);
let len = self.len();
if len == 0 {
Err(Error::NotEnoughSamples)
} else {
let sum: f32 = self
.iter()
.map(|v| (*v).clamp(f32::EPSILON, 1. - f32::EPSILON))
.zip(y.iter())
.map(|(a, b)| if *b { -a.ln() } else { -(1. - a).ln() })
.sum();
Ok(sum / len as f32)
}
}
}
impl<R: Records, R2: Records, T: AsSingleTargets<Elem = bool>, T2: AsSingleTargets<Elem = Pr>>
BinaryClassification<&DatasetBase<R, T>> for DatasetBase<R2, T2>
{
fn roc(&self, y: &DatasetBase<R, T>) -> Result<ReceiverOperatingCharacteristic> {
let targets = self.as_targets();
let targets = targets.as_slice().unwrap();
let y_targets = y.as_targets();
let y_targets = y_targets.as_slice().unwrap();
targets.roc(y_targets)
}
fn log_loss(&self, y: &DatasetBase<R, T>) -> Result<f32> {
let probabilities = self.as_single_targets();
let y_targets = y.as_targets();
let y_targets = y_targets.as_slice().unwrap();
probabilities.log_loss(y_targets)
}
}
#[cfg(test)]
mod tests {
use super::{BinaryClassification, ConfusionMatrix, ToConfusionMatrix};
use super::{Label, Pr};
use approx::assert_abs_diff_eq;
use ndarray::{array, Array1, Array2, ArrayView1};
use rand::{distributions::Uniform, rngs::SmallRng, Rng, SeedableRng};
use std::collections::HashMap;
fn get_labels_map<L: Label>(cm: &ConfusionMatrix<L>) -> HashMap<L, usize> {
cm.members
.iter()
.enumerate()
.map(|(index, label)| (label.clone(), index))
.collect()
}
fn assert_cm_eq<L: Label>(cm: &ConfusionMatrix<L>, expected: &Array2<f32>, labels: &Array1<L>) {
let map = get_labels_map(cm);
for ((row, column), value) in expected.indexed_iter().map(|((r, c), v)| {
(
(*map.get(&labels[r]).unwrap(), *map.get(&labels[c]).unwrap()),
v,
)
}) {
let cm_value = *cm.matrix.get((row, column)).unwrap();
assert_abs_diff_eq!(cm_value, value);
}
}
fn assert_split_eq<L: Label, C: Fn(&ConfusionMatrix<bool>) -> f32>(
cm: &ConfusionMatrix<L>,
eval: C,
expected: &Array1<f32>,
labels: &Array1<L>,
) {
let map = get_labels_map(cm);
let evals = cm
.split_one_vs_all()
.into_iter()
.map(|x| eval(&x))
.collect::<Vec<_>>();
for (index, value) in expected
.indexed_iter()
.map(|(i, v)| (*map.get(&labels[i]).unwrap(), v))
{
let evals_value = *evals.get(index).unwrap();
assert_abs_diff_eq!(evals_value, value);
}
}
#[test]
fn test_confusion_matrix() {
let ground_truth = ArrayView1::from(&[1, 1, 0, 1, 0, 1]);
let predicted = ArrayView1::from(&[0, 1, 0, 1, 0, 1]);
let cm = predicted.confusion_matrix(ground_truth).unwrap();
let labels = array![0, 1];
let expected = array![[2., 1.], [0., 3.]];
assert_cm_eq(&cm, &expected, &labels);
}
#[test]
fn test_cm_metrices() {
let ground_truth = Array1::from(vec![1, 1, 0, 1, 0, 1]);
let predicted = Array1::from(vec![0, 1, 0, 1, 0, 1]);
let x = predicted.confusion_matrix(ground_truth).unwrap();
let labels = array![0, 1];
assert_abs_diff_eq!(x.accuracy(), 5.0 / 6.0_f32);
assert_abs_diff_eq!(
x.mcc(),
(2. * 3. - 1. * 0.) / (2.0f32 * 3. * 3. * 4.).sqrt()
);
assert_split_eq(
&x,
ConfusionMatrix::precision,
&array![1.0, 3. / 4.],
&labels,
);
assert_split_eq(
&x,
ConfusionMatrix::recall,
&array![2.0 / 3.0, 1.0],
&labels,
);
assert_split_eq(
&x,
ConfusionMatrix::f1_score,
&array![4.0 / 5.0, 6.0 / 7.0],
&labels,
);
}
#[test]
fn test_roc_curve() {
let predicted = ArrayView1::from(&[0.1, 0.3, 0.5, 0.7, 0.8, 0.9]).mapv(Pr::new);
let groundtruth = vec![false, true, false, true, true, true];
let result = &[
(0.0, 0.0), (0.0, 0.5), (0.25, 0.5), (0.25, 1.0), (0.5, 1.0), (0.75, 1.0),
(1., 1.),
];
let roc = predicted.roc(&groundtruth).unwrap();
assert_eq!(roc.get_curve(), result);
}
#[test]
fn test_roc_auc() {
let mut rng = SmallRng::seed_from_u64(42);
let predicted = Array1::linspace(0.0, 1.0, 1000).mapv(Pr::new);
let range = Uniform::new(0, 2);
let ground_truth = (0..1000)
.map(|_| rng.sample(range) == 1)
.collect::<Vec<_>>();
let roc = predicted.roc(&ground_truth).unwrap();
assert!((roc.area_under_curve() - 0.5) < 0.04);
}
#[test]
fn split_one_vs_all() {
let ground_truth = array![0, 2, 3, 0, 1, 2, 1, 2, 3, 2];
let predicted = array![0, 3, 2, 0, 1, 1, 1, 3, 2, 3];
let cm = predicted.confusion_matrix(ground_truth).unwrap();
let labels = array![0, 1, 2, 3];
let bin_labels = array![true, false];
let map = get_labels_map(&cm);
let n_cm = cm.split_one_vs_all();
let result = &[
array![[2., 0.], [0., 8.]], array![[2., 1.], [0., 7.]], array![[0., 2.], [4., 4.]], array![[0., 3.], [2., 5.]], ];
for (r, x) in result
.iter()
.zip(labels.iter())
.map(|(r, l)| (r, n_cm.get(*map.get(l).unwrap()).unwrap()))
{
assert_cm_eq(x, r, &bin_labels);
}
}
#[test]
fn log_loss() {
let ground_truth = &[false, false, false, false, true, true, true, true, true];
let predicted =
ArrayView1::from(&[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]).mapv(Pr::new);
let logloss = predicted.log_loss(ground_truth).unwrap();
assert_abs_diff_eq!(logloss, 0.34279516);
}
#[test]
#[should_panic]
fn log_loss_empty() {
let ground_truth = &[];
let predicted = ArrayView1::from(&[]).mapv(Pr::new);
predicted.log_loss(ground_truth).unwrap();
}
#[test]
#[should_panic]
fn log_loss_with_different_lengths() {
let ground_truth = &[false, false, false, false, true, true, true, true];
let predicted =
ArrayView1::from(&[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]).mapv(Pr::new);
predicted.log_loss(ground_truth).unwrap();
}
}