rust_metrics 0.1.2

Incremental evaluation metrics for classification and text generation pipelines.
Documentation
use crate::core::{Metric, MetricError};
use crate::utils::{verify_binary_label, verify_range};

/// Thresholded precision for binary classification probabilities.
///
/// ```
/// use rust_metrics::{BinaryPrecision, Metric};
///
/// let mut precision = BinaryPrecision::new(0.5);
/// precision.update((&[0.9, 0.4], &[1.0, 0.0])).unwrap();
/// assert_eq!(precision.compute(), Some(1.0));
/// ```
#[derive(Debug, Clone)]
pub struct BinaryPrecision {
    threshold: f64,
    true_positive_ct: usize,
    false_positive_ct: usize,
}

impl Default for BinaryPrecision {
    fn default() -> Self {
        Self::new(0.5)
    }
}

impl BinaryPrecision {
    pub fn new(threshold: f64) -> Self {
        verify_range(threshold, 0.0, 1.0).unwrap();
        Self {
            threshold,
            true_positive_ct: 0,
            false_positive_ct: 0,
        }
    }
}

impl Metric<(&[f64], &[f64])> for BinaryPrecision {
    type Output = f64;

    fn update(&mut self, (predictions, targets): (&[f64], &[f64])) -> Result<(), MetricError> {
        if predictions.len() != targets.len() {
            return Err(MetricError::LengthMismatch {
                predictions: predictions.len(),
                targets: targets.len(),
            });
        }
        for (&prediction, &target) in predictions.iter().zip(targets.iter()) {
            verify_range(prediction, 0.0, 1.0)?;
            verify_binary_label(target)?;
            if prediction > self.threshold {
                if target == 1.0 {
                    self.true_positive_ct += 1;
                } else {
                    self.false_positive_ct += 1;
                }
            }
        }

        Ok(())
    }

    fn reset(&mut self) {
        self.true_positive_ct = 0;
        self.false_positive_ct = 0;
    }

    fn compute(&self) -> Option<Self::Output> {
        if self.true_positive_ct + self.false_positive_ct == 0 {
            return None;
        }
        Some(self.true_positive_ct as f64 / (self.true_positive_ct + self.false_positive_ct) as f64)
    }
}

/// Thresholded recall for binary classification probabilities.
///
/// ```
/// use rust_metrics::{BinaryRecall, Metric};
///
/// let mut recall = BinaryRecall::new(0.5);
/// recall.update((&[0.9, 0.4], &[1.0, 1.0])).unwrap();
/// assert_eq!(recall.compute(), Some(0.5));
/// ```
#[derive(Debug, Clone)]
pub struct BinaryRecall {
    threshold: f64,
    true_positive_ct: usize,
    false_negative_ct: usize,
}

impl Default for BinaryRecall {
    fn default() -> Self {
        Self::new(0.5)
    }
}

impl BinaryRecall {
    pub fn new(threshold: f64) -> Self {
        verify_range(threshold, 0.0, 1.0).unwrap();
        Self {
            threshold,
            true_positive_ct: 0,
            false_negative_ct: 0,
        }
    }
}

impl Metric<(&[f64], &[f64])> for BinaryRecall {
    type Output = f64;

    fn update(&mut self, (predictions, targets): (&[f64], &[f64])) -> Result<(), MetricError> {
        if predictions.len() != targets.len() {
            return Err(MetricError::LengthMismatch {
                predictions: predictions.len(),
                targets: targets.len(),
            });
        }
        for (&prediction, &target) in predictions.iter().zip(targets.iter()) {
            verify_range(prediction, 0.0, 1.0)?;
            verify_binary_label(target)?;
            if target == 1.0 {
                if prediction > self.threshold {
                    self.true_positive_ct += 1;
                } else {
                    self.false_negative_ct += 1;
                }
            }
        }

        Ok(())
    }

    fn reset(&mut self) {
        self.true_positive_ct = 0;
        self.false_negative_ct = 0;
    }

    fn compute(&self) -> Option<Self::Output> {
        if self.true_positive_ct + self.false_negative_ct == 0 {
            return None;
        }
        Some(self.true_positive_ct as f64 / (self.true_positive_ct + self.false_negative_ct) as f64)
    }
}

#[cfg(test)]
mod tests {
    use super::{BinaryPrecision, BinaryRecall};
    use crate::core::{Metric, MetricError};

    #[test]
    fn binary_precision_computes_over_batches() {
        let mut precision = BinaryPrecision::default();

        precision
            .update((&[0.8, 0.6, 0.3, 0.1], &[1.0, 0.0, 1.0, 0.0]))
            .expect("update should succeed");
        precision
            .update((&[0.7], &[1.0]))
            .expect("update should succeed");
        assert_eq!(precision.compute().unwrap(), 2.0 / 3.0);

        precision.reset();
        assert_eq!(precision.compute(), None);
    }

    #[test]
    fn binary_precision_validates_targets() {
        let mut precision = BinaryPrecision::default();
        let err = precision
            .update((&[0.8], &[2.0]))
            .expect_err("invalid targets should fail");
        assert_eq!(
            err,
            MetricError::IncompatibleInput {
                expected: "target must be 0 or 1",
                got: "other",
            }
        );
    }

    #[test]
    fn binary_recall_computes_over_batches() {
        let mut recall = BinaryRecall::default();

        recall
            .update((&[0.8, 0.6, 0.3, 0.1], &[1.0, 0.0, 1.0, 0.0]))
            .expect("update should succeed");
        recall
            .update((&[0.7], &[1.0]))
            .expect("update should succeed");
        assert_eq!(recall.compute().unwrap(), 2.0 / 3.0);

        recall.reset();
        assert_eq!(recall.compute(), None);
    }

    #[test]
    fn binary_recall_validates_targets() {
        let mut recall = BinaryRecall::default();
        let err = recall
            .update((&[0.8], &[2.0]))
            .expect_err("invalid targets should fail");
        assert_eq!(
            err,
            MetricError::IncompatibleInput {
                expected: "target must be 0 or 1",
                got: "other",
            }
        );
    }
}