#[derive(Debug, Clone)]
pub struct ClassificationMetrics {
n_total: u64,
n_correct: u64,
tp: u64,
fp: u64,
fn_count: u64,
sum_log_loss: f64,
}
const CLIP_MIN: f64 = 1e-15;
const CLIP_MAX: f64 = 1.0 - 1e-15;
impl ClassificationMetrics {
pub fn new() -> Self {
Self {
n_total: 0,
n_correct: 0,
tp: 0,
fp: 0,
fn_count: 0,
sum_log_loss: 0.0,
}
}
pub fn update(&mut self, target: usize, predicted: usize, predicted_proba: f64) {
self.n_total += 1;
if target == predicted {
self.n_correct += 1;
}
if target == 1 && predicted == 1 {
self.tp += 1;
} else if target == 0 && predicted == 1 {
self.fp += 1;
} else if target == 1 && predicted == 0 {
self.fn_count += 1;
}
let p = predicted_proba.clamp(CLIP_MIN, CLIP_MAX);
let y = if target == 1 { 1.0 } else { 0.0 };
let sample_loss = -(y * p.ln() + (1.0 - y) * (1.0 - p).ln());
self.sum_log_loss += sample_loss;
}
pub fn accuracy(&self) -> f64 {
if self.n_total == 0 {
return 0.0;
}
self.n_correct as f64 / self.n_total as f64
}
pub fn precision(&self) -> f64 {
let denom = self.tp + self.fp;
if denom == 0 {
return 0.0;
}
self.tp as f64 / denom as f64
}
pub fn recall(&self) -> f64 {
let denom = self.tp + self.fn_count;
if denom == 0 {
return 0.0;
}
self.tp as f64 / denom as f64
}
pub fn f1(&self) -> f64 {
let p = self.precision();
let r = self.recall();
let sum = p + r;
if sum == 0.0 {
return 0.0;
}
2.0 * p * r / sum
}
pub fn log_loss(&self) -> f64 {
if self.n_total == 0 {
return 0.0;
}
self.sum_log_loss / self.n_total as f64
}
pub fn n_samples(&self) -> u64 {
self.n_total
}
pub fn reset(&mut self) {
self.n_total = 0;
self.n_correct = 0;
self.tp = 0;
self.fp = 0;
self.fn_count = 0;
self.sum_log_loss = 0.0;
}
}
impl Default for ClassificationMetrics {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPS
}
#[test]
fn empty_state_returns_zeros() {
let m = ClassificationMetrics::new();
assert_eq!(m.accuracy(), 0.0);
assert_eq!(m.precision(), 0.0);
assert_eq!(m.recall(), 0.0);
assert_eq!(m.f1(), 0.0);
assert_eq!(m.log_loss(), 0.0);
assert_eq!(m.n_samples(), 0);
}
#[test]
fn perfect_accuracy() {
let mut m = ClassificationMetrics::new();
m.update(1, 1, 0.99);
m.update(0, 0, 0.01);
m.update(1, 1, 0.95);
m.update(0, 0, 0.05);
assert!(approx_eq(m.accuracy(), 1.0));
}
#[test]
fn fifty_percent_accuracy() {
let mut m = ClassificationMetrics::new();
m.update(1, 1, 0.9); m.update(0, 1, 0.8); m.update(1, 0, 0.2); m.update(0, 0, 0.1); assert!(approx_eq(m.accuracy(), 0.5));
}
#[test]
fn precision_all_true_positives() {
let mut m = ClassificationMetrics::new();
m.update(1, 1, 0.9);
m.update(1, 1, 0.8);
m.update(0, 0, 0.1); assert!(approx_eq(m.precision(), 1.0));
}
#[test]
fn precision_all_false_positives() {
let mut m = ClassificationMetrics::new();
m.update(0, 1, 0.9); m.update(0, 1, 0.8); assert!(approx_eq(m.precision(), 0.0));
}
#[test]
fn recall_all_positives_detected() {
let mut m = ClassificationMetrics::new();
m.update(1, 1, 0.9);
m.update(1, 1, 0.8);
assert!(approx_eq(m.recall(), 1.0));
}
#[test]
fn recall_no_positives_detected() {
let mut m = ClassificationMetrics::new();
m.update(1, 0, 0.3);
m.update(1, 0, 0.2);
assert!(approx_eq(m.recall(), 0.0));
}
#[test]
fn f1_mixed_case() {
let mut m = ClassificationMetrics::new();
m.update(1, 1, 0.9); m.update(1, 0, 0.3); m.update(0, 1, 0.7); m.update(0, 0, 0.1); assert!(approx_eq(m.f1(), 0.5));
}
#[test]
fn f1_zero_when_no_predictions() {
let mut m = ClassificationMetrics::new();
m.update(0, 0, 0.1);
m.update(0, 0, 0.2);
assert_eq!(m.f1(), 0.0);
}
#[test]
fn log_loss_near_perfect_predictions() {
let mut m = ClassificationMetrics::new();
m.update(1, 1, 0.999);
m.update(0, 0, 0.001);
assert!(m.log_loss() < 0.01);
assert!(m.log_loss() > 0.0);
}
#[test]
fn log_loss_random_predictions() {
let mut m = ClassificationMetrics::new();
m.update(1, 1, 0.5);
m.update(0, 0, 0.5);
let expected = 2.0_f64.ln();
assert!(approx_eq(m.log_loss(), expected));
}
#[test]
fn log_loss_clips_extreme_probabilities() {
let mut m = ClassificationMetrics::new();
m.update(1, 1, 1.0);
m.update(0, 0, 0.0);
assert!(m.log_loss().is_finite());
assert!(m.log_loss() >= 0.0);
}
#[test]
fn reset_clears_state() {
let mut m = ClassificationMetrics::new();
m.update(1, 1, 0.9);
m.update(0, 1, 0.8);
m.update(1, 0, 0.2);
m.reset();
assert_eq!(m.n_samples(), 0);
assert_eq!(m.accuracy(), 0.0);
assert_eq!(m.precision(), 0.0);
assert_eq!(m.recall(), 0.0);
assert_eq!(m.f1(), 0.0);
assert_eq!(m.log_loss(), 0.0);
}
#[test]
fn default_is_empty() {
let m = ClassificationMetrics::default();
assert_eq!(m.n_samples(), 0);
assert_eq!(m.accuracy(), 0.0);
}
#[test]
fn log_loss_asymmetric() {
let mut m = ClassificationMetrics::new();
m.update(1, 0, 0.01); let expected = 100.0_f64.ln();
assert!(approx_eq(m.log_loss(), expected));
}
}