1use serde::{Deserialize, Serialize};
21
22use crate::capability::RiskClass;
23use crate::error::{Result, SdkError};
24
25#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
28pub struct CalibrationSample {
29 pub score: f64,
30 pub is_unsafe: bool,
31}
32
33impl CalibrationSample {
34 pub fn new(score: f64, is_unsafe: bool) -> Self {
35 Self { score, is_unsafe }
36 }
37}
38
39pub fn accepted_unsafe_rate(samples: &[CalibrationSample], theta: f64) -> f64 {
42 if samples.is_empty() {
43 return 0.0;
44 }
45 let unsafe_accepted = samples
46 .iter()
47 .filter(|s| s.score > theta && s.is_unsafe)
48 .count();
49 unsafe_accepted as f64 / samples.len() as f64
50}
51
52pub fn conformal_threshold(samples: &[CalibrationSample], rho: f64) -> Result<f64> {
58 if !(0.0..=1.0).contains(&rho) {
59 return Err(SdkError::InvalidGate(format!(
60 "rho must be in [0,1]: {rho}"
61 )));
62 }
63 if samples.is_empty() {
64 return Err(SdkError::Domain(
65 "conformal calibration needs samples".into(),
66 ));
67 }
68 let n = samples.len() as f64;
69
70 let mut candidates: Vec<f64> = vec![0.0];
73 let mut scores: Vec<f64> = samples.iter().map(|s| s.score).collect();
74 scores.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
75 candidates.extend(scores);
76
77 for theta in candidates {
78 let r = accepted_unsafe_rate(samples, theta);
79 let adjusted = (n * r + 1.0) / (n + 1.0);
80 if adjusted <= rho {
81 return Ok(theta);
82 }
83 }
84 Ok(1.0)
85}
86
87pub fn ks_statistic(live: &[f64], calib: &[f64]) -> f64 {
89 if live.is_empty() || calib.is_empty() {
90 return 0.0;
91 }
92 let mut all: Vec<f64> = live.iter().chain(calib.iter()).copied().collect();
93 all.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
94 all.dedup();
95 let cdf =
96 |data: &[f64], x: f64| data.iter().filter(|&&v| v <= x).count() as f64 / data.len() as f64;
97 all.iter()
98 .map(|&x| (cdf(live, x) - cdf(calib, x)).abs())
99 .fold(0.0, f64::max)
100}
101
102pub fn is_drifted(live: &[f64], calib: &[f64], alpha_c: f64) -> bool {
106 if live.is_empty() || calib.is_empty() {
107 return false;
108 }
109 let n = live.len() as f64;
110 let m = calib.len() as f64;
111 let critical = alpha_c * ((n + m) / (n * m)).sqrt();
112 ks_statistic(live, calib) > critical
113}
114
115#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
117pub enum CalibrationState {
118 Calibrated { theta_hat: f64, rho: f64 },
120 Stale { backoff_theta: f64, last_theta: f64 },
123}
124
125impl CalibrationState {
126 pub fn calibrate(samples: &[CalibrationSample], rho: f64) -> Result<Self> {
128 let theta_hat = conformal_threshold(samples, rho)?;
129 Ok(CalibrationState::Calibrated { theta_hat, rho })
130 }
131
132 pub fn mark_stale(&self) -> Self {
135 let last = match self {
136 CalibrationState::Calibrated { theta_hat, .. } => *theta_hat,
137 CalibrationState::Stale { last_theta, .. } => *last_theta,
138 };
139 let backoff = (last + 1.0) / 2.0;
141 CalibrationState::Stale {
142 backoff_theta: backoff,
143 last_theta: last,
144 }
145 }
146
147 pub fn bound_is_asserted(&self) -> bool {
150 matches!(self, CalibrationState::Calibrated { .. })
151 }
152
153 fn active_threshold(&self) -> f64 {
154 match self {
155 CalibrationState::Calibrated { theta_hat, .. } => *theta_hat,
156 CalibrationState::Stale { backoff_theta, .. } => *backoff_theta,
157 }
158 }
159}
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
163#[serde(rename_all = "snake_case")]
164pub enum AcceptOutcome {
165 CertifiedAccept,
167 UncertifiedAccept,
170 RouteToApproval,
172 Reject,
174}
175
176pub fn decide(state: &CalibrationState, score: f64, risk: RiskClass) -> AcceptOutcome {
183 let threshold = state.active_threshold();
184 match state {
185 CalibrationState::Calibrated { .. } => {
186 if score > threshold {
187 AcceptOutcome::CertifiedAccept
188 } else {
189 AcceptOutcome::Reject
190 }
191 }
192 CalibrationState::Stale { .. } => {
193 if matches!(risk, RiskClass::High | RiskClass::Critical) {
195 return AcceptOutcome::RouteToApproval;
196 }
197 if score > threshold {
198 AcceptOutcome::UncertifiedAccept
199 } else {
200 AcceptOutcome::Reject
201 }
202 }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 fn samples() -> Vec<CalibrationSample> {
211 vec![
213 CalibrationSample::new(0.1, true),
214 CalibrationSample::new(0.2, true),
215 CalibrationSample::new(0.3, true),
216 CalibrationSample::new(0.6, false),
217 CalibrationSample::new(0.7, false),
218 CalibrationSample::new(0.8, false),
219 CalibrationSample::new(0.9, false),
220 CalibrationSample::new(0.95, false),
221 ]
222 }
223
224 #[test]
225 fn threshold_excludes_unsafe_low_scores() {
226 let theta = conformal_threshold(&samples(), 0.2).unwrap();
228 assert!(
229 theta >= 0.3,
230 "theta={theta} should exclude unsafe scores <= 0.3"
231 );
232 assert_eq!(accepted_unsafe_rate(&samples(), theta), 0.0);
234 }
235
236 #[test]
237 fn looser_budget_allows_lower_threshold() {
238 let tight = conformal_threshold(&samples(), 0.15).unwrap();
239 let loose = conformal_threshold(&samples(), 0.5).unwrap();
240 assert!(loose <= tight);
241 }
242
243 #[test]
244 fn impossible_budget_rejects_everything() {
245 let theta = conformal_threshold(&samples(), 0.0).unwrap();
247 assert_eq!(theta, 1.0);
248 }
249
250 #[test]
251 fn rho_out_of_range_is_error() {
252 assert!(conformal_threshold(&samples(), 1.5).is_err());
253 }
254
255 #[test]
256 fn ks_detects_distribution_shift() {
257 let calib: Vec<f64> = (0..100).map(|i| i as f64 / 100.0).collect();
258 let same: Vec<f64> = (0..100).map(|i| i as f64 / 100.0).collect();
259 let shifted: Vec<f64> = (0..100).map(|i| 0.5 + i as f64 / 200.0).collect();
260 assert!(!is_drifted(&same, &calib, 1.36));
261 assert!(is_drifted(&shifted, &calib, 1.36));
262 }
263
264 #[test]
265 fn calibrated_state_asserts_bound_stale_does_not() {
266 let state = CalibrationState::calibrate(&samples(), 0.2).unwrap();
267 assert!(state.bound_is_asserted());
268 let stale = state.mark_stale();
269 assert!(!stale.bound_is_asserted());
270 }
271
272 #[test]
273 fn stale_window_backs_off_and_does_not_hard_halt() {
274 let state = CalibrationState::calibrate(&samples(), 0.3)
275 .unwrap()
276 .mark_stale();
277 assert_eq!(
279 decide(&state, 0.99, RiskClass::Low),
280 AcceptOutcome::UncertifiedAccept
281 );
282 assert_eq!(
284 decide(&state, 0.99, RiskClass::High),
285 AcceptOutcome::RouteToApproval
286 );
287 }
288
289 #[test]
290 fn calibrated_window_certifies_accepts() {
291 let state = CalibrationState::calibrate(&samples(), 0.3).unwrap();
292 let threshold = match state {
293 CalibrationState::Calibrated { theta_hat, .. } => theta_hat,
294 _ => unreachable!(),
295 };
296 assert_eq!(
297 decide(&state, threshold + 0.05, RiskClass::Low),
298 AcceptOutcome::CertifiedAccept
299 );
300 assert_eq!(decide(&state, 0.0, RiskClass::Low), AcceptOutcome::Reject);
301 }
302}