1use crate::error::FdarError;
23use crate::matrix::FdMatrix;
24
25pub mod classification;
26pub mod cv;
27pub mod elastic;
28pub mod generic;
29pub mod regression;
30
31#[cfg(test)]
32mod tests;
33
34#[derive(Debug, Clone, Copy)]
40#[non_exhaustive]
41pub enum ConformalMethod {
42 Split,
44 CrossConformal { n_folds: usize },
46 JackknifePlus,
48}
49
50#[derive(Debug, Clone, Copy)]
52#[non_exhaustive]
53pub enum ClassificationScore {
54 Lac,
56 Aps,
58}
59
60#[derive(Debug, Clone)]
62#[non_exhaustive]
63pub struct ConformalRegressionResult {
64 pub predictions: Vec<f64>,
66 pub lower: Vec<f64>,
68 pub upper: Vec<f64>,
70 pub residual_quantile: f64,
72 pub coverage: f64,
74 pub calibration_scores: Vec<f64>,
76 pub method: ConformalMethod,
78}
79
80#[derive(Debug, Clone)]
82#[non_exhaustive]
83pub struct ConformalClassificationResult {
84 pub predicted_classes: Vec<usize>,
86 pub prediction_sets: Vec<Vec<usize>>,
88 pub set_sizes: Vec<usize>,
90 pub average_set_size: f64,
92 pub coverage: f64,
94 pub calibration_scores: Vec<f64>,
96 pub score_quantile: f64,
98 pub method: ConformalMethod,
100 pub score_type: ClassificationScore,
102}
103
104#[derive(Debug, Clone)]
119pub struct ConformalConfig {
120 pub cal_fraction: f64,
122 pub alpha: f64,
124 pub seed: u64,
126}
127
128impl Default for ConformalConfig {
129 fn default() -> Self {
130 Self {
131 cal_fraction: 0.25,
132 alpha: 0.1,
133 seed: 42,
134 }
135 }
136}
137
138pub(super) fn conformal_split(n: usize, cal_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
144 use rand::prelude::*;
145 let mut rng = StdRng::seed_from_u64(seed);
146 let mut all_idx: Vec<usize> = (0..n).collect();
147 all_idx.shuffle(&mut rng);
148 let n_cal = ((n as f64 * cal_fraction).round() as usize)
149 .max(2)
150 .min(n - 2);
151 let n_proper = n - n_cal;
152 let proper_idx = all_idx[..n_proper].to_vec();
153 let cal_idx = all_idx[n_proper..].to_vec();
154 (proper_idx, cal_idx)
155}
156
157pub(super) fn conformal_quantile(scores: &mut [f64], alpha: f64) -> f64 {
163 let n = scores.len();
164 if n == 0 {
165 return 0.0;
166 }
167 crate::helpers::sort_nan_safe(scores);
168 let k = ((n + 1) as f64 * (1.0 - alpha)).ceil() as usize;
169 if k > n {
170 return f64::INFINITY;
171 }
172 scores[k.saturating_sub(1)]
173}
174
175pub(super) fn empirical_coverage(scores: &[f64], quantile: f64) -> f64 {
177 let n = scores.len();
178 if n == 0 {
179 return 0.0;
180 }
181 scores.iter().filter(|&&s| s <= quantile).count() as f64 / n as f64
182}
183
184#[allow(dead_code)]
186pub(super) fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
187 let n = sorted.len();
188 if n == 0 {
189 return 0.0;
190 }
191 if n == 1 {
192 return sorted[0];
193 }
194 let idx = q * (n - 1) as f64;
195 let lo = (idx.floor() as usize).min(n - 1);
196 let hi = (idx.ceil() as usize).min(n - 1);
197 if lo == hi {
198 sorted[lo]
199 } else {
200 let frac = idx - lo as f64;
201 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
202 }
203}
204
205pub(super) fn build_regression_result(
207 mut cal_residuals: Vec<f64>,
208 test_predictions: Vec<f64>,
209 alpha: f64,
210 method: ConformalMethod,
211) -> ConformalRegressionResult {
212 let residual_quantile = conformal_quantile(&mut cal_residuals, alpha);
213 let coverage = empirical_coverage(&cal_residuals, residual_quantile);
214 let lower = test_predictions
215 .iter()
216 .map(|&p| p - residual_quantile)
217 .collect();
218 let upper = test_predictions
219 .iter()
220 .map(|&p| p + residual_quantile)
221 .collect();
222 ConformalRegressionResult {
223 predictions: test_predictions,
224 lower,
225 upper,
226 residual_quantile,
227 coverage,
228 calibration_scores: cal_residuals,
229 method,
230 }
231}
232
233pub(super) fn lac_score(probs: &[f64], true_class: usize) -> f64 {
235 if true_class < probs.len() {
236 1.0 - probs[true_class]
237 } else {
238 1.0
239 }
240}
241
242pub(super) fn aps_score(probs: &[f64], true_class: usize) -> f64 {
244 let g = probs.len();
245 let mut order: Vec<usize> = (0..g).collect();
246 order.sort_by(|&a, &b| {
247 probs[b]
248 .partial_cmp(&probs[a])
249 .unwrap_or(std::cmp::Ordering::Equal)
250 });
251 let mut cum = 0.0;
252 for &c in &order {
253 cum += probs[c];
254 if c == true_class {
255 return cum;
256 }
257 }
258 1.0
259}
260
261pub(super) fn lac_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
263 (0..probs.len())
264 .filter(|&k| 1.0 - probs[k] <= quantile)
265 .collect()
266}
267
268pub(super) fn aps_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
274 let g = probs.len();
275 let mut order: Vec<usize> = (0..g).collect();
276 order.sort_by(|&a, &b| {
277 probs[b]
278 .partial_cmp(&probs[a])
279 .unwrap_or(std::cmp::Ordering::Equal)
280 });
281 let mut cum = 0.0;
282 let mut set = Vec::new();
283 for &c in &order {
284 set.push(c);
285 cum += probs[c];
286 if cum >= quantile {
287 break;
288 }
289 }
290 if set.is_empty() && g > 0 {
291 set.push(order[0]);
292 }
293 set
294}
295
296pub(super) fn build_classification_result(
298 mut cal_scores: Vec<f64>,
299 test_probs: &[Vec<f64>],
300 test_pred_classes: Vec<usize>,
301 alpha: f64,
302 method: ConformalMethod,
303 score_type: ClassificationScore,
304) -> ConformalClassificationResult {
305 let score_quantile = conformal_quantile(&mut cal_scores, alpha);
306 let coverage = empirical_coverage(&cal_scores, score_quantile);
307
308 let prediction_sets: Vec<Vec<usize>> = test_probs
309 .iter()
310 .map(|probs| match score_type {
311 ClassificationScore::Lac => lac_prediction_set(probs, score_quantile),
312 ClassificationScore::Aps => aps_prediction_set(probs, score_quantile),
313 })
314 .collect();
315
316 let set_sizes: Vec<usize> = prediction_sets.iter().map(std::vec::Vec::len).collect();
317 let average_set_size = if set_sizes.is_empty() {
318 0.0
319 } else {
320 set_sizes.iter().sum::<usize>() as f64 / set_sizes.len() as f64
321 };
322
323 ConformalClassificationResult {
324 predicted_classes: test_pred_classes,
325 prediction_sets,
326 set_sizes,
327 average_set_size,
328 coverage,
329 calibration_scores: cal_scores,
330 score_quantile,
331 method,
332 score_type,
333 }
334}
335
336pub(super) fn compute_cal_scores(
338 probs: &[Vec<f64>],
339 true_classes: &[usize],
340 score_type: ClassificationScore,
341) -> Vec<f64> {
342 probs
343 .iter()
344 .zip(true_classes.iter())
345 .map(|(p, &y)| match score_type {
346 ClassificationScore::Lac => lac_score(p, y),
347 ClassificationScore::Aps => aps_score(p, y),
348 })
349 .collect()
350}
351
352pub(super) fn vstack(a: &FdMatrix, b: &FdMatrix) -> FdMatrix {
354 let m = a.ncols();
355 debug_assert_eq!(m, b.ncols());
356 let na = a.nrows();
357 let nb = b.nrows();
358 let mut out = FdMatrix::zeros(na + nb, m);
359 for j in 0..m {
360 for i in 0..na {
361 out[(i, j)] = a[(i, j)];
362 }
363 for i in 0..nb {
364 out[(na + i, j)] = b[(i, j)];
365 }
366 }
367 out
368}
369
370pub(super) fn vstack_opt(a: Option<&FdMatrix>, b: Option<&FdMatrix>) -> Option<FdMatrix> {
372 match (a, b) {
373 (Some(a), Some(b)) => Some(vstack(a, b)),
374 _ => None,
375 }
376}
377
378pub(super) fn subset_vec_usize(v: &[usize], indices: &[usize]) -> Vec<usize> {
380 indices.iter().map(|&i| v[i]).collect()
381}
382
383pub(super) fn subset_vec_i8(v: &[i8], indices: &[usize]) -> Vec<i8> {
385 indices.iter().map(|&i| v[i]).collect()
386}
387
388pub(super) fn argmax(probs: &[f64]) -> usize {
390 probs
391 .iter()
392 .enumerate()
393 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
394 .map_or(0, |(i, _)| i)
395}
396
397pub(super) fn validate_split_inputs(
399 n: usize,
400 n_test: usize,
401 cal_fraction: f64,
402 alpha: f64,
403) -> Result<(), FdarError> {
404 if n < 4 {
405 return Err(FdarError::InvalidDimension {
406 parameter: "data",
407 expected: "at least 4 observations".to_string(),
408 actual: format!("{n}"),
409 });
410 }
411 if n_test == 0 {
412 return Err(FdarError::InvalidDimension {
413 parameter: "test_data",
414 expected: "at least 1 observation".to_string(),
415 actual: "0".to_string(),
416 });
417 }
418 if cal_fraction <= 0.0 || cal_fraction >= 1.0 {
419 return Err(FdarError::InvalidParameter {
420 parameter: "cal_fraction",
421 message: format!("must be in (0, 1), got {cal_fraction}"),
422 });
423 }
424 if alpha <= 0.0 || alpha >= 1.0 {
425 return Err(FdarError::InvalidParameter {
426 parameter: "alpha",
427 message: format!("must be in (0, 1), got {alpha}"),
428 });
429 }
430 Ok(())
431}
432
433pub use classification::{conformal_classif, conformal_elastic_logistic, conformal_logistic};
438pub use cv::{cv_conformal_classification, cv_conformal_regression, jackknife_plus_regression};
439pub use elastic::{
440 conformal_elastic_pcr, conformal_elastic_pcr_with_config, conformal_elastic_regression,
441 conformal_elastic_regression_with_config,
442};
443pub use generic::{conformal_generic_classification, conformal_generic_regression};
444pub use regression::{conformal_fregre_lm, conformal_fregre_np};