Skip to main content

anofox_ml_ensemble/
calibrated_classifier.rs

1//! CalibratedClassifierCV — probability calibration for classifiers.
2//!
3//! Wraps any classifier to produce well-calibrated probabilities using either
4//! Platt scaling (sigmoid) or isotonic regression, fitted via cross-validation.
5
6use anofox_ml_core::{Fit, Float, Predict, Result, RustMlError};
7use ndarray::{Array1, Array2};
8
9/// Calibration method.
10#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
11pub enum CalibrationMethod {
12    /// Platt scaling: fits a sigmoid A*f(x)+B to map scores to probabilities.
13    Sigmoid,
14    /// Isotonic regression: non-parametric monotonic mapping.
15    Isotonic,
16}
17
18impl Default for CalibrationMethod {
19    fn default() -> Self {
20        CalibrationMethod::Sigmoid
21    }
22}
23
24/// Internal trait for type-erased fit/predict.
25trait FitPredBox<F: Float>: Send + Sync {
26    fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>>;
27}
28
29trait PredBox<F: Float>: Send + Sync {
30    fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>>;
31}
32
33impl<F, T> FitPredBox<F> for T
34where
35    F: Float,
36    T: Fit<F> + Send + Sync,
37    T::Fitted: Predict<F> + Send + Sync + 'static,
38{
39    fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>> {
40        let fitted = Fit::fit(self, x, y)?;
41        Ok(Box::new(fitted))
42    }
43}
44
45impl<F, T> PredBox<F> for T
46where
47    F: Float,
48    T: Predict<F> + Send + Sync,
49{
50    fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>> {
51        self.predict(x)
52    }
53}
54
55/// Calibrated classifier with cross-validation.
56///
57/// Wraps a base classifier and calibrates its predictions to produce
58/// well-calibrated probabilities. Uses cross-validation to generate
59/// out-of-fold predictions for calibration fitting.
60pub struct CalibratedClassifierCV<F: Float> {
61    base_estimator: Box<dyn FitPredBox<F>>,
62    method: CalibrationMethod,
63    cv_folds: usize,
64}
65
66impl<F: Float> CalibratedClassifierCV<F> {
67    /// Create a new CalibratedClassifierCV wrapping the given base estimator.
68    pub fn new<T>(base_estimator: T) -> Self
69    where
70        T: Fit<F> + Send + Sync + 'static,
71        T::Fitted: Predict<F> + Send + Sync + 'static,
72    {
73        Self {
74            base_estimator: Box::new(base_estimator),
75            method: CalibrationMethod::Sigmoid,
76            cv_folds: 5,
77        }
78    }
79
80    pub fn with_method(mut self, method: CalibrationMethod) -> Self {
81        self.method = method;
82        self
83    }
84
85    pub fn with_cv_folds(mut self, cv_folds: usize) -> Self {
86        self.cv_folds = cv_folds;
87        self
88    }
89}
90
91/// Fitted calibrated classifier.
92pub struct FittedCalibratedClassifier<F: Float> {
93    /// Base model fitted on full data.
94    base_model: Box<dyn PredBox<F>>,
95    /// Calibration parameters (Platt sigmoid: a, b).
96    cal_a: f64,
97    cal_b: f64,
98    /// For isotonic: sorted (score, prob) pairs.
99    isotonic_x: Vec<f64>,
100    isotonic_y: Vec<f64>,
101    method: CalibrationMethod,
102    n_features: usize,
103}
104
105impl<F: Float> FittedCalibratedClassifier<F> {
106    /// Predict calibrated probabilities for class 1.
107    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array1<F>> {
108        if x.ncols() != self.n_features {
109            return Err(RustMlError::ShapeMismatch(format!(
110                "expected {} features, got {}",
111                self.n_features,
112                x.ncols()
113            )));
114        }
115
116        let raw_preds = self.base_model.predict_box(x)?;
117        let n = raw_preds.len();
118        let mut proba = Array1::zeros(n);
119
120        for i in 0..n {
121            let score = raw_preds[i].to_f64().unwrap();
122            let p = match self.method {
123                CalibrationMethod::Sigmoid => {
124                    1.0 / (1.0 + (-(self.cal_a * score + self.cal_b)).exp())
125                }
126                CalibrationMethod::Isotonic => {
127                    isotonic_predict(score, &self.isotonic_x, &self.isotonic_y)
128                }
129            };
130            proba[i] = F::from_f64(p.clamp(0.0, 1.0)).unwrap();
131        }
132
133        Ok(proba)
134    }
135}
136
137impl<F: Float + 'static> Fit<F> for CalibratedClassifierCV<F> {
138    type Fitted = FittedCalibratedClassifier<F>;
139
140    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
141        if x.nrows() != y.len() {
142            return Err(RustMlError::ShapeMismatch(format!(
143                "X has {} rows but y has {} elements",
144                x.nrows(),
145                y.len()
146            )));
147        }
148        let n = x.nrows();
149        if n < 2 {
150            return Err(RustMlError::EmptyInput("need at least 2 samples".into()));
151        }
152
153        let k = self.cv_folds.min(n);
154
155        // Generate out-of-fold predictions for calibration using stratified
156        // splits so each fold keeps the class distribution — otherwise on
157        // class-sorted data we end up training each fold on a single class.
158        let folds = stratified_k_fold(y, k);
159        let mut oof_scores = vec![0.0f64; n];
160        let mut oof_labels = vec![0.0f64; n];
161
162        for (train_idx, test_idx) in &folds {
163            let x_train = select_rows(x, train_idx);
164            let y_train = select_elements(y, train_idx);
165            let x_test = select_rows(x, test_idx);
166
167            let fitted = self.base_estimator.fit_box(&x_train, &y_train)?;
168            let preds = fitted.predict_box(&x_test)?;
169
170            for (li, &gi) in test_idx.iter().enumerate() {
171                oof_scores[gi] = preds[li].to_f64().unwrap();
172                oof_labels[gi] = y[gi].to_f64().unwrap();
173            }
174        }
175
176        // Fit calibration mapping on OOF predictions
177        let (cal_a, cal_b, isotonic_x, isotonic_y) = match self.method {
178            CalibrationMethod::Sigmoid => {
179                let (a, b) = fit_platt_sigmoid(&oof_scores, &oof_labels);
180                (a, b, Vec::new(), Vec::new())
181            }
182            CalibrationMethod::Isotonic => {
183                let (ix, iy) = fit_isotonic(&oof_scores, &oof_labels);
184                (0.0, 0.0, ix, iy)
185            }
186        };
187
188        // Refit base model on full data
189        let base_model = self.base_estimator.fit_box(x, y)?;
190
191        Ok(FittedCalibratedClassifier {
192            base_model,
193            cal_a,
194            cal_b,
195            isotonic_x,
196            isotonic_y,
197            method: self.method,
198            n_features: x.ncols(),
199        })
200    }
201}
202
203impl<F: Float> Predict<F> for FittedCalibratedClassifier<F> {
204    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
205        let proba = self.predict_proba(x)?;
206        let threshold = F::from_f64(0.5).unwrap();
207        Ok(proba.mapv(|p| if p >= threshold { F::one() } else { F::zero() }))
208    }
209}
210
211/// Fit Platt scaling: find A, B such that P(y=1|f) = 1/(1+exp(A*f+B)).
212/// Uses gradient descent on cross-entropy loss.
213fn fit_platt_sigmoid(scores: &[f64], labels: &[f64]) -> (f64, f64) {
214    let n = scores.len();
215    if n == 0 {
216        return (1.0, 0.0);
217    }
218
219    let mut a = 0.0f64;
220    let mut b = 0.0f64;
221    let lr = 0.01;
222
223    for _ in 0..1000 {
224        let mut grad_a = 0.0;
225        let mut grad_b = 0.0;
226
227        for i in 0..n {
228            let p = 1.0 / (1.0 + (-(a * scores[i] + b)).exp());
229            let err = p - labels[i];
230            grad_a += err * scores[i];
231            grad_b += err;
232        }
233
234        grad_a /= n as f64;
235        grad_b /= n as f64;
236
237        a -= lr * grad_a;
238        b -= lr * grad_b;
239    }
240
241    (a, b)
242}
243
244/// Fit isotonic regression (pool adjacent violators algorithm).
245fn fit_isotonic(scores: &[f64], labels: &[f64]) -> (Vec<f64>, Vec<f64>) {
246    let n = scores.len();
247    if n == 0 {
248        return (Vec::new(), Vec::new());
249    }
250
251    // Sort by score
252    let mut pairs: Vec<(f64, f64)> = scores
253        .iter()
254        .zip(labels.iter())
255        .map(|(&s, &l)| (s, l))
256        .collect();
257    pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
258
259    // Pool adjacent violators
260    let mut x_out: Vec<f64> = Vec::with_capacity(n);
261    let mut y_out: Vec<f64> = Vec::with_capacity(n);
262    let mut weights: Vec<f64> = Vec::with_capacity(n);
263
264    for &(xi, yi) in &pairs {
265        x_out.push(xi);
266        y_out.push(yi);
267        weights.push(1.0);
268
269        while y_out.len() >= 2 {
270            let len = y_out.len();
271            if y_out[len - 2] > y_out[len - 1] {
272                let w1 = weights[len - 2];
273                let w2 = weights[len - 1];
274                let merged = (y_out[len - 2] * w1 + y_out[len - 1] * w2) / (w1 + w2);
275                let merged_x = (x_out[len - 2] * w1 + x_out[len - 1] * w2) / (w1 + w2);
276                y_out.pop();
277                x_out.pop();
278                weights.pop();
279                *y_out.last_mut().unwrap() = merged;
280                *x_out.last_mut().unwrap() = merged_x;
281                *weights.last_mut().unwrap() = w1 + w2;
282            } else {
283                break;
284            }
285        }
286    }
287
288    (x_out, y_out)
289}
290
291/// Predict using isotonic regression (linear interpolation).
292fn isotonic_predict(score: f64, x: &[f64], y: &[f64]) -> f64 {
293    if x.is_empty() {
294        return 0.5;
295    }
296    if score <= x[0] {
297        return y[0];
298    }
299    if score >= x[x.len() - 1] {
300        return y[y.len() - 1];
301    }
302
303    // Binary search for the interval
304    let pos = x.partition_point(|&v| v < score);
305    if pos == 0 {
306        return y[0];
307    }
308    if pos >= x.len() {
309        return y[y.len() - 1];
310    }
311
312    // Linear interpolation
313    let x0 = x[pos - 1];
314    let x1 = x[pos];
315    let y0 = y[pos - 1];
316    let y1 = y[pos];
317
318    if (x1 - x0).abs() < 1e-15 {
319        return (y0 + y1) / 2.0;
320    }
321
322    y0 + (y1 - y0) * (score - x0) / (x1 - x0)
323}
324
325/// Stratified K-fold for classification calibration: groups samples by class
326/// label and distributes each class's samples across the folds in round-robin
327/// fashion, preserving the class proportions in every fold.
328fn stratified_k_fold<F: Float>(y: &Array1<F>, k: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
329    use std::collections::HashMap;
330    let n = y.len();
331
332    // Group sample indices by class (keyed by f64::to_bits to support arbitrary labels).
333    let mut by_class: HashMap<u64, Vec<usize>> = HashMap::new();
334    for i in 0..n {
335        let key = y[i].to_f64().unwrap().to_bits();
336        by_class.entry(key).or_default().push(i);
337    }
338
339    // Assign each sample a fold number, round-robin within each class.
340    let mut fold_of = vec![0usize; n];
341    for (_, class_indices) in by_class.iter() {
342        for (j, &idx) in class_indices.iter().enumerate() {
343            fold_of[idx] = j % k;
344        }
345    }
346
347    // Build (train, test) index pairs per fold.
348    let mut folds: Vec<(Vec<usize>, Vec<usize>)> =
349        (0..k).map(|_| (Vec::new(), Vec::new())).collect();
350    for i in 0..n {
351        for (f, (train, test)) in folds.iter_mut().enumerate() {
352            if fold_of[i] == f {
353                test.push(i);
354            } else {
355                train.push(i);
356            }
357        }
358    }
359    folds
360}
361
362fn select_rows<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
363    let ncols = x.ncols();
364    let mut data = Vec::with_capacity(indices.len() * ncols);
365    for &i in indices {
366        for j in 0..ncols {
367            data.push(x[[i, j]]);
368        }
369    }
370    Array2::from_shape_vec((indices.len(), ncols), data).unwrap()
371}
372
373fn select_elements<F: Float>(y: &Array1<F>, indices: &[usize]) -> Array1<F> {
374    Array1::from_vec(indices.iter().map(|&i| y[i]).collect())
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use anofox_ml_trees::DecisionTreeClassifier;
381    use ndarray::array;
382
383    #[test]
384    fn test_calibrated_classifier_sigmoid() {
385        let x = array![
386            [1.0, 0.0],
387            [2.0, 0.0],
388            [3.0, 0.0],
389            [4.0, 0.0],
390            [10.0, 1.0],
391            [11.0, 1.0],
392            [12.0, 1.0],
393            [13.0, 1.0]
394        ];
395        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
396
397        let cal = CalibratedClassifierCV::new(DecisionTreeClassifier {
398            max_depth: Some(3),
399            ..Default::default()
400        })
401        .with_method(CalibrationMethod::Sigmoid)
402        .with_cv_folds(2);
403
404        let fitted: FittedCalibratedClassifier<f64> = cal.fit(&x, &y).unwrap();
405
406        let proba = fitted.predict_proba(&x).unwrap();
407        for &p in proba.iter() {
408            assert!(
409                p >= 0.0 && p <= 1.0,
410                "probability must be in [0,1], got {}",
411                p
412            );
413        }
414
415        let preds = fitted.predict(&x).unwrap();
416        for &p in preds.iter() {
417            assert!(p == 0.0 || p == 1.0);
418        }
419    }
420
421    #[test]
422    fn test_calibrated_classifier_isotonic() {
423        let x = array![
424            [1.0, 0.0],
425            [2.0, 0.0],
426            [3.0, 0.0],
427            [4.0, 0.0],
428            [10.0, 1.0],
429            [11.0, 1.0],
430            [12.0, 1.0],
431            [13.0, 1.0]
432        ];
433        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
434
435        let cal = CalibratedClassifierCV::new(DecisionTreeClassifier::default())
436            .with_method(CalibrationMethod::Isotonic)
437            .with_cv_folds(2);
438
439        let fitted: FittedCalibratedClassifier<f64> = cal.fit(&x, &y).unwrap();
440        let proba = fitted.predict_proba(&x).unwrap();
441        for &p in proba.iter() {
442            assert!(p >= 0.0 && p <= 1.0);
443        }
444    }
445
446    #[test]
447    fn test_calibrated_classifier_predict_classes() {
448        let x = array![
449            [0.0, 0.0],
450            [1.0, 0.0],
451            [2.0, 0.0],
452            [10.0, 1.0],
453            [11.0, 1.0],
454            [12.0, 1.0]
455        ];
456        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
457
458        let cal = CalibratedClassifierCV::new(DecisionTreeClassifier::default()).with_cv_folds(2);
459
460        let fitted: FittedCalibratedClassifier<f64> = cal.fit(&x, &y).unwrap();
461        let preds = fitted.predict(&x).unwrap();
462        assert_eq!(preds.len(), 6);
463    }
464
465    #[test]
466    fn test_calibrated_classifier_shape_mismatch() {
467        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
468        let y = array![0.0, 0.0, 1.0, 1.0];
469
470        let cal = CalibratedClassifierCV::new(DecisionTreeClassifier::default()).with_cv_folds(2);
471        let fitted: FittedCalibratedClassifier<f64> = cal.fit(&x, &y).unwrap();
472
473        let x_bad = array![[1.0]];
474        assert!(fitted.predict(&x_bad).is_err());
475    }
476
477    #[test]
478    fn test_calibrated_classifier_empty_error() {
479        let x = Array2::<f64>::zeros((0, 2));
480        let y = Array1::<f64>::zeros(0);
481
482        let cal = CalibratedClassifierCV::new(DecisionTreeClassifier::default());
483        assert!(Fit::<f64>::fit(&cal, &x, &y).is_err());
484    }
485}