#![forbid(unsafe_code)]
use std::collections::{HashMap, VecDeque};
use std::fmt;
use ftui_render::diff_strategy::DiffStrategy;
use crate::terminal_writer::ScreenMode;
#[derive(Debug, Clone)]
pub struct ConformalConfig {
pub alpha: f64,
pub min_samples: usize,
pub window_size: usize,
pub q_default: f64,
}
impl Default for ConformalConfig {
fn default() -> Self {
Self {
alpha: 0.05,
min_samples: 20,
window_size: 256,
q_default: 10_000.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BucketKey {
pub mode: ModeBucket,
pub diff: DiffBucket,
pub size_bucket: u8,
}
impl BucketKey {
pub fn from_context(
screen_mode: ScreenMode,
diff_strategy: DiffStrategy,
cols: u16,
rows: u16,
) -> Self {
Self {
mode: ModeBucket::from_screen_mode(screen_mode),
diff: DiffBucket::from(diff_strategy),
size_bucket: size_bucket(cols, rows),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ModeBucket {
Inline,
InlineAuto,
AltScreen,
}
impl ModeBucket {
pub fn as_str(self) -> &'static str {
match self {
Self::Inline => "inline",
Self::InlineAuto => "inline_auto",
Self::AltScreen => "altscreen",
}
}
pub fn from_screen_mode(mode: ScreenMode) -> Self {
match mode {
ScreenMode::Inline { .. } => Self::Inline,
ScreenMode::InlineAuto { .. } => Self::InlineAuto,
ScreenMode::AltScreen => Self::AltScreen,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DiffBucket {
Full,
DirtyRows,
FullRedraw,
}
impl DiffBucket {
pub fn as_str(self) -> &'static str {
match self {
Self::Full => "full",
Self::DirtyRows => "dirty",
Self::FullRedraw => "redraw",
}
}
}
impl From<DiffStrategy> for DiffBucket {
fn from(strategy: DiffStrategy) -> Self {
match strategy {
DiffStrategy::Full => Self::Full,
DiffStrategy::DirtyRows => Self::DirtyRows,
DiffStrategy::FullRedraw => Self::FullRedraw,
}
}
}
impl fmt::Display for BucketKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}:{}:{}",
self.mode.as_str(),
self.diff.as_str(),
self.size_bucket
)
}
}
#[derive(Debug, Clone)]
pub struct ConformalPrediction {
pub upper_us: f64,
pub risk: bool,
pub confidence: f64,
pub bucket: BucketKey,
pub sample_count: usize,
pub quantile: f64,
pub fallback_level: u8,
pub window_size: usize,
pub reset_count: u64,
pub y_hat: f64,
pub budget_us: f64,
}
impl ConformalPrediction {
#[must_use]
pub fn to_jsonl(&self) -> String {
format!(
r#"{{"schema":"conformal-v1","upper_us":{:.1},"risk":{},"confidence":{:.4},"bucket":"{}","samples":{},"quantile":{:.2},"fallback_level":{},"window":{},"resets":{},"y_hat":{:.1},"budget_us":{:.1}}}"#,
self.upper_us,
self.risk,
self.confidence,
self.bucket,
self.sample_count,
self.quantile,
self.fallback_level,
self.window_size,
self.reset_count,
self.y_hat,
self.budget_us,
)
}
}
#[derive(Debug, Clone)]
pub struct ConformalUpdate {
pub residual: f64,
pub bucket: BucketKey,
pub sample_count: usize,
}
#[derive(Debug, Default)]
struct BucketState {
residuals: VecDeque<f64>,
}
impl BucketState {
fn push(&mut self, residual: f64, window_size: usize) {
self.residuals.push_back(residual);
while self.residuals.len() > window_size {
self.residuals.pop_front();
}
}
}
#[derive(Debug)]
pub struct ConformalPredictor {
config: ConformalConfig,
buckets: HashMap<BucketKey, BucketState>,
reset_count: u64,
}
impl ConformalPredictor {
pub fn new(config: ConformalConfig) -> Self {
Self {
config,
buckets: HashMap::new(),
reset_count: 0,
}
}
pub fn config(&self) -> &ConformalConfig {
&self.config
}
pub fn bucket_samples(&self, key: BucketKey) -> usize {
self.buckets
.get(&key)
.map(|state| state.residuals.len())
.unwrap_or(0)
}
pub fn reset_all(&mut self) {
self.buckets.clear();
self.reset_count += 1;
}
pub fn reset_bucket(&mut self, key: BucketKey) {
if let Some(state) = self.buckets.get_mut(&key) {
state.residuals.clear();
self.reset_count += 1;
}
}
pub fn observe(&mut self, key: BucketKey, y_hat_us: f64, observed_us: f64) -> ConformalUpdate {
let residual = observed_us - y_hat_us;
if !residual.is_finite() {
return ConformalUpdate {
residual,
bucket: key,
sample_count: self.bucket_samples(key),
};
}
let window_size = self.config.window_size.max(1);
let state = self.buckets.entry(key).or_default();
state.push(residual, window_size);
ConformalUpdate {
residual,
bucket: key,
sample_count: state.residuals.len(),
}
}
pub fn predict(&self, key: BucketKey, y_hat_us: f64, budget_us: f64) -> ConformalPrediction {
let span = tracing::info_span!(
"conformal.predict",
calibration_set_size = tracing::field::Empty,
predicted_upper_bound_us = tracing::field::Empty,
frame_budget_us = budget_us,
coverage_alpha = self.config.alpha,
gate_triggered = tracing::field::Empty,
);
let _guard = span.enter();
let QuantileDecision {
quantile,
sample_count,
fallback_level,
} = self.quantile_for(key);
let upper_us = y_hat_us + quantile.max(0.0);
let risk = upper_us > budget_us;
span.record("calibration_set_size", sample_count);
span.record("predicted_upper_bound_us", upper_us);
span.record("gate_triggered", risk);
tracing::debug!(
bucket = %key,
y_hat_us,
quantile,
interval_width_us = quantile.max(0.0),
fallback_level,
sample_count,
"prediction interval"
);
ConformalPrediction {
upper_us,
risk,
confidence: 1.0 - self.config.alpha,
bucket: key,
sample_count,
quantile,
fallback_level,
window_size: self.config.window_size,
reset_count: self.reset_count,
y_hat: y_hat_us,
budget_us,
}
}
fn quantile_for(&self, key: BucketKey) -> QuantileDecision {
let min_samples = self.config.min_samples.max(1);
let exact = self.collect_exact(key);
if exact.len() >= min_samples {
return QuantileDecision::new(self.config.alpha, exact, 0);
}
let mode_diff = self.collect_mode_diff(key.mode, key.diff);
if mode_diff.len() >= min_samples {
return QuantileDecision::new(self.config.alpha, mode_diff, 1);
}
let mode_only = self.collect_mode(key.mode);
if mode_only.len() >= min_samples {
return QuantileDecision::new(self.config.alpha, mode_only, 2);
}
let global = self.collect_all();
if !global.is_empty() {
return QuantileDecision::new(self.config.alpha, global, 3);
}
QuantileDecision {
quantile: self.config.q_default,
sample_count: 0,
fallback_level: 3,
}
}
fn collect_exact(&self, key: BucketKey) -> Vec<f64> {
self.buckets
.get(&key)
.map(|state| state.residuals.iter().copied().collect())
.unwrap_or_default()
}
fn collect_mode_diff(&self, mode: ModeBucket, diff: DiffBucket) -> Vec<f64> {
let mut values = Vec::new();
for (key, state) in &self.buckets {
if key.mode == mode && key.diff == diff {
values.extend(state.residuals.iter().copied());
}
}
values
}
fn collect_mode(&self, mode: ModeBucket) -> Vec<f64> {
let mut values = Vec::new();
for (key, state) in &self.buckets {
if key.mode == mode {
values.extend(state.residuals.iter().copied());
}
}
values
}
fn collect_all(&self) -> Vec<f64> {
let mut values = Vec::new();
for state in self.buckets.values() {
values.extend(state.residuals.iter().copied());
}
values
}
}
#[derive(Debug)]
struct QuantileDecision {
quantile: f64,
sample_count: usize,
fallback_level: u8,
}
impl QuantileDecision {
fn new(alpha: f64, mut residuals: Vec<f64>, fallback_level: u8) -> Self {
let quantile = conformal_quantile(alpha, &mut residuals);
Self {
quantile,
sample_count: residuals.len(),
fallback_level,
}
}
}
fn conformal_quantile(alpha: f64, residuals: &mut [f64]) -> f64 {
if residuals.is_empty() {
return 0.0;
}
residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = residuals.len();
let rank = ((n as f64 + 1.0) * (1.0 - alpha)).ceil() as usize;
let idx = rank.saturating_sub(1).min(n - 1);
residuals[idx]
}
fn size_bucket(cols: u16, rows: u16) -> u8 {
let area = cols as u32 * rows as u32;
if area == 0 {
return 0;
}
(31 - area.leading_zeros()) as u8
}
#[cfg(test)]
mod tests {
use super::*;
fn test_key(cols: u16, rows: u16) -> BucketKey {
BucketKey::from_context(
ScreenMode::Inline { ui_height: 4 },
DiffStrategy::Full,
cols,
rows,
)
}
#[test]
fn quantile_n_plus_1_rule() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.2,
min_samples: 1,
window_size: 10,
q_default: 0.0,
});
let key = test_key(80, 24);
predictor.observe(key, 0.0, 1.0);
predictor.observe(key, 0.0, 2.0);
predictor.observe(key, 0.0, 3.0);
let decision = predictor.predict(key, 0.0, 1_000.0);
assert_eq!(decision.quantile, 3.0);
}
#[test]
fn fallback_hierarchy_mode_diff() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 4,
window_size: 16,
q_default: 0.0,
});
let key_a = test_key(80, 24);
for value in [1.0, 2.0, 3.0, 4.0] {
predictor.observe(key_a, 0.0, value);
}
let key_b = test_key(120, 40);
let decision = predictor.predict(key_b, 0.0, 1_000.0);
assert_eq!(decision.fallback_level, 1);
assert_eq!(decision.sample_count, 4);
}
#[test]
fn fallback_hierarchy_mode_only() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 3,
window_size: 16,
q_default: 0.0,
});
let key_dirty = BucketKey::from_context(
ScreenMode::Inline { ui_height: 4 },
DiffStrategy::DirtyRows,
80,
24,
);
for value in [10.0, 20.0, 30.0] {
predictor.observe(key_dirty, 0.0, value);
}
let key_full = BucketKey::from_context(
ScreenMode::Inline { ui_height: 4 },
DiffStrategy::Full,
120,
40,
);
let decision = predictor.predict(key_full, 0.0, 1_000.0);
assert_eq!(decision.fallback_level, 2);
assert_eq!(decision.sample_count, 3);
}
#[test]
fn window_enforced() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 3,
q_default: 0.0,
});
let key = test_key(80, 24);
for value in [1.0, 2.0, 3.0, 4.0, 5.0] {
predictor.observe(key, 0.0, value);
}
assert_eq!(predictor.bucket_samples(key), 3);
}
#[test]
fn predict_uses_default_when_empty() {
let predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 2,
window_size: 4,
q_default: 42.0,
});
let key = test_key(120, 40);
let prediction = predictor.predict(key, 5.0, 10_000.0);
assert_eq!(prediction.quantile, 42.0);
assert_eq!(prediction.sample_count, 0);
assert_eq!(prediction.fallback_level, 3);
}
#[test]
fn bucket_isolation_by_size() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.2,
min_samples: 2,
window_size: 10,
q_default: 0.0,
});
let small = test_key(40, 10);
predictor.observe(small, 0.0, 1.0);
predictor.observe(small, 0.0, 2.0);
let large = test_key(200, 60);
predictor.observe(large, 0.0, 10.0);
predictor.observe(large, 0.0, 12.0);
let prediction = predictor.predict(large, 0.0, 1_000.0);
assert_eq!(prediction.fallback_level, 0);
assert_eq!(prediction.sample_count, 2);
assert_eq!(prediction.quantile, 12.0);
}
#[test]
fn reset_clears_bucket_and_raises_reset_count() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 7.0,
});
let key = test_key(80, 24);
predictor.observe(key, 0.0, 3.0);
assert_eq!(predictor.bucket_samples(key), 1);
predictor.reset_bucket(key);
assert_eq!(predictor.bucket_samples(key), 0);
let prediction = predictor.predict(key, 0.0, 1_000.0);
assert_eq!(prediction.quantile, 7.0);
assert_eq!(prediction.reset_count, 1);
}
#[test]
fn reset_all_forces_conservative_fallback() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 9.0,
});
let key = test_key(80, 24);
predictor.observe(key, 0.0, 2.0);
predictor.reset_all();
let prediction = predictor.predict(key, 0.0, 1_000.0);
assert_eq!(prediction.quantile, 9.0);
assert_eq!(prediction.sample_count, 0);
assert_eq!(prediction.fallback_level, 3);
assert_eq!(prediction.reset_count, 1);
}
#[test]
fn size_bucket_log2_area() {
let a = size_bucket(8, 8); let b = size_bucket(8, 16); assert_eq!(a, 6);
assert_eq!(b, 7);
}
#[test]
fn size_bucket_zero_area() {
assert_eq!(size_bucket(0, 0), 0);
assert_eq!(size_bucket(0, 24), 0);
assert_eq!(size_bucket(80, 0), 0);
}
#[test]
fn size_bucket_one_by_one() {
assert_eq!(size_bucket(1, 1), 0); }
#[test]
fn size_bucket_typical_terminals() {
let b80 = size_bucket(80, 24); let b120 = size_bucket(120, 40); assert_eq!(b80, 10);
assert_eq!(b120, 12);
}
#[test]
fn conformal_quantile_empty() {
let mut data: Vec<f64> = vec![];
assert_eq!(conformal_quantile(0.1, &mut data), 0.0);
}
#[test]
fn conformal_quantile_single_element() {
let mut data = vec![42.0];
assert_eq!(conformal_quantile(0.1, &mut data), 42.0);
}
#[test]
fn conformal_quantile_sorted_data() {
let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let q = conformal_quantile(0.5, &mut data);
assert_eq!(q, 3.0);
}
#[test]
fn conformal_quantile_alpha_half() {
let mut data = vec![10.0, 20.0, 30.0, 40.0];
let q = conformal_quantile(0.5, &mut data);
assert_eq!(q, 30.0);
}
#[test]
fn mode_bucket_as_str_all_variants() {
assert_eq!(ModeBucket::Inline.as_str(), "inline");
assert_eq!(ModeBucket::InlineAuto.as_str(), "inline_auto");
assert_eq!(ModeBucket::AltScreen.as_str(), "altscreen");
}
#[test]
fn diff_bucket_as_str_all_variants() {
assert_eq!(DiffBucket::Full.as_str(), "full");
assert_eq!(DiffBucket::DirtyRows.as_str(), "dirty");
assert_eq!(DiffBucket::FullRedraw.as_str(), "redraw");
}
#[test]
fn diff_bucket_from_strategy() {
assert_eq!(DiffBucket::from(DiffStrategy::Full), DiffBucket::Full);
assert_eq!(
DiffBucket::from(DiffStrategy::DirtyRows),
DiffBucket::DirtyRows
);
assert_eq!(
DiffBucket::from(DiffStrategy::FullRedraw),
DiffBucket::FullRedraw
);
}
#[test]
fn bucket_key_display_format() {
let key = BucketKey {
mode: ModeBucket::AltScreen,
diff: DiffBucket::DirtyRows,
size_bucket: 12,
};
assert_eq!(format!("{key}"), "altscreen:dirty:12");
}
#[test]
fn observe_nan_residual_not_stored() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 5.0,
});
let key = test_key(80, 24);
let update = predictor.observe(key, 0.0, f64::NAN);
assert!(!update.residual.is_finite());
assert_eq!(predictor.bucket_samples(key), 0);
}
#[test]
fn observe_infinity_residual_not_stored() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 5.0,
});
let key = test_key(80, 24);
predictor.observe(key, 0.0, f64::INFINITY);
assert_eq!(predictor.bucket_samples(key), 0);
}
#[test]
fn prediction_risk_flag() {
let predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 50.0,
});
let key = test_key(80, 24);
let p = predictor.predict(key, 0.0, 100.0);
assert!(!p.risk); let p2 = predictor.predict(key, 0.0, 30.0);
assert!(p2.risk); }
#[test]
fn prediction_confidence() {
let predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.05,
min_samples: 1,
window_size: 8,
q_default: 0.0,
});
let key = test_key(80, 24);
let p = predictor.predict(key, 0.0, 100.0);
assert!((p.confidence - 0.95).abs() < 1e-10);
}
#[test]
fn global_fallback_with_data() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 100, window_size: 256,
q_default: 999.0,
});
let alt_key = BucketKey::from_context(ScreenMode::AltScreen, DiffStrategy::Full, 80, 24);
predictor.observe(alt_key, 0.0, 5.0);
let inline_key = test_key(80, 24);
let p = predictor.predict(inline_key, 0.0, 1000.0);
assert_eq!(p.fallback_level, 3);
assert_eq!(p.sample_count, 1);
assert_eq!(p.quantile, 5.0);
}
#[test]
fn mode_bucket_from_screen_modes() {
assert_eq!(
ModeBucket::from_screen_mode(ScreenMode::Inline { ui_height: 4 }),
ModeBucket::Inline
);
assert_eq!(
ModeBucket::from_screen_mode(ScreenMode::InlineAuto {
min_height: 4,
max_height: 24
}),
ModeBucket::InlineAuto
);
assert_eq!(
ModeBucket::from_screen_mode(ScreenMode::AltScreen),
ModeBucket::AltScreen
);
}
#[test]
fn config_defaults() {
let config = ConformalConfig::default();
assert!((config.alpha - 0.05).abs() < 1e-10);
assert_eq!(config.min_samples, 20);
assert_eq!(config.window_size, 256);
assert!((config.q_default - 10_000.0).abs() < 1e-10);
}
#[test]
fn predictor_config_accessor() {
let config = ConformalConfig {
alpha: 0.2,
min_samples: 5,
window_size: 32,
q_default: 100.0,
};
let predictor = ConformalPredictor::new(config);
assert!((predictor.config().alpha - 0.2).abs() < 1e-10);
assert_eq!(predictor.config().min_samples, 5);
}
#[test]
fn negative_residual_clamped_in_prediction() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 0.0,
});
let key = test_key(80, 24);
predictor.observe(key, 10.0, 5.0);
let p = predictor.predict(key, 10.0, 100.0);
assert_eq!(p.upper_us, 10.0);
}
#[test]
fn observe_returns_correct_update() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 0.0,
});
let key = test_key(80, 24);
let update = predictor.observe(key, 3.0, 10.0);
assert!((update.residual - 7.0).abs() < 1e-10);
assert_eq!(update.bucket, key);
assert_eq!(update.sample_count, 1);
}
#[test]
fn prediction_preserves_yhat_and_budget() {
let predictor = ConformalPredictor::new(ConformalConfig::default());
let key = test_key(80, 24);
let p = predictor.predict(key, 42.5, 16666.0);
assert!((p.y_hat - 42.5).abs() < 1e-10);
assert!((p.budget_us - 16666.0).abs() < 1e-10);
}
#[test]
fn predict_emits_conformal_predict_span() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
struct SpanChecker {
saw_conformal_predict: Arc<AtomicBool>,
}
impl tracing::Subscriber for SpanChecker {
fn enabled(&self, _metadata: &tracing::Metadata<'_>) -> bool {
true
}
fn new_span(&self, span: &tracing::span::Attributes<'_>) -> tracing::span::Id {
if span.metadata().name() == "conformal.predict" {
self.saw_conformal_predict.store(true, Ordering::Relaxed);
}
tracing::span::Id::from_u64(1)
}
fn record(&self, _span: &tracing::span::Id, _values: &tracing::span::Record<'_>) {}
fn record_follows_from(&self, _span: &tracing::span::Id, _follows: &tracing::span::Id) {
}
fn event(&self, _event: &tracing::Event<'_>) {}
fn enter(&self, _span: &tracing::span::Id) {}
fn exit(&self, _span: &tracing::span::Id) {}
}
let saw_it = Arc::new(AtomicBool::new(false));
let subscriber = SpanChecker {
saw_conformal_predict: Arc::clone(&saw_it),
};
let _guard = tracing::subscriber::set_default(subscriber);
let predictor = ConformalPredictor::new(ConformalConfig::default());
let key = test_key(80, 24);
let _ = predictor.predict(key, 100.0, 16666.0);
assert!(
saw_it.load(Ordering::Relaxed),
"predict() must emit a 'conformal.predict' tracing span"
);
}
#[test]
fn predict_span_records_gate_triggered_true() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
struct GateChecker {
saw_gate_true: Arc<AtomicBool>,
}
struct GateVisitor(Arc<AtomicBool>);
impl tracing::field::Visit for GateVisitor {
fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
if field.name() == "gate_triggered" && value {
self.0.store(true, Ordering::Relaxed);
}
}
fn record_debug(&mut self, _field: &tracing::field::Field, _value: &dyn fmt::Debug) {}
}
impl tracing::Subscriber for GateChecker {
fn enabled(&self, _metadata: &tracing::Metadata<'_>) -> bool {
true
}
fn new_span(&self, _span: &tracing::span::Attributes<'_>) -> tracing::span::Id {
tracing::span::Id::from_u64(1)
}
fn record(&self, _span: &tracing::span::Id, values: &tracing::span::Record<'_>) {
let mut visitor = GateVisitor(Arc::clone(&self.saw_gate_true));
values.record(&mut visitor);
}
fn record_follows_from(&self, _span: &tracing::span::Id, _follows: &tracing::span::Id) {
}
fn event(&self, _event: &tracing::Event<'_>) {}
fn enter(&self, _span: &tracing::span::Id) {}
fn exit(&self, _span: &tracing::span::Id) {}
}
let saw_gate = Arc::new(AtomicBool::new(false));
let subscriber = GateChecker {
saw_gate_true: Arc::clone(&saw_gate),
};
let _guard = tracing::subscriber::set_default(subscriber);
let predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 50_000.0, });
let key = test_key(80, 24);
let p = predictor.predict(key, 0.0, 100.0);
assert!(p.risk, "prediction should be risky");
assert!(
saw_gate.load(Ordering::Relaxed),
"predict() must record gate_triggered=true when risk"
);
}
#[test]
fn calibration_uniform_distribution_quantile() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.05,
min_samples: 1,
window_size: 256,
q_default: 0.0,
});
let key = test_key(80, 24);
for i in 0..100 {
predictor.observe(key, 0.0, i as f64);
}
let p = predictor.predict(key, 0.0, 1_000.0);
assert_eq!(p.fallback_level, 0);
assert_eq!(p.sample_count, 100);
assert!((p.quantile - 95.0).abs() < 1e-10);
}
#[test]
fn calibration_gaussian_like_distribution() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 256,
q_default: 0.0,
});
let key = test_key(120, 40);
for i in 0..50 {
let residual = (i as f64) - 24.5;
predictor.observe(key, 100.0, 100.0 + residual);
}
let p = predictor.predict(key, 100.0, 1_000.0);
assert_eq!(p.fallback_level, 0);
assert_eq!(p.sample_count, 50);
assert!((p.quantile - 20.5).abs() < 1e-10);
assert!((p.upper_us - 120.5).abs() < 1e-10);
}
#[test]
fn calibration_constant_residuals() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.05,
min_samples: 1,
window_size: 256,
q_default: 0.0,
});
let key = test_key(80, 24);
for _ in 0..30 {
predictor.observe(key, 100.0, 105.0); }
let p = predictor.predict(key, 100.0, 1_000.0);
assert!((p.quantile - 5.0).abs() < 1e-10);
assert!((p.upper_us - 105.0).abs() < 1e-10);
}
#[test]
fn coverage_property_uniform_residuals() {
let alpha = 0.1;
let n_calibrate = 100;
let n_test = 200;
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha,
min_samples: 1,
window_size: 256,
q_default: 0.0,
});
let key = test_key(80, 24);
for i in 0..n_calibrate {
predictor.observe(key, 0.0, i as f64);
}
let prediction = predictor.predict(key, 0.0, f64::MAX);
let upper_bound = prediction.upper_us;
let mut covered = 0;
for i in 0..n_test {
let new_obs = (i as f64) * (n_calibrate as f64) / (n_test as f64);
if new_obs <= upper_bound {
covered += 1;
}
}
let empirical_coverage = covered as f64 / n_test as f64;
let target_coverage = 1.0 - alpha - 0.05; assert!(
empirical_coverage >= target_coverage,
"Empirical coverage {empirical_coverage:.3} should be >= {target_coverage:.3}"
);
}
#[test]
fn coverage_property_with_shifted_test_distribution() {
let alpha = 0.05;
let n = 200;
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha,
min_samples: 1,
window_size: 512,
q_default: 0.0,
});
let key = test_key(80, 24);
for i in 1..=n {
predictor.observe(key, 0.0, i as f64);
}
let p = predictor.predict(key, 0.0, f64::MAX);
assert!((p.quantile - 191.0).abs() < 1e-10);
let covered = (1..=n).filter(|&i| (i as f64) <= p.upper_us).count();
let coverage = covered as f64 / n as f64;
assert!(
coverage >= 1.0 - alpha,
"Coverage {coverage:.3} should be >= {:.3}",
1.0 - alpha
);
}
#[test]
fn gate_trigger_exact_boundary() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 0.0,
});
let key = test_key(80, 24);
predictor.observe(key, 0.0, 50.0);
let p = predictor.predict(key, 0.0, 50.0);
assert!(
!p.risk,
"upper_us ({}) == budget_us ({}) should NOT trigger risk",
p.upper_us, p.budget_us
);
}
#[test]
fn gate_trigger_just_above_boundary() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 0.0,
});
let key = test_key(80, 24);
predictor.observe(key, 0.0, 50.0);
let p = predictor.predict(key, 0.0, 49.999);
assert!(p.risk, "upper_us > budget should trigger risk");
}
#[test]
fn gate_trigger_just_below_boundary() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 0.0,
});
let key = test_key(80, 24);
predictor.observe(key, 0.0, 50.0);
let p = predictor.predict(key, 0.0, 50.001);
assert!(!p.risk, "upper_us < budget should NOT trigger risk");
}
#[test]
fn gate_trigger_zero_budget() {
let predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 1.0,
});
let key = test_key(80, 24);
let p = predictor.predict(key, 0.0, 0.0);
assert!(p.risk, "positive upper_us with zero budget should be risky");
}
#[test]
fn gate_trigger_very_large_budget() {
let predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 100_000.0,
});
let key = test_key(80, 24);
let p = predictor.predict(key, 1_000.0, f64::MAX);
assert!(!p.risk, "huge budget should never trigger risk");
}
#[test]
fn alpha_sensitivity_wider_interval_with_lower_alpha() {
let key = test_key(80, 24);
let mut predictor_tight = ConformalPredictor::new(ConformalConfig {
alpha: 0.5, min_samples: 1,
window_size: 256,
q_default: 0.0,
});
let mut predictor_wide = ConformalPredictor::new(ConformalConfig {
alpha: 0.01, min_samples: 1,
window_size: 256,
q_default: 0.0,
});
for i in 0..100 {
let obs = i as f64;
predictor_tight.observe(key, 0.0, obs);
predictor_wide.observe(key, 0.0, obs);
}
let p_tight = predictor_tight.predict(key, 0.0, 10_000.0);
let p_wide = predictor_wide.predict(key, 0.0, 10_000.0);
assert!(
p_wide.quantile > p_tight.quantile,
"Lower alpha ({}) should produce wider interval: quantile {} vs {}",
0.01,
p_wide.quantile,
p_tight.quantile
);
assert!(
p_wide.upper_us > p_tight.upper_us,
"Lower alpha should produce higher upper bound"
);
}
#[test]
fn alpha_sensitivity_confidence_reflects_alpha() {
for &alpha in &[0.01, 0.05, 0.1, 0.2, 0.5] {
let predictor = ConformalPredictor::new(ConformalConfig {
alpha,
min_samples: 1,
window_size: 8,
q_default: 0.0,
});
let key = test_key(80, 24);
let p = predictor.predict(key, 0.0, 1_000.0);
let expected_confidence = 1.0 - alpha;
assert!(
(p.confidence - expected_confidence).abs() < 1e-10,
"confidence should be 1-alpha for alpha={alpha}"
);
}
}
#[test]
fn alpha_sensitivity_extreme_alpha_zero() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.001,
min_samples: 1,
window_size: 256,
q_default: 0.0,
});
let key = test_key(80, 24);
for i in 0..100 {
predictor.observe(key, 0.0, i as f64);
}
let p = predictor.predict(key, 0.0, 10_000.0);
assert!((p.quantile - 99.0).abs() < 1e-10);
}
#[test]
fn alpha_sensitivity_extreme_alpha_one() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.99,
min_samples: 1,
window_size: 256,
q_default: 0.0,
});
let key = test_key(80, 24);
for i in 0..100 {
predictor.observe(key, 0.0, i as f64);
}
let p = predictor.predict(key, 0.0, 10_000.0);
assert!((p.quantile - 1.0).abs() < 1e-10);
}
#[test]
fn empty_calibration_uses_default() {
let predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.05,
min_samples: 20,
window_size: 256,
q_default: 10_000.0,
});
let key = test_key(80, 24);
let p = predictor.predict(key, 100.0, 16_666.0);
assert_eq!(p.sample_count, 0);
assert_eq!(p.fallback_level, 3);
assert!((p.quantile - 10_000.0).abs() < 1e-10);
assert!((p.upper_us - 10_100.0).abs() < 1e-10);
}
#[test]
fn one_sample_below_min_samples_uses_fallback() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.05,
min_samples: 20,
window_size: 256,
q_default: 999.0,
});
let key = test_key(80, 24);
predictor.observe(key, 0.0, 5.0);
let p = predictor.predict(key, 0.0, 1_000.0);
assert_eq!(p.fallback_level, 3);
assert_eq!(p.sample_count, 1);
}
#[test]
fn exactly_min_samples_minus_one_uses_fallback() {
let min_samples = 5;
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples,
window_size: 256,
q_default: 999.0,
});
let key = test_key(80, 24);
for i in 0..(min_samples - 1) {
predictor.observe(key, 0.0, (i as f64) * 10.0);
}
let p = predictor.predict(key, 0.0, 1_000.0);
assert_eq!(p.fallback_level, 3);
assert_eq!(p.sample_count, min_samples - 1);
}
#[test]
fn exactly_min_samples_uses_exact_bucket() {
let min_samples = 5;
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples,
window_size: 256,
q_default: 999.0,
});
let key = test_key(80, 24);
for i in 0..min_samples {
predictor.observe(key, 0.0, (i as f64) * 10.0);
}
let p = predictor.predict(key, 0.0, 1_000.0);
assert_eq!(p.fallback_level, 0);
assert_eq!(p.sample_count, min_samples);
}
#[test]
fn min_samples_plus_one_uses_exact_bucket() {
let min_samples = 5;
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples,
window_size: 256,
q_default: 999.0,
});
let key = test_key(80, 24);
for i in 0..=min_samples {
predictor.observe(key, 0.0, (i as f64) * 10.0);
}
let p = predictor.predict(key, 0.0, 1_000.0);
assert_eq!(p.fallback_level, 0);
assert_eq!(p.sample_count, min_samples + 1);
}
#[test]
fn min_samples_one_allows_single_observation() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 256,
q_default: 999.0,
});
let key = test_key(80, 24);
predictor.observe(key, 0.0, 42.0);
let p = predictor.predict(key, 0.0, 1_000.0);
assert_eq!(p.fallback_level, 0);
assert_eq!(p.sample_count, 1);
assert!((p.quantile - 42.0).abs() < 1e-10);
}
#[test]
fn predict_span_records_calibration_set_size() {
use std::sync::Arc;
use std::sync::Mutex;
struct FieldRecorder {
calibration_size: Arc<Mutex<Option<u64>>>,
}
struct SizeVisitor(Arc<Mutex<Option<u64>>>);
impl tracing::field::Visit for SizeVisitor {
fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
if field.name() == "calibration_set_size" {
*self.0.lock().unwrap() = Some(value);
}
}
fn record_debug(&mut self, _field: &tracing::field::Field, _value: &dyn fmt::Debug) {}
}
impl tracing::Subscriber for FieldRecorder {
fn enabled(&self, _metadata: &tracing::Metadata<'_>) -> bool {
true
}
fn new_span(&self, _span: &tracing::span::Attributes<'_>) -> tracing::span::Id {
tracing::span::Id::from_u64(1)
}
fn record(&self, _span: &tracing::span::Id, values: &tracing::span::Record<'_>) {
let mut v = SizeVisitor(Arc::clone(&self.calibration_size));
values.record(&mut v);
}
fn record_follows_from(&self, _: &tracing::span::Id, _: &tracing::span::Id) {}
fn event(&self, _: &tracing::Event<'_>) {}
fn enter(&self, _: &tracing::span::Id) {}
fn exit(&self, _: &tracing::span::Id) {}
}
let size = Arc::new(Mutex::new(None));
let subscriber = FieldRecorder {
calibration_size: Arc::clone(&size),
};
let _guard = tracing::subscriber::set_default(subscriber);
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 0.0,
});
let key = test_key(80, 24);
for i in 0..5 {
predictor.observe(key, 0.0, i as f64);
}
let _ = predictor.predict(key, 0.0, 1_000.0);
let recorded = size.lock().unwrap();
assert_eq!(*recorded, Some(5), "calibration_set_size should be 5");
}
#[test]
fn predict_span_records_predicted_upper_bound() {
use std::sync::Arc;
use std::sync::Mutex;
struct UpperBoundRecorder {
upper_bound: Arc<Mutex<Option<f64>>>,
}
struct UpperVisitor(Arc<Mutex<Option<f64>>>);
impl tracing::field::Visit for UpperVisitor {
fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
if field.name() == "predicted_upper_bound_us" {
*self.0.lock().unwrap() = Some(value);
}
}
fn record_debug(&mut self, _: &tracing::field::Field, _: &dyn fmt::Debug) {}
}
impl tracing::Subscriber for UpperBoundRecorder {
fn enabled(&self, _metadata: &tracing::Metadata<'_>) -> bool {
true
}
fn new_span(&self, _: &tracing::span::Attributes<'_>) -> tracing::span::Id {
tracing::span::Id::from_u64(1)
}
fn record(&self, _: &tracing::span::Id, values: &tracing::span::Record<'_>) {
let mut v = UpperVisitor(Arc::clone(&self.upper_bound));
values.record(&mut v);
}
fn record_follows_from(&self, _: &tracing::span::Id, _: &tracing::span::Id) {}
fn event(&self, _: &tracing::Event<'_>) {}
fn enter(&self, _: &tracing::span::Id) {}
fn exit(&self, _: &tracing::span::Id) {}
}
let upper = Arc::new(Mutex::new(None));
let subscriber = UpperBoundRecorder {
upper_bound: Arc::clone(&upper),
};
let _guard = tracing::subscriber::set_default(subscriber);
let predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 42.0,
});
let key = test_key(80, 24);
let p = predictor.predict(key, 10.0, 1_000.0);
let recorded = upper.lock().unwrap();
assert!(
recorded.is_some(),
"predicted_upper_bound_us should be recorded"
);
assert!(
(recorded.unwrap() - p.upper_us).abs() < 1e-10,
"recorded upper bound should match prediction"
);
}
#[test]
fn predict_span_records_gate_triggered_false() {
use std::sync::Arc;
use std::sync::Mutex;
struct GateFalseChecker {
gate_value: Arc<Mutex<Option<bool>>>,
}
struct GateFalseVisitor(Arc<Mutex<Option<bool>>>);
impl tracing::field::Visit for GateFalseVisitor {
fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
if field.name() == "gate_triggered" {
*self.0.lock().unwrap() = Some(value);
}
}
fn record_debug(&mut self, _: &tracing::field::Field, _: &dyn fmt::Debug) {}
}
impl tracing::Subscriber for GateFalseChecker {
fn enabled(&self, _: &tracing::Metadata<'_>) -> bool {
true
}
fn new_span(&self, _: &tracing::span::Attributes<'_>) -> tracing::span::Id {
tracing::span::Id::from_u64(1)
}
fn record(&self, _: &tracing::span::Id, values: &tracing::span::Record<'_>) {
let mut v = GateFalseVisitor(Arc::clone(&self.gate_value));
values.record(&mut v);
}
fn record_follows_from(&self, _: &tracing::span::Id, _: &tracing::span::Id) {}
fn event(&self, _: &tracing::Event<'_>) {}
fn enter(&self, _: &tracing::span::Id) {}
fn exit(&self, _: &tracing::span::Id) {}
}
let gate = Arc::new(Mutex::new(None));
let subscriber = GateFalseChecker {
gate_value: Arc::clone(&gate),
};
let _guard = tracing::subscriber::set_default(subscriber);
let predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 8,
q_default: 1.0,
});
let key = test_key(80, 24);
let p = predictor.predict(key, 0.0, 1_000_000.0);
assert!(!p.risk);
let recorded = gate.lock().unwrap();
assert_eq!(
*recorded,
Some(false),
"gate_triggered should be recorded as false"
);
}
#[test]
fn jsonl_output_contains_required_fields() {
let prediction = ConformalPrediction {
upper_us: 150.5,
risk: true,
confidence: 0.95,
bucket: BucketKey {
mode: ModeBucket::Inline,
diff: DiffBucket::Full,
size_bucket: 10,
},
sample_count: 42,
quantile: 50.5,
fallback_level: 0,
window_size: 256,
reset_count: 1,
y_hat: 100.0,
budget_us: 140.0,
};
let jsonl = prediction.to_jsonl();
assert!(jsonl.contains("\"schema\":\"conformal-v1\""));
assert!(jsonl.contains("\"upper_us\":150.5"));
assert!(jsonl.contains("\"risk\":true"));
assert!(jsonl.contains("\"confidence\":0.9500"));
assert!(jsonl.contains("\"bucket\":\"inline:full:10\""));
assert!(jsonl.contains("\"samples\":42"));
assert!(jsonl.contains("\"quantile\":50.50"));
assert!(jsonl.contains("\"fallback_level\":0"));
assert!(jsonl.contains("\"window\":256"));
assert!(jsonl.contains("\"resets\":1"));
assert!(jsonl.contains("\"y_hat\":100.0"));
assert!(jsonl.contains("\"budget_us\":140.0"));
}
#[test]
fn property_empirical_coverage_deterministic_sequences() {
for alpha in [0.05, 0.1, 0.2] {
let n_calibrate = 100;
let n_test = 100;
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha,
min_samples: 1,
window_size: 256,
q_default: 0.0,
});
let key = test_key(80, 24);
for i in 1..=n_calibrate {
predictor.observe(key, 0.0, i as f64);
}
let p = predictor.predict(key, 0.0, f64::MAX);
let covered = (1..=n_test).filter(|&i| (i as f64) <= p.upper_us).count();
let coverage = covered as f64 / n_test as f64;
assert!(
coverage >= 1.0 - alpha - 0.02,
"alpha={alpha}: coverage {coverage:.3} should be >= {:.3}",
1.0 - alpha - 0.02
);
}
}
#[test]
fn property_monotone_quantile_with_more_extreme_data() {
let key = test_key(80, 24);
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size: 256,
q_default: 0.0,
});
for i in 0..50 {
predictor.observe(key, 0.0, i as f64);
}
let q_moderate = predictor.predict(key, 0.0, f64::MAX).quantile;
for _ in 0..50 {
predictor.observe(key, 0.0, 1000.0);
}
let q_extreme = predictor.predict(key, 0.0, f64::MAX).quantile;
assert!(
q_extreme >= q_moderate,
"Adding extreme data should not decrease quantile: {q_extreme} vs {q_moderate}"
);
}
#[test]
fn property_quantile_bounded_by_max_residual() {
let key = test_key(80, 24);
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.001, min_samples: 1,
window_size: 256,
q_default: 0.0,
});
let max_residual = 100.0;
for i in 0..50 {
predictor.observe(key, 0.0, (i as f64) * 2.0); }
let p = predictor.predict(key, 0.0, f64::MAX);
assert!(
p.quantile <= max_residual,
"quantile {} should be <= max residual {max_residual}",
p.quantile
);
}
#[test]
fn property_window_eviction_changes_quantile() {
let key = test_key(80, 24);
let window_size = 10;
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 1,
window_size,
q_default: 0.0,
});
for _ in 0..window_size {
predictor.observe(key, 0.0, 1000.0);
}
let q_large = predictor.predict(key, 0.0, f64::MAX).quantile;
for _ in 0..window_size {
predictor.observe(key, 0.0, 1.0);
}
let q_small = predictor.predict(key, 0.0, f64::MAX).quantile;
assert!(
q_small < q_large,
"After eviction, quantile should decrease: {q_small} vs {q_large}"
);
}
#[test]
fn cross_mode_fallback_does_not_mix_modes() {
let mut predictor = ConformalPredictor::new(ConformalConfig {
alpha: 0.1,
min_samples: 5,
window_size: 256,
q_default: 999.0,
});
let alt_key = BucketKey::from_context(ScreenMode::AltScreen, DiffStrategy::Full, 80, 24);
for i in 0..10 {
predictor.observe(alt_key, 0.0, (i as f64) * 100.0);
}
let inline_key = BucketKey::from_context(
ScreenMode::Inline { ui_height: 4 },
DiffStrategy::Full,
120,
40,
);
let p = predictor.predict(inline_key, 0.0, 1_000_000.0);
assert_eq!(
p.fallback_level, 3,
"Cross-mode query should fall to global"
);
}
#[test]
fn reset_count_accumulates_across_resets() {
let mut predictor = ConformalPredictor::new(ConformalConfig::default());
let key = test_key(80, 24);
predictor.observe(key, 0.0, 1.0);
predictor.reset_bucket(key);
predictor.observe(key, 0.0, 2.0);
predictor.reset_all();
predictor.observe(key, 0.0, 3.0);
predictor.reset_bucket(key);
let p = predictor.predict(key, 0.0, 1_000.0);
assert_eq!(p.reset_count, 3, "reset_count should accumulate");
}
}