#[derive(Debug, Clone, Copy, Default)]
#[non_exhaustive]
pub enum StepSchedule {
#[default]
Fixed,
Decaying {
beta: f64,
},
}
#[derive(Debug, Clone)]
pub struct AdaptiveConformalInterval {
target_alpha: f64,
alpha_t: f64,
gamma: f64,
schedule: StepSchedule,
n_covered: u64,
n_total: u64,
}
impl AdaptiveConformalInterval {
pub fn new(target_alpha: f64, gamma: f64) -> Self {
assert!(
target_alpha > 0.0 && target_alpha < 1.0,
"target_alpha must be in (0, 1), got {target_alpha}"
);
assert!(gamma > 0.0, "gamma must be > 0, got {gamma}");
Self {
target_alpha,
alpha_t: target_alpha,
gamma,
schedule: StepSchedule::Fixed,
n_covered: 0,
n_total: 0,
}
}
pub fn with_decaying_step(mut self, beta: f64) -> Self {
assert!(beta >= 0.0, "beta must be >= 0, got {beta}");
self.schedule = StepSchedule::Decaying { beta };
self
}
pub fn step_schedule(&self) -> StepSchedule {
self.schedule
}
pub fn update(&mut self, target: f64, lower: f64, upper: f64) {
self.n_total += 1;
let covered = target >= lower && target <= upper;
if covered {
self.n_covered += 1;
}
let effective_gamma = match self.schedule {
StepSchedule::Fixed => self.gamma,
StepSchedule::Decaying { beta } => self.gamma / (self.n_total as f64).powf(beta),
};
let err = if covered { 0.0 } else { 1.0 };
self.alpha_t += effective_gamma * (self.target_alpha - err);
self.alpha_t = self.alpha_t.clamp(0.001, 0.999);
}
pub fn effective_quantiles(&self) -> (f64, f64) {
(self.alpha_t / 2.0, 1.0 - self.alpha_t / 2.0)
}
pub fn empirical_coverage(&self) -> f64 {
if self.n_total == 0 {
return 0.0;
}
self.n_covered as f64 / self.n_total as f64
}
pub fn current_alpha(&self) -> f64 {
self.alpha_t
}
pub fn target_alpha(&self) -> f64 {
self.target_alpha
}
pub fn n_samples(&self) -> u64 {
self.n_total
}
pub fn reset(&mut self) {
self.alpha_t = self.target_alpha;
self.n_covered = 0;
self.n_total = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
#[test]
fn aci_perfect_coverage() {
let mut aci = AdaptiveConformalInterval::new(0.1, 0.01);
let initial = aci.current_alpha();
for _ in 0..100 {
aci.update(5.0, 0.0, 10.0); }
assert!(aci.current_alpha() > initial);
assert!((aci.empirical_coverage() - 1.0).abs() < EPS);
}
#[test]
fn aci_no_coverage() {
let mut aci = AdaptiveConformalInterval::new(0.1, 0.01);
let initial = aci.current_alpha();
for _ in 0..100 {
aci.update(100.0, 0.0, 10.0); }
assert!(aci.current_alpha() < initial);
assert!((aci.empirical_coverage()).abs() < EPS);
}
#[test]
fn aci_converges_to_target() {
let mut aci = AdaptiveConformalInterval::new(0.1, 0.005);
for i in 0..10_000 {
if i % 10 == 0 {
aci.update(100.0, 0.0, 10.0);
} else {
aci.update(5.0, 0.0, 10.0);
}
}
assert!((aci.current_alpha() - 0.1).abs() < 0.05);
}
#[test]
fn aci_effective_quantiles() {
let aci = AdaptiveConformalInterval::new(0.1, 0.01);
let (lo, hi) = aci.effective_quantiles();
assert!((lo - 0.05).abs() < EPS);
assert!((hi - 0.95).abs() < EPS);
}
#[test]
fn aci_clamp_lower() {
let mut aci = AdaptiveConformalInterval::new(0.1, 1.0); for _ in 0..1000 {
aci.update(100.0, 0.0, 10.0);
}
assert!((aci.current_alpha() - 0.001).abs() < EPS);
}
#[test]
fn aci_clamp_upper() {
let mut aci = AdaptiveConformalInterval::new(0.9, 1.0); for _ in 0..1000 {
aci.update(5.0, 0.0, 10.0);
}
assert!((aci.current_alpha() - 0.999).abs() < EPS);
}
#[test]
fn aci_reset() {
let mut aci = AdaptiveConformalInterval::new(0.1, 0.01);
for _ in 0..50 {
aci.update(5.0, 0.0, 10.0);
}
aci.reset();
assert_eq!(aci.n_samples(), 0);
assert!((aci.current_alpha() - 0.1).abs() < EPS);
assert!((aci.empirical_coverage()).abs() < EPS);
}
#[test]
#[should_panic(expected = "target_alpha must be in (0, 1)")]
fn aci_invalid_alpha() {
AdaptiveConformalInterval::new(0.0, 0.01);
}
#[test]
#[should_panic(expected = "gamma must be > 0")]
fn aci_invalid_gamma() {
AdaptiveConformalInterval::new(0.1, 0.0);
}
#[test]
fn decaying_step_has_smaller_updates_over_time() {
let mut aci = AdaptiveConformalInterval::new(0.1, 0.5).with_decaying_step(0.5);
aci.update(100.0, 0.0, 10.0); let alpha_after_1 = aci.current_alpha();
let delta_1 = (0.1 - alpha_after_1).abs();
let mut aci2 = AdaptiveConformalInterval::new(0.1, 0.5).with_decaying_step(0.5);
aci2.update(100.0, 0.0, 10.0); let alpha_mid = aci2.current_alpha();
aci2.update(100.0, 0.0, 10.0); let alpha_after_2 = aci2.current_alpha();
let delta_2 = (alpha_mid - alpha_after_2).abs();
assert!(
delta_2 < delta_1,
"second step ({delta_2}) should be smaller than first ({delta_1})"
);
}
#[test]
fn fixed_schedule_unchanged() {
let mut aci_fixed = AdaptiveConformalInterval::new(0.1, 0.01);
let mut aci_default = AdaptiveConformalInterval::new(0.1, 0.01);
for i in 0..100 {
if i % 5 == 0 {
aci_fixed.update(100.0, 0.0, 10.0);
aci_default.update(100.0, 0.0, 10.0);
} else {
aci_fixed.update(5.0, 0.0, 10.0);
aci_default.update(5.0, 0.0, 10.0);
}
}
assert!(
(aci_fixed.current_alpha() - aci_default.current_alpha()).abs() < EPS,
"fixed schedule should match default behavior"
);
}
#[test]
fn decaying_beta_zero_equals_fixed() {
let mut aci_fixed = AdaptiveConformalInterval::new(0.1, 0.01);
let mut aci_decay0 = AdaptiveConformalInterval::new(0.1, 0.01).with_decaying_step(0.0);
for i in 0..200 {
if i % 7 == 0 {
aci_fixed.update(100.0, 0.0, 10.0);
aci_decay0.update(100.0, 0.0, 10.0);
} else {
aci_fixed.update(5.0, 0.0, 10.0);
aci_decay0.update(5.0, 0.0, 10.0);
}
}
assert!(
(aci_fixed.current_alpha() - aci_decay0.current_alpha()).abs() < EPS,
"beta=0 decay should equal fixed: {} vs {}",
aci_fixed.current_alpha(),
aci_decay0.current_alpha()
);
}
}