1#![allow(clippy::pedantic, clippy::unnecessary_wraps)]
2use quantrs2_ml::utils::calibration::*;
12use quantrs2_ml::utils::metrics;
13use scirs2_core::ndarray::{array, Array1, Array2};
14use scirs2_core::random::prelude::*;
15
16fn main() -> Result<(), Box<dyn std::error::Error>> {
17 println!("=== QuantRS2-ML Calibration Demo ===\n");
18
19 println!("1. PLATT SCALING (Binary Classification)");
21 println!(" Purpose: Parametric calibration using logistic regression");
22 println!(" Best for: Well-separated binary classification\n");
23
24 demo_platt_scaling()?;
25
26 println!("\n{}\n", "=".repeat(60));
27
28 println!("2. ISOTONIC REGRESSION (Binary Classification)");
30 println!(" Purpose: Non-parametric monotonic calibration");
31 println!(" Best for: Non-linearly separable binary data\n");
32
33 demo_isotonic_regression()?;
34
35 println!("\n{}\n", "=".repeat(60));
36
37 println!("3. TEMPERATURE SCALING (Multi-class Classification)");
39 println!(" Purpose: Scale logits by single temperature parameter");
40 println!(" Best for: Neural network outputs, multi-class problems\n");
41
42 demo_temperature_scaling()?;
43
44 println!("\n{}\n", "=".repeat(60));
45
46 println!("4. CALIBRATION CURVE ANALYSIS");
48 println!(" Purpose: Visualize calibration quality (reliability diagram)\n");
49
50 demo_calibration_curve()?;
51
52 println!("\n=== Demo Complete ===");
53 println!("All calibration methods demonstrated successfully!");
54
55 Ok(())
56}
57
58fn demo_platt_scaling() -> Result<(), Box<dyn std::error::Error>> {
59 let scores = array![
62 2.5, 2.0, 1.8, 1.5, 1.2, -1.2, -1.5, -1.8, -2.0, -2.5 ];
65 let labels = array![1, 1, 1, 1, 1, 0, 0, 0, 0, 0];
66
67 println!(" Input scores: {scores:?}");
68 println!(" True labels: {labels:?}\n");
69
70 let mut scaler = PlattScaler::new();
72 scaler.fit(&scores, &labels)?;
73
74 if let Some((a, b)) = scaler.parameters() {
76 println!(" Fitted parameters:");
77 println!(" - Slope (a): {a:.4}");
78 println!(" - Intercept (b): {b:.4}");
79 }
80
81 let calibrated_probs = scaler.transform(&scores)?;
83 println!("\n Calibrated probabilities:");
84 for (i, (&score, &prob)) in scores.iter().zip(calibrated_probs.iter()).enumerate() {
85 println!(" Sample {i}: score={score:6.2} → P(class=1)={prob:.4}");
86 }
87
88 let predictions: Array1<usize> = calibrated_probs.mapv(|p| usize::from(p > 0.5));
90 let accuracy = metrics::accuracy(&predictions, &labels);
91 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
92
93 Ok(())
94}
95
96fn demo_isotonic_regression() -> Result<(), Box<dyn std::error::Error>> {
97 let scores = array![
99 0.1, 0.25, 0.2, 0.4, 0.35, 0.55, 0.6, 0.75, 0.7, 0.85, 0.95, 0.9 ];
104 let labels = array![0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1];
106
107 println!(" Input scores: {scores:?}");
108 println!(" True labels: {labels:?}\n");
109
110 let mut iso = IsotonicRegression::new();
112 iso.fit(&scores, &labels)?;
113
114 println!(" Fitted isotonic regression (maintains monotonicity)");
115
116 let calibrated_probs = iso.transform(&scores)?;
118 println!("\n Calibrated probabilities:");
119 for (i, (&score, &prob)) in scores.iter().zip(calibrated_probs.iter()).enumerate() {
120 println!(" Sample {i}: score={score:.2} → P(class=1)={prob:.4}");
121 }
122
123 let mut is_monotonic = true;
125 for i in 0..calibrated_probs.len() - 1 {
126 if calibrated_probs[i] > calibrated_probs[i + 1] + 1e-6 {
127 is_monotonic = false;
128 break;
129 }
130 }
131 println!(
132 "\n Monotonicity preserved: {}",
133 if is_monotonic { "✓" } else { "✗" }
134 );
135
136 Ok(())
137}
138
139fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
140 let logits = array![
142 [5.0, 1.0, 0.5, 0.0], [1.0, 5.0, 0.5, 0.0], [0.5, 1.0, 5.0, 0.0], [0.0, 0.5, 1.0, 5.0], [3.0, 2.0, 1.0, 0.5], [1.0, 3.0, 2.0, 0.5], [0.5, 1.0, 3.0, 2.0], [0.5, 0.5, 1.0, 3.0], ];
151 let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
152
153 println!(" Input: 4-class classification with 8 samples");
154 println!(" Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
155
156 let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
158 for i in 0..logits.nrows() {
159 let max_logit = logits
160 .row(i)
161 .iter()
162 .copied()
163 .fold(f64::NEG_INFINITY, f64::max);
164 let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
165 for j in 0..logits.ncols() {
166 uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
167 }
168 }
169
170 let mut scaler = TemperatureScaler::new();
172 scaler.fit(&logits, &labels)?;
173
174 if let Some(temp) = scaler.temperature() {
176 println!(" Fitted temperature: T = {temp:.4}");
177 println!(
178 " Interpretation: {}",
179 if temp > 1.0 {
180 "Model is overconfident (T > 1 reduces confidence)"
181 } else if temp < 1.0 {
182 "Model is underconfident (T < 1 increases confidence)"
183 } else {
184 "Model is well-calibrated (T ≈ 1)"
185 }
186 );
187 }
188
189 let calibrated_probs = scaler.transform(&logits)?;
191
192 println!("\n Comparison (first 4 samples):");
193 println!(
194 " {:<8} | {:<20} | {:<20}",
195 "Sample", "Uncalibrated Max P", "Calibrated Max P"
196 );
197 println!(" {}", "-".repeat(60));
198
199 for i in 0..4 {
200 let uncal_max = uncalibrated_probs
201 .row(i)
202 .iter()
203 .copied()
204 .fold(f64::NEG_INFINITY, f64::max);
205 let cal_max = calibrated_probs
206 .row(i)
207 .iter()
208 .copied()
209 .fold(f64::NEG_INFINITY, f64::max);
210 println!(" Sample {i:<2} | {uncal_max:.4} | {cal_max:.4}");
211 }
212
213 let mut correct = 0;
215 for i in 0..calibrated_probs.nrows() {
216 let pred = calibrated_probs
217 .row(i)
218 .iter()
219 .enumerate()
220 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
221 .map(|(idx, _)| idx)
222 .unwrap();
223 if pred == labels[i] {
224 correct += 1;
225 }
226 }
227
228 let accuracy = correct as f64 / labels.len() as f64;
229 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
230
231 Ok(())
232}
233
234fn demo_calibration_curve() -> Result<(), Box<dyn std::error::Error>> {
235 let probabilities = array![0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95];
237 let labels = array![0, 0, 0, 1, 0, 1, 1, 1, 1, 1];
238
239 println!(" Probabilities: {probabilities:?}");
240 println!(" True labels: {labels:?}\n");
241
242 let (mean_predicted, fraction_positives) = calibration_curve(&probabilities, &labels, 5)?;
244
245 println!(" Calibration Curve (5 bins):");
246 println!(
247 " {:<5} | {:<18} | {:<20}",
248 "Bin", "Mean Predicted P", "Fraction Positive"
249 );
250 println!(" {}", "-".repeat(60));
251
252 for i in 0..mean_predicted.len() {
253 println!(
254 " Bin {} | {:.4} | {:.4}",
255 i + 1,
256 mean_predicted[i],
257 fraction_positives[i]
258 );
259 }
260
261 let mut ece = 0.0;
263 let mut total_samples = 0;
264
265 let n_bins = 5;
267 for i in 0..probabilities.len() {
268 let bin_idx = ((probabilities[i] * n_bins as f64).floor() as usize).min(n_bins - 1);
269 if bin_idx < mean_predicted.len() {
270 ece += (mean_predicted[bin_idx] - fraction_positives[bin_idx]).abs();
271 total_samples += 1;
272 }
273 }
274
275 if total_samples > 0 {
276 ece /= total_samples as f64;
277 println!("\n Expected Calibration Error (ECE): {ece:.4}");
278 println!(
279 " Interpretation: {}",
280 if ece < 0.1 {
281 "Well-calibrated (ECE < 0.1)"
282 } else if ece < 0.2 {
283 "Moderately calibrated"
284 } else {
285 "Poorly calibrated (ECE > 0.2)"
286 }
287 );
288 }
289
290 Ok(())
291}