use std::collections::VecDeque;
#[derive(Debug)]
pub struct OutcomeTracker {
observations: VecDeque<Observation>,
max_observations: usize,
true_positives: u64,
false_positives: u64,
true_negatives: u64,
false_negatives: u64,
}
#[derive(Debug, Clone)]
pub struct Observation {
pub key: String,
pub query_time: u64,
pub predicted_freshness: f32,
pub actual_error: f32,
pub acceptable_error: f32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClassifiedOutcomes {
pub tp: u64,
pub fp: u64,
pub tn: u64,
pub fn_: u64,
}
impl OutcomeTracker {
pub fn new() -> Self {
Self::with_capacity(10000)
}
pub fn with_capacity(max_observations: usize) -> Self {
Self {
observations: VecDeque::with_capacity(max_observations),
max_observations,
true_positives: 0,
false_positives: 0,
true_negatives: 0,
false_negatives: 0,
}
}
pub fn record(
&mut self,
key: &str,
predicted_freshness: f32,
actual_error: f32,
acceptable_error: f32,
) {
let obs = Observation {
key: key.to_string(),
query_time: now_micros(),
predicted_freshness,
actual_error,
acceptable_error,
};
let predicted_fresh = predicted_freshness > 0.5;
let actual_fresh = actual_error < acceptable_error;
match (predicted_fresh, actual_fresh) {
(true, true) => self.true_negatives += 1,
(true, false) => self.false_negatives += 1, (false, true) => self.false_positives += 1,
(false, false) => self.true_positives += 1,
}
if self.observations.len() >= self.max_observations {
self.observations.pop_front();
}
self.observations.push_back(obs);
}
pub fn collect_recent(&self) -> Vec<Observation> {
self.observations.iter().cloned().collect()
}
pub fn classify_recent(&self, threshold: f32) -> ClassifiedOutcomes {
let mut tp = 0;
let mut fp = 0;
let mut tn = 0;
let mut fn_ = 0;
for obs in &self.observations {
let predicted_fresh = obs.predicted_freshness > threshold;
let actual_fresh = obs.actual_error < obs.acceptable_error;
match (predicted_fresh, actual_fresh) {
(true, true) => tn += 1,
(true, false) => fn_ += 1,
(false, true) => fp += 1,
(false, false) => tp += 1,
}
}
ClassifiedOutcomes { tp, fp, tn, fn_ }
}
pub fn cumulative(&self) -> ClassifiedOutcomes {
ClassifiedOutcomes {
tp: self.true_positives,
fp: self.false_positives,
tn: self.true_negatives,
fn_: self.false_negatives,
}
}
pub fn compute_loss(&self, fn_weight: f32, fp_weight: f32) -> f32 {
fn_weight * self.false_negatives as f32 + fp_weight * self.false_positives as f32
}
pub fn false_negative_rate(&self) -> f32 {
let total_actual_stale = self.true_positives + self.false_negatives;
if total_actual_stale == 0 {
return 0.0;
}
self.false_negatives as f32 / total_actual_stale as f32
}
pub fn false_positive_rate(&self) -> f32 {
let total_actual_fresh = self.true_negatives + self.false_positives;
if total_actual_fresh == 0 {
return 0.0;
}
self.false_positives as f32 / total_actual_fresh as f32
}
pub fn clear(&mut self) {
self.observations.clear();
self.true_positives = 0;
self.false_positives = 0;
self.true_negatives = 0;
self.false_negatives = 0;
}
}
impl Default for OutcomeTracker {
fn default() -> Self {
Self::new()
}
}
fn now_micros() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_micros() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_outcome_classification() {
let mut tracker = OutcomeTracker::new();
tracker.record("key1", 0.3, 0.5, 0.1);
tracker.record("key2", 0.8, 0.5, 0.1);
tracker.record("key3", 0.9, 0.05, 0.1);
let outcomes = tracker.cumulative();
assert_eq!(outcomes.tp, 1);
assert_eq!(outcomes.fn_, 1);
assert_eq!(outcomes.tn, 1);
assert_eq!(outcomes.fp, 0);
}
#[test]
fn test_false_negative_rate() {
let mut tracker = OutcomeTracker::new();
tracker.record("key1", 0.3, 0.5, 0.1); tracker.record("key2", 0.8, 0.5, 0.1);
assert!((tracker.false_negative_rate() - 0.5).abs() < 0.001);
}
}