use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct ConformalCalibrator {
scores: VecDeque<f64>,
max_scores: usize,
coverage: f64,
cached_quantile: Option<f64>,
}
impl ConformalCalibrator {
pub fn with_defaults() -> Self {
Self {
scores: VecDeque::with_capacity(1000),
max_scores: 1000,
coverage: 0.9,
cached_quantile: None,
}
}
pub fn record_outcome(&mut self, predicted: f64, completed: bool) {
let actual = if completed { 1.0 } else { 0.0 };
let score = (predicted - actual).abs();
self.scores.push_back(score);
if self.scores.len() > self.max_scores {
self.scores.pop_front();
}
self.cached_quantile = None;
}
pub fn prediction_interval(&mut self, predicted: f64) -> (f64, f64) {
if self.scores.is_empty() {
return (0.0, 1.0);
}
let q = self.quantile();
let lower = (predicted - q).max(0.0);
let upper = (predicted + q).min(1.0);
(lower, upper)
}
fn quantile(&mut self) -> f64 {
if let Some(q) = self.cached_quantile {
return q;
}
let n = self.scores.len();
if n == 0 {
return 1.0;
}
let mut sorted: Vec<f64> = self.scores.iter().copied().collect();
sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let alpha = 1.0 - self.coverage;
let idx_f = ((n + 1) as f64 * alpha).ceil();
let idx = (idx_f as usize).saturating_sub(1).min(n - 1);
let q = sorted[n - 1 - idx.min(n - 1)];
self.cached_quantile = Some(q);
q
}
pub fn sample_count(&self) -> usize {
self.scores.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_returns_full_interval() {
let mut cal = ConformalCalibrator::with_defaults();
let (lower, upper) = cal.prediction_interval(0.5);
assert!(
(lower - 0.0).abs() < 1e-10,
"Empty calibrator lower should be 0.0, got {lower}"
);
assert!(
(upper - 1.0).abs() < 1e-10,
"Empty calibrator upper should be 1.0, got {upper}"
);
}
#[test]
fn test_perfect_predictions_narrow_interval() {
let mut cal = ConformalCalibrator::with_defaults();
for _ in 0..100 {
cal.record_outcome(0.95, true);
}
for _ in 0..100 {
cal.record_outcome(0.05, false);
}
let (lower, upper) = cal.prediction_interval(0.7);
let width = upper - lower;
assert!(
width < 0.5,
"Perfect predictions should yield narrow interval, got width={width}"
);
}
#[test]
fn test_bad_predictions_wide_interval() {
let mut cal = ConformalCalibrator::with_defaults();
for _ in 0..100 {
cal.record_outcome(0.9, false);
}
for _ in 0..100 {
cal.record_outcome(0.1, true);
}
let (lower, upper) = cal.prediction_interval(0.5);
let width = upper - lower;
assert!(
width > 0.5,
"Bad predictions should yield wide interval, got width={width}"
);
}
#[test]
fn test_buffer_bounded() {
let mut cal = ConformalCalibrator::with_defaults();
for i in 0..2000 {
cal.record_outcome(i as f64 / 2000.0, i % 2 == 0);
}
assert_eq!(
cal.sample_count(),
1000,
"Score buffer should be bounded at max_scores"
);
}
#[test]
fn test_lower_le_upper_invariant() {
let mut cal = ConformalCalibrator::with_defaults();
for i in 0..50 {
cal.record_outcome(i as f64 / 50.0, i % 3 == 0);
}
for p in [0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0] {
let (lower, upper) = cal.prediction_interval(p);
assert!(
lower <= upper,
"lower ({lower}) should be <= upper ({upper}) for predicted={p}"
);
assert!(lower >= 0.0, "lower should be >= 0.0, got {lower}");
assert!(upper <= 1.0, "upper should be <= 1.0, got {upper}");
}
}
#[test]
fn test_interval_clamped_to_unit() {
let mut cal = ConformalCalibrator::with_defaults();
for _ in 0..100 {
cal.record_outcome(1.0, false); }
let (lower, upper) = cal.prediction_interval(0.5);
assert!(lower >= 0.0, "Lower should be clamped >= 0.0, got {lower}");
assert!(upper <= 1.0, "Upper should be clamped <= 1.0, got {upper}");
}
}