use crate::conformal::types::{conformal_quantile, PredictionSet};
#[derive(Debug, Clone)]
pub struct AciConformal {
pub alpha: f64,
pub alpha_t: f64,
pub gamma: f64,
pub history: Vec<bool>,
}
impl AciConformal {
pub fn new(alpha: f64, gamma: f64) -> Self {
Self {
alpha,
alpha_t: alpha,
gamma,
history: Vec::new(),
}
}
pub fn current_alpha(&self) -> f64 {
self.alpha_t.clamp(1e-6, 1.0 - 1e-6)
}
pub fn predict_with_current(&self, y_hat: f64, scores: &[f64]) -> Option<PredictionSet> {
if scores.is_empty() {
return None;
}
let q = conformal_quantile(scores, self.current_alpha());
Some(PredictionSet::interval(y_hat - q, y_hat + q))
}
pub fn update(&mut self, covered: bool) {
let indicator = if covered { 1.0 } else { 0.0 };
self.alpha_t += self.gamma * (self.alpha - indicator);
self.alpha_t = self.alpha_t.clamp(1e-6, 1.0 - 1e-6);
self.history.push(covered);
}
pub fn predict_and_update(
&mut self,
y_hat: f64,
y_true: f64,
scores: &[f64],
) -> (Option<PredictionSet>, bool) {
let ps = self.predict_with_current(y_hat, scores);
let covered = ps.as_ref().map_or(false, |s| s.contains_value(y_true));
self.update(covered);
(ps, covered)
}
pub fn running_coverage(&self) -> f64 {
running_coverage(&self.history)
}
pub fn recent_coverage(&self, window: usize) -> f64 {
if self.history.is_empty() {
return 0.0;
}
let start = self.history.len().saturating_sub(window);
let slice = &self.history[start..];
if slice.is_empty() {
0.0
} else {
slice.iter().filter(|&&b| b).count() as f64 / slice.len() as f64
}
}
}
pub fn running_coverage(history: &[bool]) -> f64 {
if history.is_empty() {
return 0.0;
}
history.iter().filter(|&&b| b).count() as f64 / history.len() as f64
}
pub fn coverage_drift(history: &[bool], window: usize) -> f64 {
if history.is_empty() || window == 0 {
return 0.0;
}
let overall = running_coverage(history);
let start = history.len().saturating_sub(window);
let slice = &history[start..];
if slice.is_empty() {
return 0.0;
}
let recent = slice.iter().filter(|&&b| b).count() as f64 / slice.len() as f64;
recent - overall
}
#[cfg(test)]
mod tests {
use super::*;
struct Lcg {
state: u64,
}
impl Lcg {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_f64(&mut self) -> f64 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
(self.state >> 33) as f64 / (u32::MAX as f64)
}
fn next_normal(&mut self) -> f64 {
let u1 = self.next_f64().max(1e-12);
let u2 = self.next_f64();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
}
#[test]
fn test_aci_alpha_increases_when_missed() {
let mut aci = AciConformal::new(0.1, 0.05);
let alpha_before = aci.current_alpha();
aci.update(false);
let alpha_after = aci.current_alpha();
assert!(
alpha_after > alpha_before,
"Alpha should increase when not covered: {} -> {}",
alpha_before,
alpha_after
);
}
#[test]
fn test_aci_alpha_decreases_when_covered() {
let mut aci = AciConformal::new(0.1, 0.05);
let alpha_before = aci.current_alpha();
aci.update(true);
let alpha_after = aci.current_alpha();
assert!(
alpha_after < alpha_before,
"Alpha should decrease when covered: {} -> {}",
alpha_before,
alpha_after
);
}
#[test]
fn test_aci_running_coverage_tracks() {
let mut aci = AciConformal::new(0.1, 0.05);
for i in 0..100 {
aci.update(i % 5 != 0); }
let cov = aci.running_coverage();
assert!((cov - 0.8).abs() < 0.01, "Coverage = {}", cov);
}
#[test]
fn test_aci_predict_returns_interval() {
let aci = AciConformal::new(0.1, 0.05);
let scores: Vec<f64> = (1..=10).map(|x| x as f64 * 0.1).collect();
let ps = aci.predict_with_current(5.0, &scores);
assert!(ps.is_some());
let ps = ps.expect("interval");
assert!(ps.lower < ps.upper);
}
#[test]
fn test_aci_online_adapts() {
let mut rng = Lcg::new(42);
let mut aci = AciConformal::new(0.1, 0.05);
let tiny_scores = vec![0.01_f64; 20];
let mut cumulative_covered = 0usize;
let n = 200usize;
for _ in 0..n {
let y_true = rng.next_normal() * 2.0; let y_hat = 0.0;
let ps = aci.predict_with_current(y_hat, &tiny_scores);
let covered = ps.map_or(false, |s| s.contains_value(y_true));
if covered {
cumulative_covered += 1;
}
aci.update(covered);
}
let final_alpha = aci.current_alpha();
assert!(aci.history.len() == n);
assert!(final_alpha > 0.0 && final_alpha < 1.0);
let _ = cumulative_covered;
}
#[test]
fn test_running_coverage_util() {
let hist = vec![true, true, false, true, false];
let cov = running_coverage(&hist);
assert!((cov - 0.6).abs() < 1e-10, "cov = {}", cov);
}
#[test]
fn test_coverage_drift_positive() {
let mut hist: Vec<bool> = vec![false; 50];
hist.extend(vec![true; 50]); let drift = coverage_drift(&hist, 50);
assert!(drift > 0.0, "drift = {}", drift);
}
#[test]
fn test_coverage_drift_negative() {
let mut hist: Vec<bool> = vec![true; 50];
hist.extend(vec![false; 50]);
let drift = coverage_drift(&hist, 50);
assert!(drift < 0.0, "drift = {}", drift);
}
}