1use crate::error::FdarError;
23use crate::matrix::FdMatrix;
24
25pub mod classification;
26pub mod cv;
27pub mod elastic;
28pub mod generic;
29pub mod regression;
30
31#[cfg(test)]
32mod tests;
33
34#[derive(Debug, Clone, Copy)]
40pub enum ConformalMethod {
41 Split,
43 CrossConformal { n_folds: usize },
45 JackknifePlus,
47}
48
49#[derive(Debug, Clone, Copy)]
51pub enum ClassificationScore {
52 Lac,
54 Aps,
56}
57
58#[derive(Debug, Clone)]
60pub struct ConformalRegressionResult {
61 pub predictions: Vec<f64>,
63 pub lower: Vec<f64>,
65 pub upper: Vec<f64>,
67 pub residual_quantile: f64,
69 pub coverage: f64,
71 pub calibration_scores: Vec<f64>,
73 pub method: ConformalMethod,
75}
76
77#[derive(Debug, Clone)]
79pub struct ConformalClassificationResult {
80 pub predicted_classes: Vec<usize>,
82 pub prediction_sets: Vec<Vec<usize>>,
84 pub set_sizes: Vec<usize>,
86 pub average_set_size: f64,
88 pub coverage: f64,
90 pub calibration_scores: Vec<f64>,
92 pub score_quantile: f64,
94 pub method: ConformalMethod,
96 pub score_type: ClassificationScore,
98}
99
100#[derive(Debug, Clone)]
115pub struct ConformalConfig {
116 pub cal_fraction: f64,
118 pub alpha: f64,
120 pub seed: u64,
122}
123
124impl Default for ConformalConfig {
125 fn default() -> Self {
126 Self {
127 cal_fraction: 0.25,
128 alpha: 0.1,
129 seed: 42,
130 }
131 }
132}
133
134pub(super) fn conformal_split(n: usize, cal_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
140 use rand::prelude::*;
141 let mut rng = StdRng::seed_from_u64(seed);
142 let mut all_idx: Vec<usize> = (0..n).collect();
143 all_idx.shuffle(&mut rng);
144 let n_cal = ((n as f64 * cal_fraction).round() as usize)
145 .max(2)
146 .min(n - 2);
147 let n_proper = n - n_cal;
148 let proper_idx = all_idx[..n_proper].to_vec();
149 let cal_idx = all_idx[n_proper..].to_vec();
150 (proper_idx, cal_idx)
151}
152
153pub(super) fn conformal_quantile(scores: &mut [f64], alpha: f64) -> f64 {
159 let n = scores.len();
160 if n == 0 {
161 return 0.0;
162 }
163 crate::helpers::sort_nan_safe(scores);
164 let k = ((n + 1) as f64 * (1.0 - alpha)).ceil() as usize;
165 if k > n {
166 return f64::INFINITY;
167 }
168 scores[k.saturating_sub(1)]
169}
170
171pub(super) fn empirical_coverage(scores: &[f64], quantile: f64) -> f64 {
173 let n = scores.len();
174 if n == 0 {
175 return 0.0;
176 }
177 scores.iter().filter(|&&s| s <= quantile).count() as f64 / n as f64
178}
179
180#[allow(dead_code)]
182pub(super) fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
183 let n = sorted.len();
184 if n == 0 {
185 return 0.0;
186 }
187 if n == 1 {
188 return sorted[0];
189 }
190 let idx = q * (n - 1) as f64;
191 let lo = (idx.floor() as usize).min(n - 1);
192 let hi = (idx.ceil() as usize).min(n - 1);
193 if lo == hi {
194 sorted[lo]
195 } else {
196 let frac = idx - lo as f64;
197 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
198 }
199}
200
201pub(super) fn build_regression_result(
203 mut cal_residuals: Vec<f64>,
204 test_predictions: Vec<f64>,
205 alpha: f64,
206 method: ConformalMethod,
207) -> ConformalRegressionResult {
208 let residual_quantile = conformal_quantile(&mut cal_residuals, alpha);
209 let coverage = empirical_coverage(&cal_residuals, residual_quantile);
210 let lower = test_predictions
211 .iter()
212 .map(|&p| p - residual_quantile)
213 .collect();
214 let upper = test_predictions
215 .iter()
216 .map(|&p| p + residual_quantile)
217 .collect();
218 ConformalRegressionResult {
219 predictions: test_predictions,
220 lower,
221 upper,
222 residual_quantile,
223 coverage,
224 calibration_scores: cal_residuals,
225 method,
226 }
227}
228
229pub(super) fn lac_score(probs: &[f64], true_class: usize) -> f64 {
231 if true_class < probs.len() {
232 1.0 - probs[true_class]
233 } else {
234 1.0
235 }
236}
237
238pub(super) fn aps_score(probs: &[f64], true_class: usize) -> f64 {
240 let g = probs.len();
241 let mut order: Vec<usize> = (0..g).collect();
242 order.sort_by(|&a, &b| {
243 probs[b]
244 .partial_cmp(&probs[a])
245 .unwrap_or(std::cmp::Ordering::Equal)
246 });
247 let mut cum = 0.0;
248 for &c in &order {
249 cum += probs[c];
250 if c == true_class {
251 return cum;
252 }
253 }
254 1.0
255}
256
257pub(super) fn lac_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
259 (0..probs.len())
260 .filter(|&k| 1.0 - probs[k] <= quantile)
261 .collect()
262}
263
264pub(super) fn aps_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
270 let g = probs.len();
271 let mut order: Vec<usize> = (0..g).collect();
272 order.sort_by(|&a, &b| {
273 probs[b]
274 .partial_cmp(&probs[a])
275 .unwrap_or(std::cmp::Ordering::Equal)
276 });
277 let mut cum = 0.0;
278 let mut set = Vec::new();
279 for &c in &order {
280 set.push(c);
281 cum += probs[c];
282 if cum >= quantile {
283 break;
284 }
285 }
286 if set.is_empty() && g > 0 {
287 set.push(order[0]);
288 }
289 set
290}
291
292pub(super) fn build_classification_result(
294 mut cal_scores: Vec<f64>,
295 test_probs: &[Vec<f64>],
296 test_pred_classes: Vec<usize>,
297 alpha: f64,
298 method: ConformalMethod,
299 score_type: ClassificationScore,
300) -> ConformalClassificationResult {
301 let score_quantile = conformal_quantile(&mut cal_scores, alpha);
302 let coverage = empirical_coverage(&cal_scores, score_quantile);
303
304 let prediction_sets: Vec<Vec<usize>> = test_probs
305 .iter()
306 .map(|probs| match score_type {
307 ClassificationScore::Lac => lac_prediction_set(probs, score_quantile),
308 ClassificationScore::Aps => aps_prediction_set(probs, score_quantile),
309 })
310 .collect();
311
312 let set_sizes: Vec<usize> = prediction_sets.iter().map(std::vec::Vec::len).collect();
313 let average_set_size = if set_sizes.is_empty() {
314 0.0
315 } else {
316 set_sizes.iter().sum::<usize>() as f64 / set_sizes.len() as f64
317 };
318
319 ConformalClassificationResult {
320 predicted_classes: test_pred_classes,
321 prediction_sets,
322 set_sizes,
323 average_set_size,
324 coverage,
325 calibration_scores: cal_scores,
326 score_quantile,
327 method,
328 score_type,
329 }
330}
331
332pub(super) fn compute_cal_scores(
334 probs: &[Vec<f64>],
335 true_classes: &[usize],
336 score_type: ClassificationScore,
337) -> Vec<f64> {
338 probs
339 .iter()
340 .zip(true_classes.iter())
341 .map(|(p, &y)| match score_type {
342 ClassificationScore::Lac => lac_score(p, y),
343 ClassificationScore::Aps => aps_score(p, y),
344 })
345 .collect()
346}
347
348pub(super) fn vstack(a: &FdMatrix, b: &FdMatrix) -> FdMatrix {
350 let m = a.ncols();
351 debug_assert_eq!(m, b.ncols());
352 let na = a.nrows();
353 let nb = b.nrows();
354 let mut out = FdMatrix::zeros(na + nb, m);
355 for j in 0..m {
356 for i in 0..na {
357 out[(i, j)] = a[(i, j)];
358 }
359 for i in 0..nb {
360 out[(na + i, j)] = b[(i, j)];
361 }
362 }
363 out
364}
365
366pub(super) fn vstack_opt(a: Option<&FdMatrix>, b: Option<&FdMatrix>) -> Option<FdMatrix> {
368 match (a, b) {
369 (Some(a), Some(b)) => Some(vstack(a, b)),
370 _ => None,
371 }
372}
373
374pub(super) fn subset_vec_usize(v: &[usize], indices: &[usize]) -> Vec<usize> {
376 indices.iter().map(|&i| v[i]).collect()
377}
378
379pub(super) fn subset_vec_i8(v: &[i8], indices: &[usize]) -> Vec<i8> {
381 indices.iter().map(|&i| v[i]).collect()
382}
383
384pub(super) fn argmax(probs: &[f64]) -> usize {
386 probs
387 .iter()
388 .enumerate()
389 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
390 .map_or(0, |(i, _)| i)
391}
392
393pub(super) fn validate_split_inputs(
395 n: usize,
396 n_test: usize,
397 cal_fraction: f64,
398 alpha: f64,
399) -> Result<(), FdarError> {
400 if n < 4 {
401 return Err(FdarError::InvalidDimension {
402 parameter: "data",
403 expected: "at least 4 observations".to_string(),
404 actual: format!("{n}"),
405 });
406 }
407 if n_test == 0 {
408 return Err(FdarError::InvalidDimension {
409 parameter: "test_data",
410 expected: "at least 1 observation".to_string(),
411 actual: "0".to_string(),
412 });
413 }
414 if cal_fraction <= 0.0 || cal_fraction >= 1.0 {
415 return Err(FdarError::InvalidParameter {
416 parameter: "cal_fraction",
417 message: format!("must be in (0, 1), got {cal_fraction}"),
418 });
419 }
420 if alpha <= 0.0 || alpha >= 1.0 {
421 return Err(FdarError::InvalidParameter {
422 parameter: "alpha",
423 message: format!("must be in (0, 1), got {alpha}"),
424 });
425 }
426 Ok(())
427}
428
429pub use classification::{conformal_classif, conformal_elastic_logistic, conformal_logistic};
434pub use cv::{cv_conformal_classification, cv_conformal_regression, jackknife_plus_regression};
435pub use elastic::{
436 conformal_elastic_pcr, conformal_elastic_pcr_with_config, conformal_elastic_regression,
437 conformal_elastic_regression_with_config,
438};
439pub use generic::{conformal_generic_classification, conformal_generic_regression};
440pub use regression::{conformal_fregre_lm, conformal_fregre_np};