1use crate::error::FdarError;
4use crate::explain::subsample_rows;
5use crate::explain_generic::{FpcPredictor, TaskType};
6use crate::matrix::FdMatrix;
7
8use super::{
9 argmax, build_classification_result, build_regression_result, compute_cal_scores,
10 conformal_split, subset_vec_usize, validate_split_inputs, ClassificationScore,
11 ConformalClassificationResult, ConformalMethod, ConformalRegressionResult,
12};
13
14#[must_use = "expensive computation whose result should not be discarded"]
45pub fn conformal_generic_regression(
46 model: &dyn FpcPredictor,
47 data: &FdMatrix,
48 y: &[f64],
49 test_data: &FdMatrix,
50 scalar_train: Option<&FdMatrix>,
51 scalar_test: Option<&FdMatrix>,
52 calibration_indices: Option<&[usize]>,
53 cal_fraction: f64,
54 alpha: f64,
55 seed: u64,
56) -> Result<ConformalRegressionResult, FdarError> {
57 let n = data.nrows();
58 validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha)?;
59 if y.len() != n {
60 return Err(FdarError::InvalidDimension {
61 parameter: "y",
62 expected: format!("{n}"),
63 actual: format!("{}", y.len()),
64 });
65 }
66
67 let cal_idx = resolve_calibration_indices(calibration_indices, n, cal_fraction, seed)?;
68
69 let cal_data = subsample_rows(data, &cal_idx);
71 let cal_sc = scalar_train.map(|sc| subsample_rows(sc, &cal_idx));
72 let cal_scores_mat = model.project(&cal_data);
73 let ncomp = model.ncomp();
74
75 let cal_preds: Vec<f64> = (0..cal_idx.len())
76 .map(|i| {
77 let s: Vec<f64> = (0..ncomp).map(|k| cal_scores_mat[(i, k)]).collect();
78 let sc_row: Option<Vec<f64>> = cal_sc
79 .as_ref()
80 .map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
81 model.predict_from_scores(&s, sc_row.as_deref())
82 })
83 .collect();
84
85 let cal_residuals: Vec<f64> = cal_idx
86 .iter()
87 .enumerate()
88 .map(|(i, &orig)| (y[orig] - cal_preds[i]).abs())
89 .collect();
90
91 let test_scores_mat = model.project(test_data);
93 let test_preds: Vec<f64> = (0..test_data.nrows())
94 .map(|i| {
95 let s: Vec<f64> = (0..ncomp).map(|k| test_scores_mat[(i, k)]).collect();
96 let sc_row: Option<Vec<f64>> =
97 scalar_test.map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
98 model.predict_from_scores(&s, sc_row.as_deref())
99 })
100 .collect();
101
102 Ok(build_regression_result(
103 cal_residuals,
104 test_preds,
105 alpha,
106 ConformalMethod::Split,
107 ))
108}
109
110#[must_use = "expensive computation whose result should not be discarded"]
145pub fn conformal_generic_classification(
146 model: &dyn FpcPredictor,
147 data: &FdMatrix,
148 y: &[usize],
149 test_data: &FdMatrix,
150 scalar_train: Option<&FdMatrix>,
151 scalar_test: Option<&FdMatrix>,
152 score_type: ClassificationScore,
153 calibration_indices: Option<&[usize]>,
154 cal_fraction: f64,
155 alpha: f64,
156 seed: u64,
157) -> Result<ConformalClassificationResult, FdarError> {
158 let n = data.nrows();
159 validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha)?;
160 if y.len() != n {
161 return Err(FdarError::InvalidDimension {
162 parameter: "y",
163 expected: format!("{n}"),
164 actual: format!("{}", y.len()),
165 });
166 }
167
168 match model.task_type() {
169 TaskType::BinaryClassification => {}
170 TaskType::MulticlassClassification(g) => {
171 return Err(FdarError::InvalidParameter {
172 parameter: "model",
173 message: format!(
174 "conformal_generic_classification does not support multiclass models \
175 ({g} classes): FpcPredictor::predict_from_scores returns a class label, \
176 not probabilities, which produces degenerate one-hot conformal sets. \
177 Use cv_conformal_classification with a closure that returns proper \
178 probabilities instead."
179 ),
180 })
181 }
182 TaskType::Regression => {
183 return Err(FdarError::InvalidParameter {
184 parameter: "model",
185 message: "expected a classification model, got regression".to_string(),
186 })
187 }
188 };
189
190 let cal_idx = resolve_calibration_indices(calibration_indices, n, cal_fraction, seed)?;
191 let ncomp = model.ncomp();
192
193 let cal_data = subsample_rows(data, &cal_idx);
195 let cal_sc = scalar_train.map(|sc| subsample_rows(sc, &cal_idx));
196 let cal_scores_mat = model.project(&cal_data);
197 let cal_probs: Vec<Vec<f64>> = (0..cal_idx.len())
198 .map(|i| {
199 let s: Vec<f64> = (0..ncomp).map(|k| cal_scores_mat[(i, k)]).collect();
200 let sc_row: Option<Vec<f64>> = cal_sc
201 .as_ref()
202 .map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
203 let pred = model.predict_from_scores(&s, sc_row.as_deref());
204 vec![1.0 - pred, pred]
205 })
206 .collect();
207
208 let cal_true = subset_vec_usize(y, &cal_idx);
209 let cal_scores = compute_cal_scores(&cal_probs, &cal_true, score_type);
210
211 let test_scores_mat = model.project(test_data);
213 let test_probs: Vec<Vec<f64>> = (0..test_data.nrows())
214 .map(|i| {
215 let s: Vec<f64> = (0..ncomp).map(|k| test_scores_mat[(i, k)]).collect();
216 let sc_row: Option<Vec<f64>> =
217 scalar_test.map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
218 let pred = model.predict_from_scores(&s, sc_row.as_deref());
219 vec![1.0 - pred, pred]
220 })
221 .collect();
222
223 let test_pred_classes: Vec<usize> = test_probs.iter().map(|p| argmax(p)).collect();
224
225 Ok(build_classification_result(
226 cal_scores,
227 &test_probs,
228 test_pred_classes,
229 alpha,
230 ConformalMethod::Split,
231 score_type,
232 ))
233}
234
235fn resolve_calibration_indices(
243 calibration_indices: Option<&[usize]>,
244 n: usize,
245 cal_fraction: f64,
246 seed: u64,
247) -> Result<Vec<usize>, FdarError> {
248 match calibration_indices {
249 Some(indices) => {
250 if indices.len() < 2 {
251 return Err(FdarError::InvalidParameter {
252 parameter: "calibration_indices",
253 message: format!(
254 "need at least 2 calibration observations, got {}",
255 indices.len()
256 ),
257 });
258 }
259 for (pos, &idx) in indices.iter().enumerate() {
260 if idx >= n {
261 return Err(FdarError::InvalidDimension {
262 parameter: "calibration_indices",
263 expected: format!("indices in 0..{n}"),
264 actual: format!("index {idx} at position {pos}"),
265 });
266 }
267 }
268 Ok(indices.to_vec())
269 }
270 None => {
271 let (_proper_idx, cal_idx) = conformal_split(n, cal_fraction, seed);
272 Ok(cal_idx)
273 }
274 }
275}