#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum ScoreType {
AbsResidual,
QuantileRegression,
NormalizedResidual,
Hpd,
Raps,
}
impl Default for ScoreType {
fn default() -> Self {
ScoreType::AbsResidual
}
}
#[derive(Debug, Clone)]
pub struct ConformalConfig {
pub alpha: f64,
pub score_fn: ScoreType,
}
impl Default for ConformalConfig {
fn default() -> Self {
Self {
alpha: 0.1,
score_fn: ScoreType::AbsResidual,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PredictionSet {
pub lower: f64,
pub upper: f64,
pub set: Vec<usize>,
}
impl PredictionSet {
pub fn interval(lower: f64, upper: f64) -> Self {
Self {
lower,
upper,
set: Vec::new(),
}
}
pub fn classification(set: Vec<usize>) -> Self {
Self {
lower: f64::NEG_INFINITY,
upper: f64::INFINITY,
set,
}
}
pub fn contains_value(&self, value: f64) -> bool {
value >= self.lower && value <= self.upper
}
pub fn contains_class(&self, class: usize) -> bool {
self.set.contains(&class)
}
pub fn width(&self) -> f64 {
if self.set.is_empty() {
self.upper - self.lower
} else {
f64::INFINITY
}
}
}
#[derive(Debug, Clone)]
pub struct ConformalResult {
pub sets: Vec<PredictionSet>,
pub coverage: f64,
pub avg_width: f64,
}
#[derive(Debug, Clone)]
pub struct RapsConfig {
pub k_reg: usize,
pub lambda: f64,
}
impl Default for RapsConfig {
fn default() -> Self {
Self {
k_reg: 5,
lambda: 0.01,
}
}
}
#[derive(Debug, Clone)]
pub struct CpConfig {
pub coverage_target: f64,
pub adaptive: bool,
}
impl Default for CpConfig {
fn default() -> Self {
Self {
coverage_target: 0.9,
adaptive: false,
}
}
}
pub fn conformal_quantile(scores: &[f64], alpha: f64) -> f64 {
if scores.is_empty() {
return f64::INFINITY;
}
let n = scores.len();
let level = ((n + 1) as f64 * (1.0 - alpha) / n as f64).min(1.0);
let mut sorted = scores.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = ((level * n as f64).ceil() as usize)
.saturating_sub(1)
.min(n - 1);
sorted[idx]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conformal_config_default() {
let cfg = ConformalConfig::default();
assert!((cfg.alpha - 0.1).abs() < 1e-10);
assert_eq!(cfg.score_fn, ScoreType::AbsResidual);
}
#[test]
fn test_cp_config_default() {
let cfg = CpConfig::default();
assert!((cfg.coverage_target - 0.9).abs() < 1e-10);
assert!(!cfg.adaptive);
}
#[test]
fn test_raps_config_default() {
let cfg = RapsConfig::default();
assert_eq!(cfg.k_reg, 5);
assert!(cfg.lambda > 0.0);
}
#[test]
fn test_prediction_set_contains_value() {
let ps = PredictionSet::interval(1.0, 3.0);
assert!(ps.contains_value(2.0));
assert!(!ps.contains_value(0.5));
assert!((ps.width() - 2.0).abs() < 1e-10);
}
#[test]
fn test_prediction_set_classification() {
let ps = PredictionSet::classification(vec![0, 2]);
assert!(ps.contains_class(0));
assert!(!ps.contains_class(1));
}
#[test]
fn test_conformal_quantile_basic() {
let scores: Vec<f64> = (1..=10).map(|x| x as f64).collect();
let q = conformal_quantile(&scores, 0.1);
assert!(q <= 10.0);
}
#[test]
fn test_conformal_quantile_empty() {
let q = conformal_quantile(&[], 0.1);
assert!(q.is_infinite());
}
}