1#![allow(
2 clippy::pedantic,
3 clippy::unnecessary_wraps,
4 clippy::needless_range_loop,
5 clippy::useless_vec,
6 clippy::needless_collect,
7 clippy::too_many_arguments
8)]
9use quantrs2_ml::utils::calibration::*;
19use quantrs2_ml::utils::metrics;
20use scirs2_core::ndarray::{array, Array1, Array2};
21use scirs2_core::random::prelude::*;
22
23fn main() -> Result<(), Box<dyn std::error::Error>> {
24 println!("=== QuantRS2-ML Calibration Demo ===\n");
25
26 println!("1. PLATT SCALING (Binary Classification)");
28 println!(" Purpose: Parametric calibration using logistic regression");
29 println!(" Best for: Well-separated binary classification\n");
30
31 demo_platt_scaling()?;
32
33 println!("\n{}\n", "=".repeat(60));
34
35 println!("2. ISOTONIC REGRESSION (Binary Classification)");
37 println!(" Purpose: Non-parametric monotonic calibration");
38 println!(" Best for: Non-linearly separable binary data\n");
39
40 demo_isotonic_regression()?;
41
42 println!("\n{}\n", "=".repeat(60));
43
44 println!("3. TEMPERATURE SCALING (Multi-class Classification)");
46 println!(" Purpose: Scale logits by single temperature parameter");
47 println!(" Best for: Neural network outputs, multi-class problems\n");
48
49 demo_temperature_scaling()?;
50
51 println!("\n{}\n", "=".repeat(60));
52
53 println!("4. CALIBRATION CURVE ANALYSIS");
55 println!(" Purpose: Visualize calibration quality (reliability diagram)\n");
56
57 demo_calibration_curve()?;
58
59 println!("\n=== Demo Complete ===");
60 println!("All calibration methods demonstrated successfully!");
61
62 Ok(())
63}
64
65fn demo_platt_scaling() -> Result<(), Box<dyn std::error::Error>> {
66 let scores = array![
69 2.5, 2.0, 1.8, 1.5, 1.2, -1.2, -1.5, -1.8, -2.0, -2.5 ];
72 let labels = array![1, 1, 1, 1, 1, 0, 0, 0, 0, 0];
73
74 println!(" Input scores: {scores:?}");
75 println!(" True labels: {labels:?}\n");
76
77 let mut scaler = PlattScaler::new();
79 scaler.fit(&scores, &labels)?;
80
81 if let Some((a, b)) = scaler.parameters() {
83 println!(" Fitted parameters:");
84 println!(" - Slope (a): {a:.4}");
85 println!(" - Intercept (b): {b:.4}");
86 }
87
88 let calibrated_probs = scaler.transform(&scores)?;
90 println!("\n Calibrated probabilities:");
91 for (i, (&score, &prob)) in scores.iter().zip(calibrated_probs.iter()).enumerate() {
92 println!(" Sample {i}: score={score:6.2} → P(class=1)={prob:.4}");
93 }
94
95 let predictions: Array1<usize> = calibrated_probs.mapv(|p| usize::from(p > 0.5));
97 let accuracy = metrics::accuracy(&predictions, &labels);
98 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
99
100 Ok(())
101}
102
103fn demo_isotonic_regression() -> Result<(), Box<dyn std::error::Error>> {
104 let scores = array![
106 0.1, 0.25, 0.2, 0.4, 0.35, 0.55, 0.6, 0.75, 0.7, 0.85, 0.95, 0.9 ];
111 let labels = array![0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1];
113
114 println!(" Input scores: {scores:?}");
115 println!(" True labels: {labels:?}\n");
116
117 let mut iso = IsotonicRegression::new();
119 iso.fit(&scores, &labels)?;
120
121 println!(" Fitted isotonic regression (maintains monotonicity)");
122
123 let calibrated_probs = iso.transform(&scores)?;
125 println!("\n Calibrated probabilities:");
126 for (i, (&score, &prob)) in scores.iter().zip(calibrated_probs.iter()).enumerate() {
127 println!(" Sample {i}: score={score:.2} → P(class=1)={prob:.4}");
128 }
129
130 let mut is_monotonic = true;
132 for i in 0..calibrated_probs.len() - 1 {
133 if calibrated_probs[i] > calibrated_probs[i + 1] + 1e-6 {
134 is_monotonic = false;
135 break;
136 }
137 }
138 println!(
139 "\n Monotonicity preserved: {}",
140 if is_monotonic { "✓" } else { "✗" }
141 );
142
143 Ok(())
144}
145
146fn demo_temperature_scaling() -> Result<(), Box<dyn std::error::Error>> {
147 let logits = array![
149 [5.0, 1.0, 0.5, 0.0], [1.0, 5.0, 0.5, 0.0], [0.5, 1.0, 5.0, 0.0], [0.0, 0.5, 1.0, 5.0], [3.0, 2.0, 1.0, 0.5], [1.0, 3.0, 2.0, 0.5], [0.5, 1.0, 3.0, 2.0], [0.5, 0.5, 1.0, 3.0], ];
158 let labels = array![0, 1, 2, 3, 0, 1, 2, 3];
159
160 println!(" Input: 4-class classification with 8 samples");
161 println!(" Logits shape: {}×{}\n", logits.nrows(), logits.ncols());
162
163 let mut uncalibrated_probs = Array2::zeros((logits.nrows(), logits.ncols()));
165 for i in 0..logits.nrows() {
166 let max_logit = logits
167 .row(i)
168 .iter()
169 .copied()
170 .fold(f64::NEG_INFINITY, f64::max);
171 let exp_sum: f64 = logits.row(i).iter().map(|&x| (x - max_logit).exp()).sum();
172 for j in 0..logits.ncols() {
173 uncalibrated_probs[(i, j)] = ((logits[(i, j)] - max_logit).exp()) / exp_sum;
174 }
175 }
176
177 let mut scaler = TemperatureScaler::new();
179 scaler.fit(&logits, &labels)?;
180
181 if let Some(temp) = scaler.temperature() {
183 println!(" Fitted temperature: T = {temp:.4}");
184 println!(
185 " Interpretation: {}",
186 if temp > 1.0 {
187 "Model is overconfident (T > 1 reduces confidence)"
188 } else if temp < 1.0 {
189 "Model is underconfident (T < 1 increases confidence)"
190 } else {
191 "Model is well-calibrated (T ≈ 1)"
192 }
193 );
194 }
195
196 let calibrated_probs = scaler.transform(&logits)?;
198
199 println!("\n Comparison (first 4 samples):");
200 println!(
201 " {:<8} | {:<20} | {:<20}",
202 "Sample", "Uncalibrated Max P", "Calibrated Max P"
203 );
204 println!(" {}", "-".repeat(60));
205
206 for i in 0..4 {
207 let uncal_max = uncalibrated_probs
208 .row(i)
209 .iter()
210 .copied()
211 .fold(f64::NEG_INFINITY, f64::max);
212 let cal_max = calibrated_probs
213 .row(i)
214 .iter()
215 .copied()
216 .fold(f64::NEG_INFINITY, f64::max);
217 println!(" Sample {i:<2} | {uncal_max:.4} | {cal_max:.4}");
218 }
219
220 let mut correct = 0;
222 for i in 0..calibrated_probs.nrows() {
223 let pred = calibrated_probs
224 .row(i)
225 .iter()
226 .enumerate()
227 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
228 .map(|(idx, _)| idx)
229 .unwrap();
230 if pred == labels[i] {
231 correct += 1;
232 }
233 }
234
235 let accuracy = correct as f64 / labels.len() as f64;
236 println!("\n Calibrated accuracy: {:.2}%", accuracy * 100.0);
237
238 Ok(())
239}
240
241fn demo_calibration_curve() -> Result<(), Box<dyn std::error::Error>> {
242 let probabilities = array![0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95];
244 let labels = array![0, 0, 0, 1, 0, 1, 1, 1, 1, 1];
245
246 println!(" Probabilities: {probabilities:?}");
247 println!(" True labels: {labels:?}\n");
248
249 let (mean_predicted, fraction_positives) = calibration_curve(&probabilities, &labels, 5)?;
251
252 println!(" Calibration Curve (5 bins):");
253 println!(
254 " {:<5} | {:<18} | {:<20}",
255 "Bin", "Mean Predicted P", "Fraction Positive"
256 );
257 println!(" {}", "-".repeat(60));
258
259 for i in 0..mean_predicted.len() {
260 println!(
261 " Bin {} | {:.4} | {:.4}",
262 i + 1,
263 mean_predicted[i],
264 fraction_positives[i]
265 );
266 }
267
268 let mut ece = 0.0;
270 let mut total_samples = 0;
271
272 let n_bins = 5;
274 for i in 0..probabilities.len() {
275 let bin_idx = ((probabilities[i] * n_bins as f64).floor() as usize).min(n_bins - 1);
276 if bin_idx < mean_predicted.len() {
277 ece += (mean_predicted[bin_idx] - fraction_positives[bin_idx]).abs();
278 total_samples += 1;
279 }
280 }
281
282 if total_samples > 0 {
283 ece /= total_samples as f64;
284 println!("\n Expected Calibration Error (ECE): {ece:.4}");
285 println!(
286 " Interpretation: {}",
287 if ece < 0.1 {
288 "Well-calibrated (ECE < 0.1)"
289 } else if ece < 0.2 {
290 "Moderately calibrated"
291 } else {
292 "Poorly calibrated (ECE > 0.2)"
293 }
294 );
295 }
296
297 Ok(())
298}