calibration_demo/
calibration_demo.rs

1#![allow(clippy::pedantic, clippy::unnecessary_wraps)]
2//! Comprehensive Calibration Demo
3//!
4//! This example demonstrates all three calibration methods available in QuantRS2-ML:
5//! 1. Platt Scaling (binary classification)
6//! 2. Isotonic Regression (binary classification, non-parametric)
7//! 3. Temperature Scaling (multi-class classification)
8//!
9//! Run with: cargo run --example calibration_demo
10
11use quantrs2_ml::utils::calibration::*;
12use quantrs2_ml::utils::metrics;
13use scirs2_core::ndarray::{array, Array1, Array2};
14use scirs2_core::random::prelude::*;
15
16fn main() -> Result<(), Box<dyn std::error::Error>> {
17    println!("=== QuantRS2-ML Calibration Demo ===\n");
18
19    // Demo 1: Platt Scaling for Binary Classification
20    println!("1. PLATT SCALING (Binary Classification)");
21    println!("   Purpose: Parametric calibration using logistic regression");
22    println!("   Best for: Well-separated binary classification\n");
23
24    demo_platt_scaling()?;
25
26    println!("\n{}\n", "=".repeat(60));
27
28    // Demo 2: Isotonic Regression for Binary Classification
29    println!("2. ISOTONIC REGRESSION (Binary Classification)");
30    println!("   Purpose: Non-parametric monotonic calibration");
31    println!("   Best for: Non-linearly separable binary data\n");
32
33    demo_isotonic_regression()?;
34
35    println!("\n{}\n", "=".repeat(60));
36
37    // Demo 3: Temperature Scaling for Multi-class
38    println!("3. TEMPERATURE SCALING (Multi-class Classification)");
39    println!("   Purpose: Scale logits by single temperature parameter");
40    println!("   Best for: Neural network outputs, multi-class problems\n");
41
42    demo_temperature_scaling()?;
43
44    println!("\n{}\n", "=".repeat(60));
45
46    // Demo 4: Calibration Curve Visualization
47    println!("4. CALIBRATION CURVE ANALYSIS");
48    println!("   Purpose: Visualize calibration quality (reliability diagram)\n");
49
50    demo_calibration_curve()?;
51
52    println!("\n=== Demo Complete ===");
53    println!("All calibration methods demonstrated successfully!");
54
55    Ok(())
56}
57
58fn demo_platt_scaling() -> Result<(), Box<dyn std::error::Error>> {
59    // Generate synthetic binary classification scores
60    // Positive class: higher scores, Negative class: lower scores
61    let scores = array![
62        2.5, 2.0, 1.8, 1.5, 1.2, // Positive class (overconfident)
63        -1.2, -1.5, -1.8, -2.0, -2.5 // Negative class (overconfident)
64    ];
65    let labels = array![1, 1, 1, 1, 1, 0, 0, 0, 0, 0];
66
67    println!("   Input scores: {scores:?}");
68    println!("   True labels:  {labels:?}\n");
69
70    // Fit Platt scaler
71    let mut scaler = PlattScaler::new();
72    scaler.fit(&scores, &labels)?;
73
74    // Get fitted parameters
75    if let Some((a, b)) = scaler.parameters() {
76        println!("   Fitted parameters:");
77        println!("   - Slope (a):     {a:.4}");
78        println!("   - Intercept (b): {b:.4}");
79    }
80
81    // Transform scores to calibrated probabilities
82    let calibrated_probs = scaler.transform(&scores)?;
83    println!("\n   Calibrated probabilities:");
84    for (i, (&score, &prob)) in scores.iter().zip(calibrated_probs.iter()).enumerate() {
85        println!("   Sample {i}: score={score:6.2} → P(class=1)={prob:.4}");
86    }
87
88    // Compute accuracy on predictions
89    let predictions: Array1<usize> = calibrated_probs.mapv(|p| usize::from(p > 0.5));
90    let accuracy = metrics::accuracy(&predictions, &labels);
91    println!("\n   Calibrated accuracy: {:.2}%", accuracy * 100.0);
92
93    Ok(())
94}
95
96fn demo_isotonic_regression() -> Result<(), Box<dyn std::error::Error>> {
97    // Generate non-linearly separable scores
98    let scores = array![
99        0.1, 0.25, 0.2, // Low scores
100        0.4, 0.35, 0.55, // Mid-low scores
101        0.6, 0.75, 0.7, // Mid-high scores
102        0.85, 0.95, 0.9 // High scores
103    ];
104    // Non-linear relationship with labels
105    let labels = array![0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1];
106
107    println!("   Input scores: {scores:?}");
108    println!("   True labels:  {labels:?}\n");
109
110    // Fit isotonic regression
111    let mut iso = IsotonicRegression::new();
112    iso.fit(&scores, &labels)?;
113
114    println!("   Fitted isotonic regression (maintains monotonicity)");
115
116    // Transform scores
117    let calibrated_probs = iso.transform(&scores)?;
118    println!("\n   Calibrated probabilities:");
119    for (i, (&score, &prob)) in scores.iter().zip(calibrated_probs.iter()).enumerate() {
120        println!("   Sample {i}: score={score:.2} → P(class=1)={prob:.4}");
121    }
122
123    // Verify monotonicity
124    let mut is_monotonic = true;
125    for i in 0..calibrated_probs.len() - 1 {
126        if calibrated_probs[i] > calibrated_probs[i + 1] + 1e-6 {
127            is_monotonic = false;
128            break;
129        }
130    }
131    println!(
132        "\n   Monotonicity preserved: {}",
133        if is_monotonic { "✓" } else { "✗" }
134    );
135
136    Ok(())
137}
138
139fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
140    // Generate multi-class logits (4 classes, 8 samples)
141    let logits = array![
142        [5.0, 1.0, 0.5, 0.0], // Overconfident for class 0
143        [1.0, 5.0, 0.5, 0.0], // Overconfident for class 1
144        [0.5, 1.0, 5.0, 0.0], // Overconfident for class 2
145        [0.0, 0.5, 1.0, 5.0], // Overconfident for class 3
146        [3.0, 2.0, 1.0, 0.5], // Moderately confident for class 0
147        [1.0, 3.0, 2.0, 0.5], // Moderately confident for class 1
148        [0.5, 1.0, 3.0, 2.0], // Moderately confident for class 2
149        [0.5, 0.5, 1.0, 3.0], // Moderately confident for class 3
150    ];
151    let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
152
153    println!("   Input: 4-class classification with 8 samples");
154    println!("   Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
155
156    // Compute uncalibrated softmax for comparison
157    let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
158    for i in 0..logits.nrows() {
159        let max_logit = logits
160            .row(i)
161            .iter()
162            .copied()
163            .fold(f64::NEG_INFINITY, f64::max);
164        let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
165        for j in 0..logits.ncols() {
166            uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
167        }
168    }
169
170    // Fit temperature scaler
171    let mut scaler = TemperatureScaler::new();
172    scaler.fit(&logits, &labels)?;
173
174    // Get fitted temperature
175    if let Some(temp) = scaler.temperature() {
176        println!("   Fitted temperature: T = {temp:.4}");
177        println!(
178            "   Interpretation: {}",
179            if temp > 1.0 {
180                "Model is overconfident (T > 1 reduces confidence)"
181            } else if temp < 1.0 {
182                "Model is underconfident (T < 1 increases confidence)"
183            } else {
184                "Model is well-calibrated (T ≈ 1)"
185            }
186        );
187    }
188
189    // Transform to calibrated probabilities
190    let calibrated_probs = scaler.transform(&logits)?;
191
192    println!("\n   Comparison (first 4 samples):");
193    println!(
194        "   {:<8} | {:<20} | {:<20}",
195        "Sample", "Uncalibrated Max P", "Calibrated Max P"
196    );
197    println!("   {}", "-".repeat(60));
198
199    for i in 0..4 {
200        let uncal_max = uncalibrated_probs
201            .row(i)
202            .iter()
203            .copied()
204            .fold(f64::NEG_INFINITY, f64::max);
205        let cal_max = calibrated_probs
206            .row(i)
207            .iter()
208            .copied()
209            .fold(f64::NEG_INFINITY, f64::max);
210        println!("   Sample {i:<2}  | {uncal_max:.4}               | {cal_max:.4}");
211    }
212
213    // Compute predictions
214    let mut correct = 0;
215    for i in 0..calibrated_probs.nrows() {
216        let pred = calibrated_probs
217            .row(i)
218            .iter()
219            .enumerate()
220            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
221            .map(|(idx, _)| idx)
222            .unwrap();
223        if pred == labels[i] {
224            correct += 1;
225        }
226    }
227
228    let accuracy = correct as f64 / labels.len() as f64;
229    println!("\n   Calibrated accuracy: {:.2}%", accuracy * 100.0);
230
231    Ok(())
232}
233
234fn demo_calibration_curve() -> Result<(), Box<dyn std::error::Error>> {
235    // Generate predicted probabilities and true labels
236    let probabilities = array![0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95];
237    let labels = array![0, 0, 0, 1, 0, 1, 1, 1, 1, 1];
238
239    println!("   Probabilities: {probabilities:?}");
240    println!("   True labels:   {labels:?}\n");
241
242    // Compute calibration curve
243    let (mean_predicted, fraction_positives) = calibration_curve(&probabilities, &labels, 5)?;
244
245    println!("   Calibration Curve (5 bins):");
246    println!(
247        "   {:<5} | {:<18} | {:<20}",
248        "Bin", "Mean Predicted P", "Fraction Positive"
249    );
250    println!("   {}", "-".repeat(60));
251
252    for i in 0..mean_predicted.len() {
253        println!(
254            "   Bin {} | {:.4}              | {:.4}",
255            i + 1,
256            mean_predicted[i],
257            fraction_positives[i]
258        );
259    }
260
261    // Compute calibration error (Expected Calibration Error - ECE)
262    let mut ece = 0.0;
263    let mut total_samples = 0;
264
265    // Count samples in each bin
266    let n_bins = 5;
267    for i in 0..probabilities.len() {
268        let bin_idx = ((probabilities[i] * n_bins as f64).floor() as usize).min(n_bins - 1);
269        if bin_idx < mean_predicted.len() {
270            ece += (mean_predicted[bin_idx] - fraction_positives[bin_idx]).abs();
271            total_samples += 1;
272        }
273    }
274
275    if total_samples > 0 {
276        ece /= total_samples as f64;
277        println!("\n   Expected Calibration Error (ECE): {ece:.4}");
278        println!(
279            "   Interpretation: {}",
280            if ece < 0.1 {
281                "Well-calibrated (ECE < 0.1)"
282            } else if ece < 0.2 {
283                "Moderately calibrated"
284            } else {
285                "Poorly calibrated (ECE > 0.2)"
286            }
287        );
288    }
289
290    Ok(())
291}