use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct StreamingAUC {
window_size: usize,
buffer: VecDeque<(f64, bool)>,
n_samples_total: u64,
}
impl StreamingAUC {
pub fn new(window_size: usize) -> Self {
assert!(window_size > 0, "window_size must be > 0");
Self {
window_size,
buffer: VecDeque::with_capacity(window_size + 1),
n_samples_total: 0,
}
}
pub fn update(&mut self, score: f64, is_positive: bool) {
if self.buffer.len() == self.window_size {
self.buffer.pop_front();
}
self.buffer.push_back((score, is_positive));
self.n_samples_total += 1;
}
pub fn auc(&self) -> Option<f64> {
if self.buffer.is_empty() {
return None;
}
let mut pos_scores: Vec<f64> = Vec::new();
let mut neg_scores: Vec<f64> = Vec::new();
for &(score, is_positive) in &self.buffer {
if is_positive {
pos_scores.push(score);
} else {
neg_scores.push(score);
}
}
let n_pos = pos_scores.len();
let n_neg = neg_scores.len();
if n_pos == 0 || n_neg == 0 {
return None;
}
let mut u = 0.0_f64;
for &ps in &pos_scores {
for &ns in &neg_scores {
if ps > ns {
u += 1.0;
} else if (ps - ns).abs() < f64::EPSILON {
u += 0.5;
}
}
}
Some(u / (n_pos as f64 * n_neg as f64))
}
pub fn window_count(&self) -> usize {
self.buffer.len()
}
pub fn n_samples(&self) -> u64 {
self.n_samples_total
}
pub fn n_positive(&self) -> usize {
self.buffer.iter().filter(|(_, is_pos)| *is_pos).count()
}
pub fn n_negative(&self) -> usize {
self.buffer.iter().filter(|(_, is_pos)| !*is_pos).count()
}
pub fn is_full(&self) -> bool {
self.buffer.len() == self.window_size
}
pub fn window_size(&self) -> usize {
self.window_size
}
pub fn reset(&mut self) {
self.buffer.clear();
self.n_samples_total = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPS
}
#[test]
fn perfect_classifier_auc_is_one() {
let mut auc = StreamingAUC::new(100);
auc.update(0.95, true);
auc.update(0.90, true);
auc.update(0.85, true);
auc.update(0.20, false);
auc.update(0.15, false);
auc.update(0.10, false);
assert_eq!(auc.auc(), Some(1.0));
}
#[test]
fn perfectly_wrong_classifier_auc_is_zero() {
let mut auc = StreamingAUC::new(100);
auc.update(0.10, true);
auc.update(0.15, true);
auc.update(0.85, false);
auc.update(0.90, false);
assert_eq!(auc.auc(), Some(0.0));
}
#[test]
fn random_classifier_auc_near_half() {
let mut auc = StreamingAUC::new(100);
auc.update(0.1, true);
auc.update(0.2, false);
auc.update(0.3, true);
auc.update(0.4, false);
auc.update(0.5, true);
auc.update(0.6, false);
auc.update(0.7, true);
auc.update(0.8, false);
auc.update(0.9, true);
assert!(approx_eq(auc.auc().unwrap(), 0.5));
}
#[test]
fn known_small_example_exact_auc() {
let mut auc = StreamingAUC::new(100);
auc.update(0.8, true);
auc.update(0.4, false);
auc.update(0.6, true);
auc.update(0.3, false);
auc.update(0.7, false);
let expected = 5.0 / 6.0;
assert!(approx_eq(auc.auc().unwrap(), expected));
}
#[test]
fn ties_handled_correctly() {
let mut auc = StreamingAUC::new(100);
auc.update(0.5, true);
auc.update(0.5, false);
assert!(approx_eq(auc.auc().unwrap(), 0.5));
}
#[test]
fn ties_mixed_with_clear_pairs() {
let mut auc = StreamingAUC::new(100);
auc.update(0.7, true);
auc.update(0.5, false);
auc.update(0.5, true);
auc.update(0.3, false);
assert!(approx_eq(auc.auc().unwrap(), 0.875));
}
#[test]
fn empty_buffer_returns_none() {
let auc = StreamingAUC::new(100);
assert_eq!(auc.auc(), None);
}
#[test]
fn single_sample_returns_none() {
let mut auc = StreamingAUC::new(100);
auc.update(0.9, true);
assert_eq!(auc.auc(), None);
}
#[test]
fn all_positive_returns_none() {
let mut auc = StreamingAUC::new(100);
auc.update(0.9, true);
auc.update(0.8, true);
auc.update(0.7, true);
assert_eq!(auc.auc(), None);
}
#[test]
fn all_negative_returns_none() {
let mut auc = StreamingAUC::new(100);
auc.update(0.3, false);
auc.update(0.2, false);
assert_eq!(auc.auc(), None);
}
#[test]
fn window_eviction_drops_old_samples() {
let mut auc = StreamingAUC::new(3);
auc.update(0.9, true); auc.update(0.1, false); auc.update(0.8, true); assert_eq!(auc.window_count(), 3);
assert!(auc.is_full());
auc.update(0.1, true);
auc.update(0.8, false);
auc.update(0.9, false);
assert_eq!(auc.window_count(), 3);
assert_eq!(auc.n_samples(), 6);
assert_eq!(auc.n_positive(), 1);
assert_eq!(auc.n_negative(), 2);
assert_eq!(auc.auc(), Some(0.0));
}
#[test]
fn n_samples_tracks_total_including_evicted() {
let mut auc = StreamingAUC::new(2);
auc.update(0.9, true);
auc.update(0.1, false);
assert_eq!(auc.n_samples(), 2);
assert_eq!(auc.window_count(), 2);
auc.update(0.5, true);
assert_eq!(auc.n_samples(), 3);
assert_eq!(auc.window_count(), 2);
auc.update(0.4, false);
assert_eq!(auc.n_samples(), 4);
assert_eq!(auc.window_count(), 2);
}
#[test]
fn reset_clears_everything() {
let mut auc = StreamingAUC::new(100);
auc.update(0.9, true);
auc.update(0.1, false);
auc.update(0.8, true);
auc.reset();
assert_eq!(auc.window_count(), 0);
assert_eq!(auc.n_samples(), 0);
assert_eq!(auc.n_positive(), 0);
assert_eq!(auc.n_negative(), 0);
assert!(!auc.is_full());
assert_eq!(auc.auc(), None);
}
#[test]
fn window_size_accessor() {
let auc = StreamingAUC::new(42);
assert_eq!(auc.window_size(), 42);
}
#[test]
fn is_full_transitions() {
let mut auc = StreamingAUC::new(2);
assert!(!auc.is_full());
auc.update(0.5, true);
assert!(!auc.is_full());
auc.update(0.5, false);
assert!(auc.is_full());
auc.update(0.7, true);
assert!(auc.is_full());
}
#[test]
#[should_panic(expected = "window_size must be > 0")]
fn zero_window_panics() {
StreamingAUC::new(0);
}
#[test]
fn positive_negative_counts() {
let mut auc = StreamingAUC::new(100);
auc.update(0.9, true);
auc.update(0.8, true);
auc.update(0.3, false);
assert_eq!(auc.n_positive(), 2);
assert_eq!(auc.n_negative(), 1);
assert_eq!(auc.window_count(), 3);
}
#[test]
fn window_eviction_updates_class_counts() {
let mut auc = StreamingAUC::new(2);
auc.update(0.9, true);
auc.update(0.1, false);
assert_eq!(auc.n_positive(), 1);
assert_eq!(auc.n_negative(), 1);
auc.update(0.3, false);
assert_eq!(auc.n_positive(), 0);
assert_eq!(auc.n_negative(), 2);
assert_eq!(auc.auc(), None); }
}