1use crate::error::FdarError;
23use crate::matrix::FdMatrix;
24
25pub mod classification;
26pub mod cv;
27pub mod elastic;
28pub mod generic;
29pub mod regression;
30
31#[cfg(test)]
32mod tests;
33
34#[derive(Debug, Clone, Copy)]
40#[non_exhaustive]
41pub enum ConformalMethod {
42 Split,
44 CrossConformal { n_folds: usize },
46 JackknifePlus,
48}
49
50#[derive(Debug, Clone, Copy)]
52#[non_exhaustive]
53pub enum ClassificationScore {
54 Lac,
56 Aps,
58}
59
60#[derive(Debug, Clone)]
62#[non_exhaustive]
63pub struct ConformalRegressionResult {
64 pub predictions: Vec<f64>,
66 pub lower: Vec<f64>,
68 pub upper: Vec<f64>,
70 pub residual_quantile: f64,
72 pub coverage: f64,
74 pub calibration_scores: Vec<f64>,
76 pub method: ConformalMethod,
78}
79
80#[derive(Debug, Clone)]
82#[non_exhaustive]
83pub struct ConformalClassificationResult {
84 pub predicted_classes: Vec<usize>,
86 pub prediction_sets: Vec<Vec<usize>>,
88 pub set_sizes: Vec<usize>,
90 pub average_set_size: f64,
92 pub coverage: f64,
94 pub calibration_scores: Vec<f64>,
96 pub score_quantile: f64,
98 pub method: ConformalMethod,
100 pub score_type: ClassificationScore,
102}
103
104#[derive(Debug, Clone, PartialEq)]
117#[non_exhaustive]
118pub struct ConformalConfig {
119 pub cal_fraction: f64,
121 pub alpha: f64,
123 pub seed: u64,
125}
126
127impl Default for ConformalConfig {
128 fn default() -> Self {
129 Self {
130 cal_fraction: 0.25,
131 alpha: 0.1,
132 seed: 42,
133 }
134 }
135}
136
137pub(super) fn conformal_split(n: usize, cal_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
143 use rand::prelude::*;
144 let mut rng = StdRng::seed_from_u64(seed);
145 let mut all_idx: Vec<usize> = (0..n).collect();
146 all_idx.shuffle(&mut rng);
147 let n_cal = ((n as f64 * cal_fraction).round() as usize)
148 .max(2)
149 .min(n - 2);
150 let n_proper = n - n_cal;
151 let proper_idx = all_idx[..n_proper].to_vec();
152 let cal_idx = all_idx[n_proper..].to_vec();
153 (proper_idx, cal_idx)
154}
155
156pub(super) fn conformal_quantile(scores: &mut [f64], alpha: f64) -> f64 {
162 let n = scores.len();
163 if n == 0 {
164 return 0.0;
165 }
166 crate::helpers::sort_nan_safe(scores);
167 let k = ((n + 1) as f64 * (1.0 - alpha)).ceil() as usize;
168 if k > n {
169 return f64::INFINITY;
170 }
171 scores[k.saturating_sub(1)]
172}
173
174pub(super) fn empirical_coverage(scores: &[f64], quantile: f64) -> f64 {
176 let n = scores.len();
177 if n == 0 {
178 return 0.0;
179 }
180 scores.iter().filter(|&&s| s <= quantile).count() as f64 / n as f64
181}
182
183#[allow(unused_imports)]
185pub(super) use crate::helpers::quantile_sorted;
186
187pub(super) fn build_regression_result(
189 mut cal_residuals: Vec<f64>,
190 test_predictions: Vec<f64>,
191 alpha: f64,
192 method: ConformalMethod,
193) -> ConformalRegressionResult {
194 let residual_quantile = conformal_quantile(&mut cal_residuals, alpha);
195 let coverage = empirical_coverage(&cal_residuals, residual_quantile);
196 let lower = test_predictions
197 .iter()
198 .map(|&p| p - residual_quantile)
199 .collect();
200 let upper = test_predictions
201 .iter()
202 .map(|&p| p + residual_quantile)
203 .collect();
204 ConformalRegressionResult {
205 predictions: test_predictions,
206 lower,
207 upper,
208 residual_quantile,
209 coverage,
210 calibration_scores: cal_residuals,
211 method,
212 }
213}
214
215pub(super) fn lac_score(probs: &[f64], true_class: usize) -> f64 {
217 if true_class < probs.len() {
218 1.0 - probs[true_class]
219 } else {
220 1.0
221 }
222}
223
224pub(super) fn aps_score(probs: &[f64], true_class: usize) -> f64 {
226 let g = probs.len();
227 let mut order: Vec<usize> = (0..g).collect();
228 order.sort_by(|&a, &b| {
229 probs[b]
230 .partial_cmp(&probs[a])
231 .unwrap_or(std::cmp::Ordering::Equal)
232 });
233 let mut cum = 0.0;
234 for &c in &order {
235 cum += probs[c];
236 if c == true_class {
237 return cum;
238 }
239 }
240 1.0
241}
242
243pub(super) fn lac_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
245 (0..probs.len())
246 .filter(|&k| 1.0 - probs[k] <= quantile)
247 .collect()
248}
249
250pub(super) fn aps_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
256 let g = probs.len();
257 let mut order: Vec<usize> = (0..g).collect();
258 order.sort_by(|&a, &b| {
259 probs[b]
260 .partial_cmp(&probs[a])
261 .unwrap_or(std::cmp::Ordering::Equal)
262 });
263 let mut cum = 0.0;
264 let mut set = Vec::new();
265 for &c in &order {
266 set.push(c);
267 cum += probs[c];
268 if cum >= quantile {
269 break;
270 }
271 }
272 if set.is_empty() && g > 0 {
273 set.push(order[0]);
274 }
275 set
276}
277
278pub(super) fn build_classification_result(
280 mut cal_scores: Vec<f64>,
281 test_probs: &[Vec<f64>],
282 test_pred_classes: Vec<usize>,
283 alpha: f64,
284 method: ConformalMethod,
285 score_type: ClassificationScore,
286) -> ConformalClassificationResult {
287 let score_quantile = conformal_quantile(&mut cal_scores, alpha);
288 let coverage = empirical_coverage(&cal_scores, score_quantile);
289
290 let prediction_sets: Vec<Vec<usize>> = test_probs
291 .iter()
292 .map(|probs| match score_type {
293 ClassificationScore::Lac => lac_prediction_set(probs, score_quantile),
294 ClassificationScore::Aps => aps_prediction_set(probs, score_quantile),
295 })
296 .collect();
297
298 let set_sizes: Vec<usize> = prediction_sets.iter().map(std::vec::Vec::len).collect();
299 let average_set_size = if set_sizes.is_empty() {
300 0.0
301 } else {
302 set_sizes.iter().sum::<usize>() as f64 / set_sizes.len() as f64
303 };
304
305 ConformalClassificationResult {
306 predicted_classes: test_pred_classes,
307 prediction_sets,
308 set_sizes,
309 average_set_size,
310 coverage,
311 calibration_scores: cal_scores,
312 score_quantile,
313 method,
314 score_type,
315 }
316}
317
318pub(super) fn compute_cal_scores(
320 probs: &[Vec<f64>],
321 true_classes: &[usize],
322 score_type: ClassificationScore,
323) -> Vec<f64> {
324 probs
325 .iter()
326 .zip(true_classes.iter())
327 .map(|(p, &y)| match score_type {
328 ClassificationScore::Lac => lac_score(p, y),
329 ClassificationScore::Aps => aps_score(p, y),
330 })
331 .collect()
332}
333
334pub(super) fn vstack(a: &FdMatrix, b: &FdMatrix) -> FdMatrix {
336 let m = a.ncols();
337 debug_assert_eq!(m, b.ncols());
338 let na = a.nrows();
339 let nb = b.nrows();
340 let mut out = FdMatrix::zeros(na + nb, m);
341 for j in 0..m {
342 for i in 0..na {
343 out[(i, j)] = a[(i, j)];
344 }
345 for i in 0..nb {
346 out[(na + i, j)] = b[(i, j)];
347 }
348 }
349 out
350}
351
352pub(super) fn vstack_opt(a: Option<&FdMatrix>, b: Option<&FdMatrix>) -> Option<FdMatrix> {
354 match (a, b) {
355 (Some(a), Some(b)) => Some(vstack(a, b)),
356 _ => None,
357 }
358}
359
360pub(super) fn subset_vec_usize(v: &[usize], indices: &[usize]) -> Vec<usize> {
362 indices.iter().map(|&i| v[i]).collect()
363}
364
365pub(super) fn subset_vec_i8(v: &[i8], indices: &[usize]) -> Vec<i8> {
367 indices.iter().map(|&i| v[i]).collect()
368}
369
370pub(super) fn argmax(probs: &[f64]) -> usize {
372 probs
373 .iter()
374 .enumerate()
375 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
376 .map_or(0, |(i, _)| i)
377}
378
379pub(super) fn validate_split_inputs(
381 n: usize,
382 n_test: usize,
383 cal_fraction: f64,
384 alpha: f64,
385) -> Result<(), FdarError> {
386 if n < 4 {
387 return Err(FdarError::InvalidDimension {
388 parameter: "data",
389 expected: "at least 4 observations".to_string(),
390 actual: format!("{n}"),
391 });
392 }
393 if n_test == 0 {
394 return Err(FdarError::InvalidDimension {
395 parameter: "test_data",
396 expected: "at least 1 observation".to_string(),
397 actual: "0".to_string(),
398 });
399 }
400 if cal_fraction <= 0.0 || cal_fraction >= 1.0 {
401 return Err(FdarError::InvalidParameter {
402 parameter: "cal_fraction",
403 message: format!("must be in (0, 1), got {cal_fraction}"),
404 });
405 }
406 if alpha <= 0.0 || alpha >= 1.0 {
407 return Err(FdarError::InvalidParameter {
408 parameter: "alpha",
409 message: format!("must be in (0, 1), got {alpha}"),
410 });
411 }
412 Ok(())
413}
414
415pub use classification::{conformal_classif, conformal_elastic_logistic, conformal_logistic};
420pub use cv::{cv_conformal_classification, cv_conformal_regression, jackknife_plus_regression};
421pub use elastic::{
422 conformal_elastic_pcr, conformal_elastic_pcr_with_config, conformal_elastic_regression,
423 conformal_elastic_regression_with_config,
424};
425pub use generic::{conformal_generic_classification, conformal_generic_regression};
426pub use regression::{conformal_fregre_lm, conformal_fregre_np};