scirs2-stats 0.4.1

Statistical functions module for SciRS2 (scirs2-stats)
Documentation
//! Online / adaptive conformal prediction.
//!
//! Implements **Adaptive Conformal Inference (ACI)** (Gibbs & Candès 2021),
//! which adjusts the significance level `α_t` on the fly to track a target
//! marginal coverage of `1 − α` under distribution shift.
//!
//! ## ACI update rule
//!
//! ```text
//! α_{t+1} = α_t + γ · (α − 1_{y_t ∉ Ĉ_t})
//! ```
//!
//! * When the prediction set *misses* the true value (`1_{…} = 0`), the step
//!   `+ γ · α` *increases* α_t, making future sets **narrower** (because we
//!   over-covered on average).
//! * When the set *covers* the true value (`1_{…} = 1`), the step
//!   `+ γ · (α − 1) < 0` *decreases* α_t, making future sets **wider**.
//!
//! Over time this tracks the target coverage `1 − α` even under covariate
//! or label shift.

use crate::conformal::types::{conformal_quantile, PredictionSet};

// ---------------------------------------------------------------------------
// ACI conformal predictor
// ---------------------------------------------------------------------------

/// Adaptive Conformal Inference predictor (Gibbs & Candès 2021).
///
/// Maintains a running `alpha_t` that is updated after each test example
/// to keep empirical coverage near `1 − alpha`.
///
/// # Example
/// ```rust
/// use scirs2_stats::conformal::online_conformal::AciConformal;
///
/// let scores = vec![0.1, 0.2, 0.3, 0.4, 0.5];
/// let mut aci = AciConformal::new(0.1, 0.05);
/// let ps = aci.predict_with_current(3.0, &scores);
/// assert!(ps.is_some());
/// ```
#[derive(Debug, Clone)]
pub struct AciConformal {
    /// Target significance level α ∈ (0, 1).
    pub alpha: f64,
    /// Current (running) significance level α_t.
    pub alpha_t: f64,
    /// Step size γ for the ACI update.
    pub gamma: f64,
    /// Coverage history: `true` = the t-th prediction set covered y_t.
    pub history: Vec<bool>,
}

impl AciConformal {
    /// Create a new ACI predictor.
    ///
    /// * `alpha` — target significance level.
    /// * `gamma` — step size (default 0.05 as recommended by Gibbs & Candès).
    pub fn new(alpha: f64, gamma: f64) -> Self {
        Self {
            alpha,
            alpha_t: alpha,
            gamma,
            history: Vec::new(),
        }
    }

    /// Return the current (running) significance level α_t.
    pub fn current_alpha(&self) -> f64 {
        self.alpha_t.clamp(1e-6, 1.0 - 1e-6)
    }

    /// Build a prediction interval `[ŷ − Q̂, ŷ + Q̂]` using the current α_t
    /// and the provided calibration-like scores.
    ///
    /// `scores` may be a recent rolling window of nonconformity scores.
    ///
    /// Returns `None` if `scores` is empty.
    pub fn predict_with_current(&self, y_hat: f64, scores: &[f64]) -> Option<PredictionSet> {
        if scores.is_empty() {
            return None;
        }
        let q = conformal_quantile(scores, self.current_alpha());
        Some(PredictionSet::interval(y_hat - q, y_hat + q))
    }

    /// Update the running α_t after observing the true value `y_true` and
    /// whether the prediction set covered it.
    ///
    /// Implements:
    /// `α_{t+1} = α_t + γ · (α − 1_{y_t ∈ Ĉ_t})`
    ///
    /// * `covered` — whether the last prediction set contained `y_true`.
    pub fn update(&mut self, covered: bool) {
        let indicator = if covered { 1.0 } else { 0.0 };
        self.alpha_t += self.gamma * (self.alpha - indicator);
        // Clamp to (0, 1) to keep α_t a valid significance level
        self.alpha_t = self.alpha_t.clamp(1e-6, 1.0 - 1e-6);
        self.history.push(covered);
    }

    /// Convenience method: predict, then immediately update with the true value.
    ///
    /// Returns the prediction set and whether it covered `y_true`.
    pub fn predict_and_update(
        &mut self,
        y_hat: f64,
        y_true: f64,
        scores: &[f64],
    ) -> (Option<PredictionSet>, bool) {
        let ps = self.predict_with_current(y_hat, scores);
        let covered = ps.as_ref().map_or(false, |s| s.contains_value(y_true));
        self.update(covered);
        (ps, covered)
    }

    /// Running fraction of time steps that were covered.
    pub fn running_coverage(&self) -> f64 {
        running_coverage(&self.history)
    }

    /// Coverage over the most recent `window` steps.
    pub fn recent_coverage(&self, window: usize) -> f64 {
        if self.history.is_empty() {
            return 0.0;
        }
        let start = self.history.len().saturating_sub(window);
        let slice = &self.history[start..];
        if slice.is_empty() {
            0.0
        } else {
            slice.iter().filter(|&&b| b).count() as f64 / slice.len() as f64
        }
    }
}

// ---------------------------------------------------------------------------
// Coverage tracking utilities
// ---------------------------------------------------------------------------

/// Compute the running (overall) empirical coverage from a boolean history.
///
/// Returns the fraction of `true` values in `history`, or `0.0` if empty.
pub fn running_coverage(history: &[bool]) -> f64 {
    if history.is_empty() {
        return 0.0;
    }
    history.iter().filter(|&&b| b).count() as f64 / history.len() as f64
}

/// Compute the coverage drift: difference between recent-window coverage and
/// overall coverage.
///
/// A positive value means the recent window is better covered (coverage
/// increasing); negative means degrading coverage.
pub fn coverage_drift(history: &[bool], window: usize) -> f64 {
    if history.is_empty() || window == 0 {
        return 0.0;
    }
    let overall = running_coverage(history);
    let start = history.len().saturating_sub(window);
    let slice = &history[start..];
    if slice.is_empty() {
        return 0.0;
    }
    let recent = slice.iter().filter(|&&b| b).count() as f64 / slice.len() as f64;
    recent - overall
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::*;

    struct Lcg {
        state: u64,
    }

    impl Lcg {
        fn new(seed: u64) -> Self {
            Self { state: seed }
        }

        fn next_f64(&mut self) -> f64 {
            self.state = self
                .state
                .wrapping_mul(6_364_136_223_846_793_005)
                .wrapping_add(1_442_695_040_888_963_407);
            (self.state >> 33) as f64 / (u32::MAX as f64)
        }

        fn next_normal(&mut self) -> f64 {
            let u1 = self.next_f64().max(1e-12);
            let u2 = self.next_f64();
            (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
        }
    }

    #[test]
    fn test_aci_alpha_increases_when_missed() {
        let mut aci = AciConformal::new(0.1, 0.05);
        let alpha_before = aci.current_alpha();
        // Not covered → indicator = 0 → α_t += γ*(α - 0) = 0.05*0.1 = +0.005
        aci.update(false);
        let alpha_after = aci.current_alpha();
        assert!(
            alpha_after > alpha_before,
            "Alpha should increase when not covered: {} -> {}",
            alpha_before,
            alpha_after
        );
    }

    #[test]
    fn test_aci_alpha_decreases_when_covered() {
        let mut aci = AciConformal::new(0.1, 0.05);
        let alpha_before = aci.current_alpha();
        // Covered → indicator = 1 → α_t += γ*(α - 1) = 0.05*(0.1-1) = -0.045
        aci.update(true);
        let alpha_after = aci.current_alpha();
        assert!(
            alpha_after < alpha_before,
            "Alpha should decrease when covered: {} -> {}",
            alpha_before,
            alpha_after
        );
    }

    #[test]
    fn test_aci_running_coverage_tracks() {
        let mut aci = AciConformal::new(0.1, 0.05);
        // Manually push 80% covered
        for i in 0..100 {
            aci.update(i % 5 != 0); // 80% covered
        }
        let cov = aci.running_coverage();
        assert!((cov - 0.8).abs() < 0.01, "Coverage = {}", cov);
    }

    #[test]
    fn test_aci_predict_returns_interval() {
        let aci = AciConformal::new(0.1, 0.05);
        let scores: Vec<f64> = (1..=10).map(|x| x as f64 * 0.1).collect();
        let ps = aci.predict_with_current(5.0, &scores);
        assert!(ps.is_some());
        let ps = ps.expect("interval");
        assert!(ps.lower < ps.upper);
    }

    #[test]
    fn test_aci_online_adapts() {
        // Simulate a scenario where the model is initially miscalibrated (too
        // narrow), then ACI should increase α_t to widen intervals over time.
        let mut rng = Lcg::new(42);
        let mut aci = AciConformal::new(0.1, 0.05);

        // A fixed tiny window of scores → very narrow intervals
        let tiny_scores = vec![0.01_f64; 20];
        let mut cumulative_covered = 0usize;
        let n = 200usize;

        for _ in 0..n {
            let y_true = rng.next_normal() * 2.0; // std dev = 2, not 0.01
            let y_hat = 0.0;
            let ps = aci.predict_with_current(y_hat, &tiny_scores);
            let covered = ps.map_or(false, |s| s.contains_value(y_true));
            if covered {
                cumulative_covered += 1;
            }
            aci.update(covered);
        }

        // After many steps without coverage, α_t should have increased above α
        // (intervals become very wide eventually)
        let final_alpha = aci.current_alpha();
        // α_t may or may not be clamped; just verify the online update ran
        assert!(aci.history.len() == n);
        assert!(final_alpha > 0.0 && final_alpha < 1.0);
        let _ = cumulative_covered;
    }

    #[test]
    fn test_running_coverage_util() {
        let hist = vec![true, true, false, true, false];
        let cov = running_coverage(&hist);
        assert!((cov - 0.6).abs() < 1e-10, "cov = {}", cov);
    }

    #[test]
    fn test_coverage_drift_positive() {
        // Recent window covered more → positive drift
        let mut hist: Vec<bool> = vec![false; 50];
        hist.extend(vec![true; 50]); // recent 50 all covered
        let drift = coverage_drift(&hist, 50);
        assert!(drift > 0.0, "drift = {}", drift);
    }

    #[test]
    fn test_coverage_drift_negative() {
        // Recent window covered less → negative drift
        let mut hist: Vec<bool> = vec![true; 50];
        hist.extend(vec![false; 50]);
        let drift = coverage_drift(&hist, 50);
        assert!(drift < 0.0, "drift = {}", drift);
    }
}