tangram_metrics 0.7.0

Tangram makes it easy for programmers to train, deploy, and monitor machine learning models.
Documentation
use itertools::Itertools;
use std::num::NonZeroUsize;

/// This `Metric` computes the area under the receiver operating characteristic curve.
pub struct AucRoc;

impl AucRoc {
	pub fn compute(mut input: Vec<(f32, NonZeroUsize)>) -> f32 {
		// Sort by probabilities in descending order.
		input.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap().reverse());
		// Collect the true_positives and false_positives counts for each unique probability.
		let mut true_positives_false_positives: Vec<TruePositivesFalsePositivesPoint> = Vec::new();
		for (probability, label) in input.iter() {
			// Labels are 1-indexed.
			let label = label.get() - 1;
			// If the classification threshold were to be this probability and the label is 1, the prediction is a true_positive. If the label is 0, its not a true_positive.
			let true_positive = label;
			// If the classification threshold were to be this probability and the label is 0, the prediction is a false_positive. If the label is 1, its not a false_positive.
			let false_positive = 1 - label;
			match true_positives_false_positives.last() {
				Some(last_point)
					if f32::abs(probability - last_point.probability) < std::f32::EPSILON =>
				{
					let last = true_positives_false_positives.last_mut().unwrap();
					last.true_positives += true_positive;
					last.false_positives += false_positive;
				}
				_ => {
					true_positives_false_positives.push(TruePositivesFalsePositivesPoint {
						probability: *probability,
						true_positives: true_positive,
						false_positives: false_positive,
					});
				}
			}
		}
		// Compute the cumulative sum of true positives and false positives.
		for i in 1..true_positives_false_positives.len() {
			true_positives_false_positives[i].true_positives +=
				true_positives_false_positives[i - 1].true_positives;
			true_positives_false_positives[i].false_positives +=
				true_positives_false_positives[i - 1].false_positives;
		}
		// Get the total count of positives.
		let count_positives = input.iter().map(|l| l.1.get() - 1).sum::<usize>();
		// Get the total count of negatives.
		let count_negatives = input.len() - count_positives;
		// The true_positive_rate at threshold x is the percent of the total positives that have a prediction probability >= x. At the maximum probability `x` observed in the dataset, either the true_positive_rate or false_positive_rate will be nonzero depending on whether the label at the this highest probability point is positive or negative respectively. This means that we will not have a point on the ROC curve with a true_positive_rate and false_positive_rate of 0. We create a dummy point with an impossible threshold of 2.0 such that no predictions have probability >= 2.0. At this threshold, both the true_positive_rate and false_positive_rate is 0.
		let mut roc_curve = vec![RocCurvePoint {
			threshold: 2.0,
			true_positive_rate: 0.0,
			false_positive_rate: 0.0,
		}];
		for true_positives_false_positives_point in true_positives_false_positives.iter() {
			roc_curve.push(RocCurvePoint {
				// The true positive rate is the number of true positives divided by the total number of positives.
				true_positive_rate: true_positives_false_positives_point.true_positives as f32
					/ count_positives as f32,
				threshold: true_positives_false_positives_point.probability,
				// The false positive rate is the number of false positives divided by the total number of negatives.
				false_positive_rate: true_positives_false_positives_point.false_positives as f32
					/ count_negatives as f32,
			});
		}
		// Compute the riemann sum using the trapezoidal rule.
		roc_curve
			.iter()
			.tuple_windows()
			.map(|(left, right)| {
				let y_avg =
					(left.true_positive_rate as f64 + right.true_positive_rate as f64) / 2.0;
				let dx = right.false_positive_rate as f64 - left.false_positive_rate as f64;
				y_avg * dx
			})
			.sum::<f64>() as f32
	}
}

/// A point on the ROC curve, parameterized by thresholds.
#[derive(Debug, PartialEq)]
struct RocCurvePoint {
	/// The classification threshold.
	threshold: f32,
	/// The true positive rate for all predictions with probability <= threshold.
	true_positive_rate: f32,
	/// The false positive rate for all predictions with probability <= threshold.
	false_positive_rate: f32,
}

#[derive(Debug)]
struct TruePositivesFalsePositivesPoint {
	/// The prediction probability.
	probability: f32,
	/// The true positives for this threshold.
	true_positives: usize,
	/// The false positives for this threshold.
	false_positives: usize,
}

#[test]
fn test_roc_curve() {
	use tangram_zip::zip;
	let labels = vec![
		NonZeroUsize::new(2).unwrap(),
		NonZeroUsize::new(2).unwrap(),
		NonZeroUsize::new(1).unwrap(),
		NonZeroUsize::new(1).unwrap(),
	];
	let probabilities = vec![0.9, 0.4, 0.4, 0.2];
	let input = zip!(probabilities.into_iter(), labels.into_iter()).collect();
	let actual = AucRoc::compute(input);
	let expected = 0.875;
	assert!(f32::abs(actual - expected) < f32::EPSILON)
}