1use anofox_ml_core::{Fit, Float, Predict, Result, RustMlError};
7use ndarray::{Array1, Array2};
8
9#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
11pub enum CalibrationMethod {
12 Sigmoid,
14 Isotonic,
16}
17
18impl Default for CalibrationMethod {
19 fn default() -> Self {
20 CalibrationMethod::Sigmoid
21 }
22}
23
24trait FitPredBox<F: Float>: Send + Sync {
26 fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>>;
27}
28
29trait PredBox<F: Float>: Send + Sync {
30 fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>>;
31}
32
33impl<F, T> FitPredBox<F> for T
34where
35 F: Float,
36 T: Fit<F> + Send + Sync,
37 T::Fitted: Predict<F> + Send + Sync + 'static,
38{
39 fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>> {
40 let fitted = Fit::fit(self, x, y)?;
41 Ok(Box::new(fitted))
42 }
43}
44
45impl<F, T> PredBox<F> for T
46where
47 F: Float,
48 T: Predict<F> + Send + Sync,
49{
50 fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>> {
51 self.predict(x)
52 }
53}
54
55pub struct CalibratedClassifierCV<F: Float> {
61 base_estimator: Box<dyn FitPredBox<F>>,
62 method: CalibrationMethod,
63 cv_folds: usize,
64}
65
66impl<F: Float> CalibratedClassifierCV<F> {
67 pub fn new<T>(base_estimator: T) -> Self
69 where
70 T: Fit<F> + Send + Sync + 'static,
71 T::Fitted: Predict<F> + Send + Sync + 'static,
72 {
73 Self {
74 base_estimator: Box::new(base_estimator),
75 method: CalibrationMethod::Sigmoid,
76 cv_folds: 5,
77 }
78 }
79
80 pub fn with_method(mut self, method: CalibrationMethod) -> Self {
81 self.method = method;
82 self
83 }
84
85 pub fn with_cv_folds(mut self, cv_folds: usize) -> Self {
86 self.cv_folds = cv_folds;
87 self
88 }
89}
90
91pub struct FittedCalibratedClassifier<F: Float> {
93 base_model: Box<dyn PredBox<F>>,
95 cal_a: f64,
97 cal_b: f64,
98 isotonic_x: Vec<f64>,
100 isotonic_y: Vec<f64>,
101 method: CalibrationMethod,
102 n_features: usize,
103}
104
105impl<F: Float> FittedCalibratedClassifier<F> {
106 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array1<F>> {
108 if x.ncols() != self.n_features {
109 return Err(RustMlError::ShapeMismatch(format!(
110 "expected {} features, got {}",
111 self.n_features,
112 x.ncols()
113 )));
114 }
115
116 let raw_preds = self.base_model.predict_box(x)?;
117 let n = raw_preds.len();
118 let mut proba = Array1::zeros(n);
119
120 for i in 0..n {
121 let score = raw_preds[i].to_f64().unwrap();
122 let p = match self.method {
123 CalibrationMethod::Sigmoid => {
124 1.0 / (1.0 + (-(self.cal_a * score + self.cal_b)).exp())
125 }
126 CalibrationMethod::Isotonic => {
127 isotonic_predict(score, &self.isotonic_x, &self.isotonic_y)
128 }
129 };
130 proba[i] = F::from_f64(p.clamp(0.0, 1.0)).unwrap();
131 }
132
133 Ok(proba)
134 }
135}
136
137impl<F: Float + 'static> Fit<F> for CalibratedClassifierCV<F> {
138 type Fitted = FittedCalibratedClassifier<F>;
139
140 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
141 if x.nrows() != y.len() {
142 return Err(RustMlError::ShapeMismatch(format!(
143 "X has {} rows but y has {} elements",
144 x.nrows(),
145 y.len()
146 )));
147 }
148 let n = x.nrows();
149 if n < 2 {
150 return Err(RustMlError::EmptyInput("need at least 2 samples".into()));
151 }
152
153 let k = self.cv_folds.min(n);
154
155 let folds = stratified_k_fold(y, k);
159 let mut oof_scores = vec![0.0f64; n];
160 let mut oof_labels = vec![0.0f64; n];
161
162 for (train_idx, test_idx) in &folds {
163 let x_train = select_rows(x, train_idx);
164 let y_train = select_elements(y, train_idx);
165 let x_test = select_rows(x, test_idx);
166
167 let fitted = self.base_estimator.fit_box(&x_train, &y_train)?;
168 let preds = fitted.predict_box(&x_test)?;
169
170 for (li, &gi) in test_idx.iter().enumerate() {
171 oof_scores[gi] = preds[li].to_f64().unwrap();
172 oof_labels[gi] = y[gi].to_f64().unwrap();
173 }
174 }
175
176 let (cal_a, cal_b, isotonic_x, isotonic_y) = match self.method {
178 CalibrationMethod::Sigmoid => {
179 let (a, b) = fit_platt_sigmoid(&oof_scores, &oof_labels);
180 (a, b, Vec::new(), Vec::new())
181 }
182 CalibrationMethod::Isotonic => {
183 let (ix, iy) = fit_isotonic(&oof_scores, &oof_labels);
184 (0.0, 0.0, ix, iy)
185 }
186 };
187
188 let base_model = self.base_estimator.fit_box(x, y)?;
190
191 Ok(FittedCalibratedClassifier {
192 base_model,
193 cal_a,
194 cal_b,
195 isotonic_x,
196 isotonic_y,
197 method: self.method,
198 n_features: x.ncols(),
199 })
200 }
201}
202
203impl<F: Float> Predict<F> for FittedCalibratedClassifier<F> {
204 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
205 let proba = self.predict_proba(x)?;
206 let threshold = F::from_f64(0.5).unwrap();
207 Ok(proba.mapv(|p| if p >= threshold { F::one() } else { F::zero() }))
208 }
209}
210
211fn fit_platt_sigmoid(scores: &[f64], labels: &[f64]) -> (f64, f64) {
214 let n = scores.len();
215 if n == 0 {
216 return (1.0, 0.0);
217 }
218
219 let mut a = 0.0f64;
220 let mut b = 0.0f64;
221 let lr = 0.01;
222
223 for _ in 0..1000 {
224 let mut grad_a = 0.0;
225 let mut grad_b = 0.0;
226
227 for i in 0..n {
228 let p = 1.0 / (1.0 + (-(a * scores[i] + b)).exp());
229 let err = p - labels[i];
230 grad_a += err * scores[i];
231 grad_b += err;
232 }
233
234 grad_a /= n as f64;
235 grad_b /= n as f64;
236
237 a -= lr * grad_a;
238 b -= lr * grad_b;
239 }
240
241 (a, b)
242}
243
244fn fit_isotonic(scores: &[f64], labels: &[f64]) -> (Vec<f64>, Vec<f64>) {
246 let n = scores.len();
247 if n == 0 {
248 return (Vec::new(), Vec::new());
249 }
250
251 let mut pairs: Vec<(f64, f64)> = scores
253 .iter()
254 .zip(labels.iter())
255 .map(|(&s, &l)| (s, l))
256 .collect();
257 pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
258
259 let mut x_out: Vec<f64> = Vec::with_capacity(n);
261 let mut y_out: Vec<f64> = Vec::with_capacity(n);
262 let mut weights: Vec<f64> = Vec::with_capacity(n);
263
264 for &(xi, yi) in &pairs {
265 x_out.push(xi);
266 y_out.push(yi);
267 weights.push(1.0);
268
269 while y_out.len() >= 2 {
270 let len = y_out.len();
271 if y_out[len - 2] > y_out[len - 1] {
272 let w1 = weights[len - 2];
273 let w2 = weights[len - 1];
274 let merged = (y_out[len - 2] * w1 + y_out[len - 1] * w2) / (w1 + w2);
275 let merged_x = (x_out[len - 2] * w1 + x_out[len - 1] * w2) / (w1 + w2);
276 y_out.pop();
277 x_out.pop();
278 weights.pop();
279 *y_out.last_mut().unwrap() = merged;
280 *x_out.last_mut().unwrap() = merged_x;
281 *weights.last_mut().unwrap() = w1 + w2;
282 } else {
283 break;
284 }
285 }
286 }
287
288 (x_out, y_out)
289}
290
291fn isotonic_predict(score: f64, x: &[f64], y: &[f64]) -> f64 {
293 if x.is_empty() {
294 return 0.5;
295 }
296 if score <= x[0] {
297 return y[0];
298 }
299 if score >= x[x.len() - 1] {
300 return y[y.len() - 1];
301 }
302
303 let pos = x.partition_point(|&v| v < score);
305 if pos == 0 {
306 return y[0];
307 }
308 if pos >= x.len() {
309 return y[y.len() - 1];
310 }
311
312 let x0 = x[pos - 1];
314 let x1 = x[pos];
315 let y0 = y[pos - 1];
316 let y1 = y[pos];
317
318 if (x1 - x0).abs() < 1e-15 {
319 return (y0 + y1) / 2.0;
320 }
321
322 y0 + (y1 - y0) * (score - x0) / (x1 - x0)
323}
324
325fn stratified_k_fold<F: Float>(y: &Array1<F>, k: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
329 use std::collections::HashMap;
330 let n = y.len();
331
332 let mut by_class: HashMap<u64, Vec<usize>> = HashMap::new();
334 for i in 0..n {
335 let key = y[i].to_f64().unwrap().to_bits();
336 by_class.entry(key).or_default().push(i);
337 }
338
339 let mut fold_of = vec![0usize; n];
341 for (_, class_indices) in by_class.iter() {
342 for (j, &idx) in class_indices.iter().enumerate() {
343 fold_of[idx] = j % k;
344 }
345 }
346
347 let mut folds: Vec<(Vec<usize>, Vec<usize>)> =
349 (0..k).map(|_| (Vec::new(), Vec::new())).collect();
350 for i in 0..n {
351 for (f, (train, test)) in folds.iter_mut().enumerate() {
352 if fold_of[i] == f {
353 test.push(i);
354 } else {
355 train.push(i);
356 }
357 }
358 }
359 folds
360}
361
362fn select_rows<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
363 let ncols = x.ncols();
364 let mut data = Vec::with_capacity(indices.len() * ncols);
365 for &i in indices {
366 for j in 0..ncols {
367 data.push(x[[i, j]]);
368 }
369 }
370 Array2::from_shape_vec((indices.len(), ncols), data).unwrap()
371}
372
373fn select_elements<F: Float>(y: &Array1<F>, indices: &[usize]) -> Array1<F> {
374 Array1::from_vec(indices.iter().map(|&i| y[i]).collect())
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use anofox_ml_trees::DecisionTreeClassifier;
381 use ndarray::array;
382
383 #[test]
384 fn test_calibrated_classifier_sigmoid() {
385 let x = array![
386 [1.0, 0.0],
387 [2.0, 0.0],
388 [3.0, 0.0],
389 [4.0, 0.0],
390 [10.0, 1.0],
391 [11.0, 1.0],
392 [12.0, 1.0],
393 [13.0, 1.0]
394 ];
395 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
396
397 let cal = CalibratedClassifierCV::new(DecisionTreeClassifier {
398 max_depth: Some(3),
399 ..Default::default()
400 })
401 .with_method(CalibrationMethod::Sigmoid)
402 .with_cv_folds(2);
403
404 let fitted: FittedCalibratedClassifier<f64> = cal.fit(&x, &y).unwrap();
405
406 let proba = fitted.predict_proba(&x).unwrap();
407 for &p in proba.iter() {
408 assert!(
409 p >= 0.0 && p <= 1.0,
410 "probability must be in [0,1], got {}",
411 p
412 );
413 }
414
415 let preds = fitted.predict(&x).unwrap();
416 for &p in preds.iter() {
417 assert!(p == 0.0 || p == 1.0);
418 }
419 }
420
421 #[test]
422 fn test_calibrated_classifier_isotonic() {
423 let x = array![
424 [1.0, 0.0],
425 [2.0, 0.0],
426 [3.0, 0.0],
427 [4.0, 0.0],
428 [10.0, 1.0],
429 [11.0, 1.0],
430 [12.0, 1.0],
431 [13.0, 1.0]
432 ];
433 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
434
435 let cal = CalibratedClassifierCV::new(DecisionTreeClassifier::default())
436 .with_method(CalibrationMethod::Isotonic)
437 .with_cv_folds(2);
438
439 let fitted: FittedCalibratedClassifier<f64> = cal.fit(&x, &y).unwrap();
440 let proba = fitted.predict_proba(&x).unwrap();
441 for &p in proba.iter() {
442 assert!(p >= 0.0 && p <= 1.0);
443 }
444 }
445
446 #[test]
447 fn test_calibrated_classifier_predict_classes() {
448 let x = array![
449 [0.0, 0.0],
450 [1.0, 0.0],
451 [2.0, 0.0],
452 [10.0, 1.0],
453 [11.0, 1.0],
454 [12.0, 1.0]
455 ];
456 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
457
458 let cal = CalibratedClassifierCV::new(DecisionTreeClassifier::default()).with_cv_folds(2);
459
460 let fitted: FittedCalibratedClassifier<f64> = cal.fit(&x, &y).unwrap();
461 let preds = fitted.predict(&x).unwrap();
462 assert_eq!(preds.len(), 6);
463 }
464
465 #[test]
466 fn test_calibrated_classifier_shape_mismatch() {
467 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
468 let y = array![0.0, 0.0, 1.0, 1.0];
469
470 let cal = CalibratedClassifierCV::new(DecisionTreeClassifier::default()).with_cv_folds(2);
471 let fitted: FittedCalibratedClassifier<f64> = cal.fit(&x, &y).unwrap();
472
473 let x_bad = array![[1.0]];
474 assert!(fitted.predict(&x_bad).is_err());
475 }
476
477 #[test]
478 fn test_calibrated_classifier_empty_error() {
479 let x = Array2::<f64>::zeros((0, 2));
480 let y = Array1::<f64>::zeros(0);
481
482 let cal = CalibratedClassifierCV::new(DecisionTreeClassifier::default());
483 assert!(Fit::<f64>::fit(&cal, &x, &y).is_err());
484 }
485}