calibration_demo/
calibration_demo.rs

1#![allow(
2    clippy::pedantic,
3    clippy::unnecessary_wraps,
4    clippy::needless_range_loop,
5    clippy::useless_vec,
6    clippy::needless_collect,
7    clippy::too_many_arguments
8)]
9//! Comprehensive Calibration Demo
10//!
11//! This example demonstrates all three calibration methods available in QuantRS2-ML:
12//! 1. Platt Scaling (binary classification)
13//! 2. Isotonic Regression (binary classification, non-parametric)
14//! 3. Temperature Scaling (multi-class classification)
15//!
16//! Run with: cargo run --example calibration_demo
17
18use quantrs2_ml::utils::calibration::*;
19use quantrs2_ml::utils::metrics;
20use scirs2_core::ndarray::{array, Array1, Array2};
21use scirs2_core::random::prelude::*;
22
23fn main() -> Result<(), Box<dyn std::error::Error>> {
24    println!("=== QuantRS2-ML Calibration Demo ===\n");
25
26    // Demo 1: Platt Scaling for Binary Classification
27    println!("1. PLATT SCALING (Binary Classification)");
28    println!("   Purpose: Parametric calibration using logistic regression");
29    println!("   Best for: Well-separated binary classification\n");
30
31    demo_platt_scaling()?;
32
33    println!("\n{}\n", "=".repeat(60));
34
35    // Demo 2: Isotonic Regression for Binary Classification
36    println!("2. ISOTONIC REGRESSION (Binary Classification)");
37    println!("   Purpose: Non-parametric monotonic calibration");
38    println!("   Best for: Non-linearly separable binary data\n");
39
40    demo_isotonic_regression()?;
41
42    println!("\n{}\n", "=".repeat(60));
43
44    // Demo 3: Temperature Scaling for Multi-class
45    println!("3. TEMPERATURE SCALING (Multi-class Classification)");
46    println!("   Purpose: Scale logits by single temperature parameter");
47    println!("   Best for: Neural network outputs, multi-class problems\n");
48
49    demo_temperature_scaling()?;
50
51    println!("\n{}\n", "=".repeat(60));
52
53    // Demo 4: Calibration Curve Visualization
54    println!("4. CALIBRATION CURVE ANALYSIS");
55    println!("   Purpose: Visualize calibration quality (reliability diagram)\n");
56
57    demo_calibration_curve()?;
58
59    println!("\n=== Demo Complete ===");
60    println!("All calibration methods demonstrated successfully!");
61
62    Ok(())
63}
64
65fn demo_platt_scaling() -> Result<(), Box<dyn std::error::Error>> {
66    // Generate synthetic binary classification scores
67    // Positive class: higher scores, Negative class: lower scores
68    let scores = array![
69        2.5, 2.0, 1.8, 1.5, 1.2, // Positive class (overconfident)
70        -1.2, -1.5, -1.8, -2.0, -2.5 // Negative class (overconfident)
71    ];
72    let labels = array![1, 1, 1, 1, 1, 0, 0, 0, 0, 0];
73
74    println!("   Input scores: {scores:?}");
75    println!("   True labels:  {labels:?}\n");
76
77    // Fit Platt scaler
78    let mut scaler = PlattScaler::new();
79    scaler.fit(&scores, &labels)?;
80
81    // Get fitted parameters
82    if let Some((a, b)) = scaler.parameters() {
83        println!("   Fitted parameters:");
84        println!("   - Slope (a):     {a:.4}");
85        println!("   - Intercept (b): {b:.4}");
86    }
87
88    // Transform scores to calibrated probabilities
89    let calibrated_probs = scaler.transform(&scores)?;
90    println!("\n   Calibrated probabilities:");
91    for (i, (&score, &prob)) in scores.iter().zip(calibrated_probs.iter()).enumerate() {
92        println!("   Sample {i}: score={score:6.2} → P(class=1)={prob:.4}");
93    }
94
95    // Compute accuracy on predictions
96    let predictions: Array1<usize> = calibrated_probs.mapv(|p| usize::from(p > 0.5));
97    let accuracy = metrics::accuracy(&predictions, &labels);
98    println!("\n   Calibrated accuracy: {:.2}%", accuracy * 100.0);
99
100    Ok(())
101}
102
103fn demo_isotonic_regression() -> Result<(), Box<dyn std::error::Error>> {
104    // Generate non-linearly separable scores
105    let scores = array![
106        0.1, 0.25, 0.2, // Low scores
107        0.4, 0.35, 0.55, // Mid-low scores
108        0.6, 0.75, 0.7, // Mid-high scores
109        0.85, 0.95, 0.9 // High scores
110    ];
111    // Non-linear relationship with labels
112    let labels = array![0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1];
113
114    println!("   Input scores: {scores:?}");
115    println!("   True labels:  {labels:?}\n");
116
117    // Fit isotonic regression
118    let mut iso = IsotonicRegression::new();
119    iso.fit(&scores, &labels)?;
120
121    println!("   Fitted isotonic regression (maintains monotonicity)");
122
123    // Transform scores
124    let calibrated_probs = iso.transform(&scores)?;
125    println!("\n   Calibrated probabilities:");
126    for (i, (&score, &prob)) in scores.iter().zip(calibrated_probs.iter()).enumerate() {
127        println!("   Sample {i}: score={score:.2} → P(class=1)={prob:.4}");
128    }
129
130    // Verify monotonicity
131    let mut is_monotonic = true;
132    for i in 0..calibrated_probs.len() - 1 {
133        if calibrated_probs[i] > calibrated_probs[i + 1] + 1e-6 {
134            is_monotonic = false;
135            break;
136        }
137    }
138    println!(
139        "\n   Monotonicity preserved: {}",
140        if is_monotonic { "✓" } else { "✗" }
141    );
142
143    Ok(())
144}
145
146fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
147    // Generate multi-class logits (4 classes, 8 samples)
148    let logits = array![
149        [5.0, 1.0, 0.5, 0.0], // Overconfident for class 0
150        [1.0, 5.0, 0.5, 0.0], // Overconfident for class 1
151        [0.5, 1.0, 5.0, 0.0], // Overconfident for class 2
152        [0.0, 0.5, 1.0, 5.0], // Overconfident for class 3
153        [3.0, 2.0, 1.0, 0.5], // Moderately confident for class 0
154        [1.0, 3.0, 2.0, 0.5], // Moderately confident for class 1
155        [0.5, 1.0, 3.0, 2.0], // Moderately confident for class 2
156        [0.5, 0.5, 1.0, 3.0], // Moderately confident for class 3
157    ];
158    let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
159
160    println!("   Input: 4-class classification with 8 samples");
161    println!("   Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
162
163    // Compute uncalibrated softmax for comparison
164    let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
165    for i in 0..logits.nrows() {
166        let max_logit = logits
167            .row(i)
168            .iter()
169            .copied()
170            .fold(f64::NEG_INFINITY, f64::max);
171        let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
172        for j in 0..logits.ncols() {
173            uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
174        }
175    }
176
177    // Fit temperature scaler
178    let mut scaler = TemperatureScaler::new();
179    scaler.fit(&logits, &labels)?;
180
181    // Get fitted temperature
182    if let Some(temp) = scaler.temperature() {
183        println!("   Fitted temperature: T = {temp:.4}");
184        println!(
185            "   Interpretation: {}",
186            if temp > 1.0 {
187                "Model is overconfident (T > 1 reduces confidence)"
188            } else if temp < 1.0 {
189                "Model is underconfident (T < 1 increases confidence)"
190            } else {
191                "Model is well-calibrated (T ≈ 1)"
192            }
193        );
194    }
195
196    // Transform to calibrated probabilities
197    let calibrated_probs = scaler.transform(&logits)?;
198
199    println!("\n   Comparison (first 4 samples):");
200    println!(
201        "   {:<8} | {:<20} | {:<20}",
202        "Sample", "Uncalibrated Max P", "Calibrated Max P"
203    );
204    println!("   {}", "-".repeat(60));
205
206    for i in 0..4 {
207        let uncal_max = uncalibrated_probs
208            .row(i)
209            .iter()
210            .copied()
211            .fold(f64::NEG_INFINITY, f64::max);
212        let cal_max = calibrated_probs
213            .row(i)
214            .iter()
215            .copied()
216            .fold(f64::NEG_INFINITY, f64::max);
217        println!("   Sample {i:<2}  | {uncal_max:.4}               | {cal_max:.4}");
218    }
219
220    // Compute predictions
221    let mut correct = 0;
222    for i in 0..calibrated_probs.nrows() {
223        let pred = calibrated_probs
224            .row(i)
225            .iter()
226            .enumerate()
227            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
228            .map(|(idx, _)| idx)
229            .unwrap();
230        if pred == labels[i] {
231            correct += 1;
232        }
233    }
234
235    let accuracy = correct as f64 / labels.len() as f64;
236    println!("\n   Calibrated accuracy: {:.2}%", accuracy * 100.0);
237
238    Ok(())
239}
240
241fn demo_calibration_curve() -> Result<(), Box<dyn std::error::Error>> {
242    // Generate predicted probabilities and true labels
243    let probabilities = array![0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95];
244    let labels = array![0, 0, 0, 1, 0, 1, 1, 1, 1, 1];
245
246    println!("   Probabilities: {probabilities:?}");
247    println!("   True labels:   {labels:?}\n");
248
249    // Compute calibration curve
250    let (mean_predicted, fraction_positives) = calibration_curve(&probabilities, &labels, 5)?;
251
252    println!("   Calibration Curve (5 bins):");
253    println!(
254        "   {:<5} | {:<18} | {:<20}",
255        "Bin", "Mean Predicted P", "Fraction Positive"
256    );
257    println!("   {}", "-".repeat(60));
258
259    for i in 0..mean_predicted.len() {
260        println!(
261            "   Bin {} | {:.4}              | {:.4}",
262            i + 1,
263            mean_predicted[i],
264            fraction_positives[i]
265        );
266    }
267
268    // Compute calibration error (Expected Calibration Error - ECE)
269    let mut ece = 0.0;
270    let mut total_samples = 0;
271
272    // Count samples in each bin
273    let n_bins = 5;
274    for i in 0..probabilities.len() {
275        let bin_idx = ((probabilities[i] * n_bins as f64).floor() as usize).min(n_bins - 1);
276        if bin_idx < mean_predicted.len() {
277            ece += (mean_predicted[bin_idx] - fraction_positives[bin_idx]).abs();
278            total_samples += 1;
279        }
280    }
281
282    if total_samples > 0 {
283        ece /= total_samples as f64;
284        println!("\n   Expected Calibration Error (ECE): {ece:.4}");
285        println!(
286            "   Interpretation: {}",
287            if ece < 0.1 {
288                "Well-calibrated (ECE < 0.1)"
289            } else if ece < 0.2 {
290                "Moderately calibrated"
291            } else {
292                "Poorly calibrated (ECE > 0.2)"
293            }
294        );
295    }
296
297    Ok(())
298}