use ndarray::ArrayView1;
use crate::error::{RagDriftError, Result};
use crate::stats::ks_two_sample;
use crate::types::{check_min_samples, DriftDimension, DriftScore};
#[derive(Debug, Clone, Copy)]
pub struct ConfidenceDriftConfig {
pub threshold: f64,
pub ece_bins: usize,
}
impl Default for ConfidenceDriftConfig {
fn default() -> Self {
Self {
threshold: 0.15,
ece_bins: 10,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ConfidenceDriftDetector {
config: ConfidenceDriftConfig,
}
impl ConfidenceDriftDetector {
pub fn new(config: ConfidenceDriftConfig) -> Self {
Self { config }
}
pub fn detect(
&self,
baseline: &ArrayView1<'_, f64>,
current: &ArrayView1<'_, f64>,
) -> Result<DriftScore> {
let d = ks_two_sample(baseline, current)?.statistic;
Ok(DriftScore::new(
DriftDimension::Confidence,
d,
self.config.threshold,
"ks",
))
}
pub fn detect_with_calibration(
&self,
baseline_scores: &ArrayView1<'_, f64>,
baseline_labels: &ArrayView1<'_, f64>,
current_scores: &ArrayView1<'_, f64>,
current_labels: &ArrayView1<'_, f64>,
) -> Result<DriftScore> {
let d = ks_two_sample(baseline_scores, current_scores)?.statistic;
let ece_b =
expected_calibration_error(baseline_scores, baseline_labels, self.config.ece_bins)?;
let ece_c =
expected_calibration_error(current_scores, current_labels, self.config.ece_bins)?;
let combined = d + (ece_c - ece_b).abs();
Ok(DriftScore::new(
DriftDimension::Confidence,
combined,
self.config.threshold,
"ks+ece",
))
}
}
fn expected_calibration_error(
scores: &ArrayView1<'_, f64>,
labels: &ArrayView1<'_, f64>,
n_bins: usize,
) -> Result<f64> {
if scores.len() != labels.len() {
return Err(RagDriftError::DimensionMismatch {
baseline: vec![scores.len()],
current: vec![labels.len()],
});
}
check_min_samples(scores.len(), n_bins)?;
if n_bins < 2 {
return Err(RagDriftError::InvalidConfig(format!(
"ECE needs at least 2 bins, got {n_bins}"
)));
}
let mut bin_scores = vec![0.0_f64; n_bins];
let mut bin_labels = vec![0.0_f64; n_bins];
let mut bin_counts = vec![0_usize; n_bins];
for (s, y) in scores.iter().zip(labels.iter()) {
let s = s.clamp(0.0, 1.0);
let mut idx = (s * n_bins as f64) as usize;
if idx >= n_bins {
idx = n_bins - 1;
}
bin_scores[idx] += s;
bin_labels[idx] += y;
bin_counts[idx] += 1;
}
let n = scores.len() as f64;
let mut ece = 0.0_f64;
for i in 0..n_bins {
if bin_counts[i] == 0 {
continue;
}
let avg_score = bin_scores[i] / bin_counts[i] as f64;
let avg_label = bin_labels[i] / bin_counts[i] as f64;
ece += (bin_counts[i] as f64 / n) * (avg_score - avg_label).abs();
}
Ok(ece)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array1;
#[test]
fn identical_scores_zero_drift() {
let a = Array1::from(vec![0.5; 100]);
let detector = ConfidenceDriftDetector::default();
let s = detector.detect(&a.view(), &a.view()).unwrap();
assert_eq!(s.score, 0.0);
}
#[test]
fn confidence_collapse_flagged() {
let a = Array1::from(vec![0.95; 100]);
let b = Array1::from(vec![0.4; 100]);
let detector = ConfidenceDriftDetector::default();
let s = detector.detect(&a.view(), &b.view()).unwrap();
assert!(s.exceeded);
}
#[test]
fn ece_perfect_calibration_zero() {
let scores = Array1::from(vec![0.5; 100]);
let labels = Array1::from((0..100).map(|i| (i % 2) as f64).collect::<Vec<_>>());
let ece = expected_calibration_error(&scores.view(), &labels.view(), 2).unwrap();
assert!(ece < 1e-9);
}
#[test]
fn ece_overconfidence_positive() {
let scores = Array1::from(vec![0.95; 100]);
let labels = Array1::from(vec![0.5; 100]);
let ece = expected_calibration_error(&scores.view(), &labels.view(), 10).unwrap();
assert!(ece > 0.4);
}
}