rusty-machine 0.5.4

A machine learning library.
Documentation
//! Functions for scoring a set of predictions, i.e. evaluating
//! how close predictions and truth are. All functions in this
//! module obey the convention that higher is better.

use libnum::{Zero, One};

use linalg::{BaseMatrix, Matrix};
use learning::toolkit::cost_fn::{CostFunc, MeanSqError};

// ************************************
// Classification Scores
// ************************************

/// Returns the fraction of outputs which match their target.
///
/// # Arguments
///
/// * `outputs` - Iterator of output (predicted) labels.
/// * `targets` - Iterator of expected (actual) labels.
///
/// # Examples
///
/// ```
/// use rusty_machine::analysis::score::accuracy;
/// let outputs = [1, 1, 1, 0, 0, 0];
/// let targets = [1, 1, 0, 0, 1, 1];
///
/// assert_eq!(accuracy(outputs.iter(), targets.iter()), 0.5);
/// ```
///
/// # Panics
///
/// - outputs and targets have different length
pub fn accuracy<I>(outputs: I, targets: I) -> f64
    where I: ExactSizeIterator,
          I::Item: PartialEq
{
    assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
    let len = outputs.len() as f64;
    let correct = outputs
        .zip(targets)
        .filter(|&(ref x, ref y)| x == y)
        .count();
    correct as f64 / len
}

/// Returns the fraction of outputs rows which match their target.
pub fn row_accuracy(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64 {
    accuracy(outputs.iter_rows(), targets.iter_rows())
}

/// Returns the precision score for 2 class classification.
///
/// Precision is calculated with true-positive / (true-positive + false-positive),
/// see [Precision and Recall](https://en.wikipedia.org/wiki/Precision_and_recall) for details.
///
/// # Arguments
///
/// * `outputs` - Iterator of output (predicted) labels which only contains 0 or 1.
/// * `targets` - Iterator of expected (actual) labels which only contains 0 or 1.
///
/// # Examples
///
/// ```
/// use rusty_machine::analysis::score::precision;
/// let outputs = [1, 1, 1, 0, 0, 0];
/// let targets = [1, 1, 0, 0, 1, 1];
///
/// assert_eq!(precision(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
/// ```
///
/// # Panics
///
/// - outputs and targets have different length
/// - outputs or targets contains a value which is not 0 or 1
pub fn precision<'a, I, T>(outputs: I, targets: I) -> f64
    where I: ExactSizeIterator<Item=&'a T>,
          T: 'a + PartialEq + Zero + One
{
    assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");

    let mut tpfp = 0.0f64;
    let mut tp = 0.0f64;

    for (ref o, ref t) in outputs.zip(targets) {
        if *o == &T::one() {
            tpfp += 1.0f64;
            if *t == &T::one() {
                tp += 1.0f64;
            }
        }
        if ((*t != &T::zero()) & (*t != &T::one())) |
           ((*o != &T::zero()) & (*o != &T::one())) {
            panic!("precision must be used for 2 class classification")
        }
    }
    tp / tpfp
}

/// Returns the recall score for 2 class classification.
///
/// Recall is calculated with true-positive / (true-positive + false-negative),
/// see [Precision and Recall](https://en.wikipedia.org/wiki/Precision_and_recall) for details.
///
/// # Arguments
///
/// * `outputs` - Iterator of output (predicted) labels which only contains 0 or 1.
/// * `targets` - Iterator of expected (actual) labels which only contains 0 or 1.
///
/// # Examples
///
/// ```
/// use rusty_machine::analysis::score::recall;
/// let outputs = [1, 1, 1, 0, 0, 0];
/// let targets = [1, 1, 0, 0, 1, 1];
///
/// assert_eq!(recall(outputs.iter(), targets.iter()), 0.5);
/// ```
///
/// # Panics
///
/// - outputs and targets have different length
/// - outputs or targets contains a value which is not 0 or 1
pub fn recall<'a, I, T>(outputs: I, targets: I) -> f64
    where I: ExactSizeIterator<Item=&'a T>,
          T: 'a + PartialEq + Zero + One
{
    assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");

    let mut tpfn = 0.0f64;
    let mut tp = 0.0f64;

    for (ref o, ref t) in outputs.zip(targets) {
        if *t == &T::one() {
            tpfn += 1.0f64;
            if *o == &T::one() {
                tp += 1.0f64;
            }
        }
        if ((*t != &T::zero()) & (*t != &T::one())) |
           ((*o != &T::zero()) & (*o != &T::one())) {
            panic!("recall must be used for 2 class classification")
        }
    }
    tp / tpfn
}

/// Returns the f1 score for 2 class classification.
///
/// F1-score is calculated with 2 * precision * recall / (precision + recall),
/// see [F1 score](https://en.wikipedia.org/wiki/F1_score) for details.
///
/// # Arguments
///
/// * `outputs` - Iterator of output (predicted) labels which only contains 0 or 1.
/// * `targets` - Iterator of expected (actual) labels which only contains 0 or 1.
///
/// # Examples
///
/// ```
/// use rusty_machine::analysis::score::f1;
/// let outputs = [1, 1, 1, 0, 0, 0];
/// let targets = [1, 1, 0, 0, 1, 1];
///
/// assert_eq!(f1(outputs.iter(), targets.iter()), 0.5714285714285714);
/// ```
///
/// # Panics
///
/// - outputs and targets have different length
/// - outputs or targets contains a value which is not 0 or 1
pub fn f1<'a, I, T>(outputs: I, targets: I) -> f64
    where I: ExactSizeIterator<Item=&'a T>,
          T: 'a + PartialEq + Zero + One
{
    assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");

    let mut tpos = 0.0f64;
    let mut fpos = 0.0f64;
    let mut fneg = 0.0f64;

    for (ref o, ref t) in outputs.zip(targets) {
        if (*o == &T::one()) & (*t == &T::one()) {
            tpos += 1.0f64;
        } else if *t == &T::one() {
            fpos += 1.0f64;
        } else if *o == &T::one() {
            fneg += 1.0f64;
        }
        if ((*t != &T::zero()) & (*t != &T::one())) |
           ((*o != &T::zero()) & (*o != &T::one())) {
            panic!("f1-score must be used for 2 class classification")
        }
    }
    2.0f64 * tpos / (2.0f64 * tpos + fneg + fpos)
}

// ************************************
// Regression Scores
// ************************************

// TODO: generalise to accept arbitrary iterators of diff-able things
/// Returns the additive inverse of the mean-squared-error of the
/// outputs. So higher is better, and the returned value is always
/// negative.
pub fn neg_mean_squared_error(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64
{
    // MeanSqError divides the actual mean squared error by two.
    -2f64 * MeanSqError::cost(outputs, targets)
}

#[cfg(test)]
mod tests {
    use linalg::Matrix;
    use super::{accuracy, precision, recall, f1, neg_mean_squared_error};

    #[test]
    fn test_accuracy() {
        let outputs = [1, 2, 3, 4, 5, 6];
        let targets = [1, 2, 3, 3, 5, 1];
        assert_eq!(accuracy(outputs.iter(), targets.iter()), 2f64/3f64);

        let outputs = [1, 1, 1, 0, 0, 0];
        let targets = [1, 1, 1, 0, 0, 1];
        assert_eq!(accuracy(outputs.iter(), targets.iter()), 5.0f64 / 6.0f64);
    }

    #[test]
    fn test_precision() {
        let outputs = [1, 1, 1, 0, 0, 0];
        let targets = [1, 1, 0, 0, 1, 1];
        assert_eq!(precision(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);

        let outputs = [1, 1, 1, 0, 1, 1];
        let targets = [1, 1, 0, 0, 1, 1];
        assert_eq!(precision(outputs.iter(), targets.iter()), 0.8);

        let outputs = [0, 0, 0, 1, 1, 1];
        let targets = [1, 1, 1, 1, 1, 0];
        assert_eq!(precision(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);

        let outputs = [1, 1, 1, 1, 1, 0];
        let targets = [0, 0, 0, 1, 1, 1];
        assert_eq!(precision(outputs.iter(), targets.iter()), 0.4);
    }

    #[test]
    #[should_panic]
    fn test_precision_outputs_not_2class() {
        let outputs = [1, 2, 1, 0, 0, 0];
        let targets = [1, 1, 0, 0, 1, 1];
        precision(outputs.iter(), targets.iter());
    }

    #[test]
    #[should_panic]
    fn test_precision_targets_not_2class() {
        let outputs = [1, 0, 1, 0, 0, 0];
        let targets = [1, 2, 0, 0, 1, 1];
        precision(outputs.iter(), targets.iter());
    }

    #[test]
    fn test_recall() {
        let outputs = [1, 1, 1, 0, 0, 0];
        let targets = [1, 1, 0, 0, 1, 1];
        assert_eq!(recall(outputs.iter(), targets.iter()), 0.5);

        let outputs = [1, 1, 1, 0, 1, 1];
        let targets = [1, 1, 0, 0, 1, 1];
        assert_eq!(recall(outputs.iter(), targets.iter()), 1.0);

        let outputs = [0, 0, 0, 1, 1, 1];
        let targets = [1, 1, 1, 1, 1, 0];
        assert_eq!(recall(outputs.iter(), targets.iter()), 0.4);

        let outputs = [1, 1, 1, 1, 1, 0];
        let targets = [0, 0, 0, 1, 1, 1];
        assert_eq!(recall(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
    }

    #[test]
    #[should_panic]
    fn test_recall_outputs_not_2class() {
        let outputs = [1, 2, 1, 0, 0, 0];
        let targets = [1, 1, 0, 0, 1, 1];
        recall(outputs.iter(), targets.iter());
    }

    #[test]
    #[should_panic]
    fn test_recall_targets_not_2class() {
        let outputs = [1, 0, 1, 0, 0, 0];
        let targets = [1, 2, 0, 0, 1, 1];
        recall(outputs.iter(), targets.iter());
    }

    #[test]
    fn test_f1() {
        let outputs = [1, 1, 1, 0, 0, 0];
        let targets = [1, 1, 0, 0, 1, 1];
        assert_eq!(f1(outputs.iter(), targets.iter()), 0.5714285714285714);

        let outputs = [1, 1, 1, 0, 1, 1];
        let targets = [1, 1, 0, 0, 1, 1];
        assert_eq!(f1(outputs.iter(), targets.iter()), 0.8888888888888888);

        let outputs = [0, 0, 0, 1, 1, 1];
        let targets = [1, 1, 1, 1, 1, 0];
        assert_eq!(f1(outputs.iter(), targets.iter()), 0.5);

        let outputs = [1, 1, 1, 1, 1, 0];
        let targets = [0, 0, 0, 1, 1, 1];
        assert_eq!(f1(outputs.iter(), targets.iter()), 0.5);
    }


    #[test]
    #[should_panic]
    fn test_f1_outputs_not_2class() {
        let outputs = [1, 2, 1, 0, 0, 0];
        let targets = [1, 1, 0, 0, 1, 1];
        f1(outputs.iter(), targets.iter());
    }

    #[test]
    #[should_panic]
    fn test_f1_targets_not_2class() {
        let outputs = [1, 0, 1, 0, 0, 0];
        let targets = [1, 2, 0, 0, 1, 1];
        f1(outputs.iter(), targets.iter());
    }

    #[test]
    fn test_neg_mean_squared_error_1d() {
        let outputs = Matrix::new(3, 1, vec![1f64, 2f64, 3f64]);
        let targets = Matrix::new(3, 1, vec![2f64, 4f64, 3f64]);
        assert_eq!(neg_mean_squared_error(&outputs, &targets), -5f64/3f64);
    }

    #[test]
    fn test_neg_mean_squared_error_2d() {
        let outputs = Matrix::new(3, 2, vec![
            1f64, 2f64,
            3f64, 4f64,
            5f64, 6f64
            ]);
        let targets = Matrix::new(3, 2, vec![
            1.5f64, 2.5f64,
            5f64,   6f64,
            5.5f64, 6.5f64
            ]);
        assert_eq!(neg_mean_squared_error(&outputs, &targets), -3f64);
    }
}