1use super::*;
6use crate::error::{MLError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8pub fn calibration_curve(
11 probabilities: &Array1<f64>,
12 labels: &Array1<usize>,
13 n_bins: usize,
14) -> Result<(Array1<f64>, Array1<f64>)> {
15 if probabilities.len() != labels.len() {
16 return Err(MLError::InvalidInput(
17 "Probabilities and labels must have same length".to_string(),
18 ));
19 }
20 if n_bins < 2 {
21 return Err(MLError::InvalidInput(
22 "Number of bins must be at least 2".to_string(),
23 ));
24 }
25 let mut bins = vec![Vec::new(); n_bins];
26 for (i, &prob) in probabilities.iter().enumerate() {
27 let bin_idx = ((prob * n_bins as f64).floor() as usize).min(n_bins - 1);
28 bins[bin_idx].push((prob, labels[i]));
29 }
30 let mut mean_predicted = Vec::new();
31 let mut fraction_positives = Vec::new();
32 for bin in bins {
33 if !bin.is_empty() {
34 let sum_prob: f64 = bin.iter().map(|(p, _)| p).sum();
35 let sum_labels: f64 = bin.iter().map(|(_, l)| *l as f64).sum();
36 mean_predicted.push(sum_prob / bin.len() as f64);
37 fraction_positives.push(sum_labels / bin.len() as f64);
38 }
39 }
40 Ok((
41 Array1::from_vec(mean_predicted),
42 Array1::from_vec(fraction_positives),
43 ))
44}
45pub mod visualization {
47 use super::*;
48 #[derive(Debug, Clone)]
50 pub struct CalibrationPlotData {
51 pub mean_predicted: Array1<f64>,
53 pub fraction_positives: Array1<f64>,
55 pub bin_counts: Array1<usize>,
57 pub bin_edges: Vec<f64>,
59 }
60 pub fn generate_calibration_plot_data(
62 probabilities: &Array1<f64>,
63 labels: &Array1<usize>,
64 n_bins: usize,
65 ) -> Result<CalibrationPlotData> {
66 if probabilities.len() != labels.len() {
67 return Err(MLError::InvalidInput(
68 "Probabilities and labels must have same length".to_string(),
69 ));
70 }
71 if n_bins < 2 {
72 return Err(MLError::InvalidInput(
73 "Number of bins must be at least 2".to_string(),
74 ));
75 }
76 let mut bins = vec![Vec::new(); n_bins];
77 let bin_edges: Vec<f64> = (0..=n_bins).map(|i| i as f64 / n_bins as f64).collect();
78 for (i, &prob) in probabilities.iter().enumerate() {
79 let bin_idx = ((prob * n_bins as f64).floor() as usize).min(n_bins - 1);
80 bins[bin_idx].push((prob, labels[i]));
81 }
82 let mut mean_predicted = Vec::new();
83 let mut fraction_positives = Vec::new();
84 let mut bin_counts = Vec::new();
85 for bin in bins {
86 if !bin.is_empty() {
87 let sum_prob: f64 = bin.iter().map(|(p, _)| p).sum();
88 let sum_labels: f64 = bin.iter().map(|(_, l)| *l as f64).sum();
89 mean_predicted.push(sum_prob / bin.len() as f64);
90 fraction_positives.push(sum_labels / bin.len() as f64);
91 bin_counts.push(bin.len());
92 } else {
93 mean_predicted.push(
94 (bin_edges[mean_predicted.len()] + bin_edges[mean_predicted.len() + 1]) / 2.0,
95 );
96 fraction_positives.push(0.0);
97 bin_counts.push(0);
98 }
99 }
100 Ok(CalibrationPlotData {
101 mean_predicted: Array1::from_vec(mean_predicted),
102 fraction_positives: Array1::from_vec(fraction_positives),
103 bin_counts: Array1::from_vec(bin_counts),
104 bin_edges,
105 })
106 }
107 #[derive(Debug, Clone)]
109 pub struct CalibrationAnalysis {
110 pub ece: f64,
112 pub mce: f64,
114 pub brier_score: f64,
116 pub nll: f64,
118 pub n_bins: usize,
120 pub bin_errors: Array1<f64>,
122 pub interpretation: String,
124 }
125 impl CalibrationAnalysis {
126 fn interpret_ece(ece: f64) -> String {
128 if ece < 0.01 {
129 "Excellent calibration - predictions are highly reliable".to_string()
130 } else if ece < 0.05 {
131 "Good calibration - predictions are generally reliable".to_string()
132 } else if ece < 0.10 {
133 "Moderate calibration - some miscalibration present".to_string()
134 } else if ece < 0.20 {
135 "Poor calibration - significant miscalibration detected".to_string()
136 } else {
137 "Very poor calibration - predictions are unreliable".to_string()
138 }
139 }
140 }
141 pub fn analyze_calibration(
143 probabilities: &Array1<f64>,
144 labels: &Array1<usize>,
145 n_bins: usize,
146 ) -> Result<CalibrationAnalysis> {
147 let plot_data = generate_calibration_plot_data(probabilities, labels, n_bins)?;
148 let mut ece = 0.0;
149 let total_samples = probabilities.len() as f64;
150 for i in 0..plot_data.mean_predicted.len() {
151 let bin_error = (plot_data.mean_predicted[i] - plot_data.fraction_positives[i]).abs();
152 let bin_weight = plot_data.bin_counts[i] as f64 / total_samples;
153 ece += bin_weight * bin_error;
154 }
155 let bin_errors: Array1<f64> =
156 (&plot_data.mean_predicted - &plot_data.fraction_positives).mapv(|x| x.abs());
157 let mce = bin_errors.iter().cloned().fold(0.0f64, f64::max);
158 let mut brier_score = 0.0;
159 for (i, &prob) in probabilities.iter().enumerate() {
160 let true_label = labels[i] as f64;
161 brier_score += (prob - true_label).powi(2);
162 }
163 brier_score /= probabilities.len() as f64;
164 let mut nll = 0.0;
165 for (i, &prob) in probabilities.iter().enumerate() {
166 let true_label = labels[i];
167 let prob_clamped = prob.max(1e-10).min(1.0 - 1e-10);
168 if true_label == 1 {
169 nll -= prob_clamped.ln();
170 } else {
171 nll -= (1.0 - prob_clamped).ln();
172 }
173 }
174 nll /= probabilities.len() as f64;
175 let interpretation = CalibrationAnalysis::interpret_ece(ece);
176 Ok(CalibrationAnalysis {
177 ece,
178 mce,
179 brier_score,
180 nll,
181 n_bins,
182 bin_errors,
183 interpretation,
184 })
185 }
186 #[derive(Debug, Clone)]
188 pub struct CalibrationComparison {
189 pub method_name: String,
191 pub analysis: CalibrationAnalysis,
193 pub calibrated_probs: Array1<f64>,
195 }
196 pub fn compare_calibration_methods(
198 uncalibrated_probs: &Array1<f64>,
199 labels: &Array1<usize>,
200 n_bins: usize,
201 ) -> Result<Vec<CalibrationComparison>> {
202 let mut comparisons = Vec::new();
203 let uncal_analysis = analyze_calibration(uncalibrated_probs, labels, n_bins)?;
204 comparisons.push(CalibrationComparison {
205 method_name: "Uncalibrated".to_string(),
206 analysis: uncal_analysis,
207 calibrated_probs: uncalibrated_probs.clone(),
208 });
209 if labels.iter().max().unwrap_or(&0) == &1 {
210 let mut platt = PlattScaler::new();
211 if let Ok(calibrated) = platt.fit_transform(uncalibrated_probs, labels) {
212 let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
213 comparisons.push(CalibrationComparison {
214 method_name: "Platt Scaling".to_string(),
215 analysis,
216 calibrated_probs: calibrated,
217 });
218 }
219 }
220 if labels.iter().max().unwrap_or(&0) == &1 {
221 let mut isotonic = IsotonicRegression::new();
222 if let Ok(calibrated) = isotonic.fit_transform(uncalibrated_probs, labels) {
223 let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
224 comparisons.push(CalibrationComparison {
225 method_name: "Isotonic Regression".to_string(),
226 analysis,
227 calibrated_probs: calibrated,
228 });
229 }
230 }
231 if labels.iter().max().unwrap_or(&0) == &1 {
232 let mut bbq = BayesianBinningQuantiles::new(n_bins);
233 if let Ok(calibrated) = bbq.fit_transform(uncalibrated_probs, labels) {
234 let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
235 comparisons.push(CalibrationComparison {
236 method_name: "Bayesian Binning (BBQ)".to_string(),
237 analysis,
238 calibrated_probs: calibrated,
239 });
240 }
241 }
242 Ok(comparisons)
243 }
244 pub fn generate_comparison_report(comparisons: &[CalibrationComparison]) -> String {
246 let mut report = String::new();
247 report.push_str("=== Calibration Methods Comparison Report ===\n\n");
248 let mut best_ece_idx = 0;
249 let mut best_mce_idx = 0;
250 let mut best_brier_idx = 0;
251 let mut best_nll_idx = 0;
252 for (i, comp) in comparisons.iter().enumerate() {
253 if comp.analysis.ece < comparisons[best_ece_idx].analysis.ece {
254 best_ece_idx = i;
255 }
256 if comp.analysis.mce < comparisons[best_mce_idx].analysis.mce {
257 best_mce_idx = i;
258 }
259 if comp.analysis.brier_score < comparisons[best_brier_idx].analysis.brier_score {
260 best_brier_idx = i;
261 }
262 if comp.analysis.nll < comparisons[best_nll_idx].analysis.nll {
263 best_nll_idx = i;
264 }
265 }
266 for (i, comp) in comparisons.iter().enumerate() {
267 report.push_str(&format!("\n{}\n", comp.method_name));
268 report.push_str(&format!("{}\n", "=".repeat(comp.method_name.len())));
269 report.push_str(&format!(
270 "ECE: {:.4}{}\n",
271 comp.analysis.ece,
272 if i == best_ece_idx { " ⭐ BEST" } else { "" }
273 ));
274 report.push_str(&format!(
275 "MCE: {:.4}{}\n",
276 comp.analysis.mce,
277 if i == best_mce_idx { " ⭐ BEST" } else { "" }
278 ));
279 report.push_str(&format!(
280 "Brier Score: {:.4}{}\n",
281 comp.analysis.brier_score,
282 if i == best_brier_idx { " ⭐ BEST" } else { "" }
283 ));
284 report.push_str(&format!(
285 "NLL: {:.4}{}\n",
286 comp.analysis.nll,
287 if i == best_nll_idx { " ⭐ BEST" } else { "" }
288 ));
289 report.push_str(&format!(
290 "Interpretation: {}\n",
291 comp.analysis.interpretation
292 ));
293 }
294 report.push_str("\n=== Recommendations ===\n");
295 report.push_str(&format!(
296 "Best overall (ECE): {}\n",
297 comparisons[best_ece_idx].method_name
298 ));
299 report.push_str(&format!(
300 "Most reliable (MCE): {}\n",
301 comparisons[best_mce_idx].method_name
302 ));
303 report.push_str(&format!(
304 "Best probability estimates (Brier): {}\n",
305 comparisons[best_brier_idx].method_name
306 ));
307 report
308 }
309}
310pub mod quantum_calibration {
313 use super::*;
314 #[derive(Debug, Clone)]
316 pub struct QuantumCalibrationConfig {
317 pub n_bins: usize,
319 pub use_error_mitigation: bool,
321 pub confidence_level: f64,
323 pub account_shot_noise: bool,
325 }
326 impl Default for QuantumCalibrationConfig {
327 fn default() -> Self {
328 Self {
329 n_bins: 10,
330 use_error_mitigation: true,
331 confidence_level: 0.95,
332 account_shot_noise: true,
333 }
334 }
335 }
336 #[derive(Debug, Clone)]
342 pub struct QuantumNeuralNetworkCalibrator {
343 method: CalibrationMethod,
345 config: QuantumCalibrationConfig,
347 shot_noise_estimates: Option<Array1<f64>>,
349 fitted: bool,
351 }
352 #[derive(Debug, Clone)]
354 pub enum CalibrationMethod {
355 Temperature(TemperatureScaler),
357 Vector(VectorScaler),
359 Platt(PlattScaler),
361 Isotonic(IsotonicRegression),
363 BayesianBinning(BayesianBinningQuantiles),
365 }
366 impl QuantumNeuralNetworkCalibrator {
367 pub fn new() -> Self {
369 Self {
370 method: CalibrationMethod::Temperature(TemperatureScaler::new()),
371 config: QuantumCalibrationConfig::default(),
372 shot_noise_estimates: None,
373 fitted: false,
374 }
375 }
376 pub fn with_method(method: CalibrationMethod) -> Self {
378 Self {
379 method,
380 config: QuantumCalibrationConfig::default(),
381 shot_noise_estimates: None,
382 fitted: false,
383 }
384 }
385 pub fn with_config(mut self, config: QuantumCalibrationConfig) -> Self {
387 self.config = config;
388 self
389 }
390 pub fn fit_binary(
392 &mut self,
393 probabilities: &Array1<f64>,
394 labels: &Array1<usize>,
395 shot_counts: Option<&Array1<usize>>,
396 ) -> Result<()> {
397 if let Some(shots) = shot_counts {
398 if self.config.account_shot_noise {
399 self.shot_noise_estimates =
400 Some(self.estimate_shot_noise(probabilities, shots));
401 }
402 }
403 match &mut self.method {
404 CalibrationMethod::Platt(scaler) => {
405 scaler.fit(probabilities, labels)?;
406 }
407 CalibrationMethod::Isotonic(scaler) => {
408 scaler.fit(probabilities, labels)?;
409 }
410 CalibrationMethod::BayesianBinning(scaler) => {
411 scaler.fit(probabilities, labels)?;
412 }
413 _ => {
414 return Err(MLError::InvalidInput(
415 "Binary calibration requires Platt, Isotonic, or BBQ method".to_string(),
416 ));
417 }
418 }
419 self.fitted = true;
420 Ok(())
421 }
422 pub fn fit_multiclass(
424 &mut self,
425 logits: &Array2<f64>,
426 labels: &Array1<usize>,
427 shot_counts: Option<&Array1<usize>>,
428 ) -> Result<()> {
429 if let Some(shots) = shot_counts {
430 if self.config.account_shot_noise {
431 let avg_probs = logits.mean_axis(scirs2_core::ndarray::Axis(1)).unwrap();
432 self.shot_noise_estimates = Some(self.estimate_shot_noise(&avg_probs, shots));
433 }
434 }
435 match &mut self.method {
436 CalibrationMethod::Temperature(scaler) => {
437 scaler.fit(logits, labels)?;
438 }
439 CalibrationMethod::Vector(scaler) => {
440 scaler.fit(logits, labels)?;
441 }
442 _ => {
443 return Err(MLError::InvalidInput(
444 "Multi-class calibration requires Temperature or Vector method".to_string(),
445 ));
446 }
447 }
448 self.fitted = true;
449 Ok(())
450 }
451 fn estimate_shot_noise(
453 &self,
454 probabilities: &Array1<f64>,
455 shot_counts: &Array1<usize>,
456 ) -> Array1<f64> {
457 probabilities
458 .iter()
459 .zip(shot_counts.iter())
460 .map(|(&p, &n)| {
461 if n > 0 {
462 (p * (1.0 - p) / n as f64).sqrt()
463 } else {
464 0.0
465 }
466 })
467 .collect::<Vec<_>>()
468 .into()
469 }
470 pub fn transform_binary(&self, probabilities: &Array1<f64>) -> Result<Array1<f64>> {
472 if !self.fitted {
473 return Err(MLError::InvalidInput(
474 "Calibrator must be fitted before transform".to_string(),
475 ));
476 }
477 match &self.method {
478 CalibrationMethod::Platt(scaler) => scaler.transform(probabilities),
479 CalibrationMethod::Isotonic(scaler) => scaler.transform(probabilities),
480 CalibrationMethod::BayesianBinning(scaler) => scaler.transform(probabilities),
481 _ => Err(MLError::InvalidInput(
482 "Method does not support binary transformation".to_string(),
483 )),
484 }
485 }
486 pub fn transform_multiclass(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
488 if !self.fitted {
489 return Err(MLError::InvalidInput(
490 "Calibrator must be fitted before transform".to_string(),
491 ));
492 }
493 match &self.method {
494 CalibrationMethod::Temperature(scaler) => scaler.transform(logits),
495 CalibrationMethod::Vector(scaler) => scaler.transform(logits),
496 _ => Err(MLError::InvalidInput(
497 "Method does not support multi-class transformation".to_string(),
498 )),
499 }
500 }
501 pub fn transform_with_uncertainty(
503 &self,
504 probabilities: &Array1<f64>,
505 ) -> Result<Vec<(f64, f64, f64)>> {
506 if !self.fitted {
507 return Err(MLError::InvalidInput(
508 "Calibrator must be fitted before transform".to_string(),
509 ));
510 }
511 match &self.method {
512 CalibrationMethod::BayesianBinning(scaler) => {
513 scaler.predict_with_uncertainty(probabilities, self.config.confidence_level)
514 }
515 _ => {
516 let calibrated = self.transform_binary(probabilities)?;
517 if let Some(noise) = &self.shot_noise_estimates {
518 let results = calibrated
519 .iter()
520 .zip(noise.iter())
521 .map(|(&p, &sigma)| {
522 let z = 1.96;
523 let lower = (p - z * sigma).max(0.0);
524 let upper = (p + z * sigma).min(1.0);
525 (p, lower, upper)
526 })
527 .collect();
528 Ok(results)
529 } else {
530 Ok(calibrated.iter().map(|&p| (p, p, p)).collect())
531 }
532 }
533 }
534 }
535 pub fn evaluate_quantum_calibration(
537 &self,
538 probabilities: &Array1<f64>,
539 labels: &Array1<usize>,
540 ) -> Result<QuantumCalibrationMetrics> {
541 let calibrated = self.transform_binary(probabilities)?;
542 let analysis =
543 visualization::analyze_calibration(&calibrated, labels, self.config.n_bins)?;
544 let shot_noise_impact = if let Some(noise) = &self.shot_noise_estimates {
545 noise.mean().unwrap_or(0.0)
546 } else {
547 0.0
548 };
549 Ok(QuantumCalibrationMetrics {
550 ece: analysis.ece,
551 mce: analysis.mce,
552 brier_score: analysis.brier_score,
553 nll: analysis.nll,
554 shot_noise_impact,
555 interpretation: analysis.interpretation,
556 })
557 }
558 }
559 impl Default for QuantumNeuralNetworkCalibrator {
560 fn default() -> Self {
561 Self::new()
562 }
563 }
564 #[derive(Debug, Clone)]
566 pub struct QuantumCalibrationMetrics {
567 pub ece: f64,
569 pub mce: f64,
571 pub brier_score: f64,
573 pub nll: f64,
575 pub shot_noise_impact: f64,
577 pub interpretation: String,
579 }
580 pub fn quantum_ensemble_calibration(
583 probabilities: &Array1<f64>,
584 labels: &Array1<usize>,
585 shot_counts: &Array1<usize>,
586 n_bins: usize,
587 ) -> Result<(Array1<f64>, QuantumCalibrationMetrics)> {
588 let mut platt_cal = QuantumNeuralNetworkCalibrator::with_method(CalibrationMethod::Platt(
589 PlattScaler::new(),
590 ));
591 platt_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
592 let mut isotonic_cal = QuantumNeuralNetworkCalibrator::with_method(
593 CalibrationMethod::Isotonic(IsotonicRegression::new()),
594 );
595 isotonic_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
596 let mut bbq_cal = QuantumNeuralNetworkCalibrator::with_method(
597 CalibrationMethod::BayesianBinning(BayesianBinningQuantiles::new(n_bins)),
598 );
599 bbq_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
600 let platt_probs = platt_cal.transform_binary(probabilities)?;
601 let isotonic_probs = isotonic_cal.transform_binary(probabilities)?;
602 let bbq_probs = bbq_cal.transform_binary(probabilities)?;
603 let platt_metrics = platt_cal.evaluate_quantum_calibration(probabilities, labels)?;
604 let isotonic_metrics = isotonic_cal.evaluate_quantum_calibration(probabilities, labels)?;
605 let bbq_metrics = bbq_cal.evaluate_quantum_calibration(probabilities, labels)?;
606 let platt_weight = 1.0 / (platt_metrics.ece + 1e-6);
607 let isotonic_weight = 1.0 / (isotonic_metrics.ece + 1e-6);
608 let bbq_weight = 1.0 / (bbq_metrics.ece + 1e-6);
609 let total_weight = platt_weight + isotonic_weight + bbq_weight;
610 let ensemble_probs = (&platt_probs * (platt_weight / total_weight))
611 + (&isotonic_probs * (isotonic_weight / total_weight))
612 + (&bbq_probs * (bbq_weight / total_weight));
613 let ensemble_analysis =
614 visualization::analyze_calibration(&ensemble_probs, labels, n_bins)?;
615 let metrics = QuantumCalibrationMetrics {
616 ece: ensemble_analysis.ece,
617 mce: ensemble_analysis.mce,
618 brier_score: ensemble_analysis.brier_score,
619 nll: ensemble_analysis.nll,
620 shot_noise_impact: platt_metrics.shot_noise_impact,
621 interpretation: ensemble_analysis.interpretation,
622 };
623 Ok((ensemble_probs, metrics))
624 }
625}
626pub mod ensemble_selection {
628 use super::*;
629 use crate::utils::split::KFold;
630 #[derive(Debug, Clone)]
632 pub struct CalibratorCandidate {
633 pub name: String,
635 pub cv_ece_scores: Vec<f64>,
637 pub mean_ece: f64,
639 pub std_ece: f64,
641 pub is_binary: bool,
643 }
644 #[derive(Debug, Clone)]
646 pub enum SelectionStrategy {
647 BestSingle,
649 TopK(usize),
651 Threshold(f64),
653 WeightedAll,
655 }
656 #[derive(Debug, Clone)]
658 pub struct EnsembleSelectionResult {
659 pub selected_methods: Vec<String>,
661 pub weights: Vec<f64>,
663 pub method_performances: Vec<CalibratorCandidate>,
665 pub best_method: String,
667 pub ensemble_ece: f64,
669 }
670 pub fn select_binary_ensemble(
672 probabilities: &Array1<f64>,
673 labels: &Array1<usize>,
674 n_folds: usize,
675 strategy: SelectionStrategy,
676 ) -> Result<EnsembleSelectionResult> {
677 if n_folds < 2 {
678 return Err(MLError::InvalidInput(
679 "Need at least 2 folds for cross-validation".to_string(),
680 ));
681 }
682 let kfold = KFold::new(probabilities.len(), n_folds, true)?;
683 let method_names = vec!["Platt", "Isotonic", "BBQ-5", "BBQ-10"];
684 let mut candidates = Vec::new();
685 for method_name in method_names {
686 let mut cv_ece_scores = Vec::new();
687 for fold in 0..n_folds {
688 let (train_indices, val_indices) = kfold.get_fold(fold)?;
689 let train_probs: Array1<f64> =
690 train_indices.iter().map(|&i| probabilities[i]).collect();
691 let train_labels: Array1<usize> =
692 train_indices.iter().map(|&i| labels[i]).collect();
693 let val_probs: Array1<f64> =
694 val_indices.iter().map(|&i| probabilities[i]).collect();
695 let val_labels: Array1<usize> = val_indices.iter().map(|&i| labels[i]).collect();
696 let calibrated_val = match method_name {
697 "Platt" => {
698 let mut scaler = PlattScaler::new();
699 scaler.fit(&train_probs, &train_labels)?;
700 scaler.transform(&val_probs)?
701 }
702 "Isotonic" => {
703 let mut scaler = IsotonicRegression::new();
704 scaler.fit(&train_probs, &train_labels)?;
705 scaler.transform(&val_probs)?
706 }
707 "BBQ-5" => {
708 let mut scaler = BayesianBinningQuantiles::new(5);
709 scaler.fit(&train_probs, &train_labels)?;
710 scaler.transform(&val_probs)?
711 }
712 "BBQ-10" => {
713 let mut scaler = BayesianBinningQuantiles::new(10);
714 scaler.fit(&train_probs, &train_labels)?;
715 scaler.transform(&val_probs)?
716 }
717 _ => {
718 return Err(MLError::InvalidInput(format!(
719 "Unknown method: {}",
720 method_name
721 )));
722 }
723 };
724 let analysis =
725 visualization::analyze_calibration(&calibrated_val, &val_labels, 10)?;
726 cv_ece_scores.push(analysis.ece);
727 }
728 let mean_ece = cv_ece_scores.iter().sum::<f64>() / cv_ece_scores.len() as f64;
729 let variance = cv_ece_scores
730 .iter()
731 .map(|&x| (x - mean_ece).powi(2))
732 .sum::<f64>()
733 / cv_ece_scores.len() as f64;
734 let std_ece = variance.sqrt();
735 candidates.push(CalibratorCandidate {
736 name: method_name.to_string(),
737 cv_ece_scores,
738 mean_ece,
739 std_ece,
740 is_binary: true,
741 });
742 }
743 candidates.sort_by(|a, b| a.mean_ece.partial_cmp(&b.mean_ece).unwrap());
744 let (selected_methods, weights) = match strategy {
745 SelectionStrategy::BestSingle => (vec![candidates[0].name.clone()], vec![1.0]),
746 SelectionStrategy::TopK(k) => {
747 let k = k.min(candidates.len());
748 let methods: Vec<String> = candidates[..k].iter().map(|c| c.name.clone()).collect();
749 let weights = vec![1.0 / k as f64; k];
750 (methods, weights)
751 }
752 SelectionStrategy::Threshold(threshold) => {
753 let selected: Vec<_> = candidates
754 .iter()
755 .filter(|c| c.mean_ece < threshold)
756 .map(|c| c.name.clone())
757 .collect();
758 if selected.is_empty() {
759 (vec![candidates[0].name.clone()], vec![1.0])
760 } else {
761 let n = selected.len();
762 let weights = vec![1.0 / n as f64; n];
763 (selected, weights)
764 }
765 }
766 SelectionStrategy::WeightedAll => {
767 let methods: Vec<String> = candidates.iter().map(|c| c.name.clone()).collect();
768 let inv_eces: Vec<f64> = candidates
769 .iter()
770 .map(|c| 1.0 / (c.mean_ece + 1e-6))
771 .collect();
772 let sum_inv: f64 = inv_eces.iter().sum();
773 let weights: Vec<f64> = inv_eces.iter().map(|&w| w / sum_inv).collect();
774 (methods, weights)
775 }
776 };
777 let best_method = candidates[0].name.clone();
778 let ensemble_ece = if weights.len() == 1 {
779 candidates[0].mean_ece
780 } else {
781 candidates
782 .iter()
783 .zip(&weights)
784 .map(|(c, &w)| c.mean_ece * w)
785 .sum()
786 };
787 Ok(EnsembleSelectionResult {
788 selected_methods,
789 weights,
790 method_performances: candidates,
791 best_method,
792 ensemble_ece,
793 })
794 }
795 #[derive(Debug, Clone)]
798 pub struct CalibrationAwareSelector {
799 strategy: SelectionStrategy,
801 n_folds: usize,
803 is_binary: bool,
805 }
806 impl CalibrationAwareSelector {
807 pub fn new(n_folds: usize, is_binary: bool) -> Self {
809 Self {
810 strategy: SelectionStrategy::BestSingle,
811 n_folds,
812 is_binary,
813 }
814 }
815 pub fn with_strategy(mut self, strategy: SelectionStrategy) -> Self {
817 self.strategy = strategy;
818 self
819 }
820 pub fn select_binary(
822 &self,
823 probabilities: &Array1<f64>,
824 labels: &Array1<usize>,
825 ) -> Result<EnsembleSelectionResult> {
826 select_binary_ensemble(probabilities, labels, self.n_folds, self.strategy.clone())
827 }
828 pub fn generate_selection_report(&self, result: &EnsembleSelectionResult) -> String {
830 let mut report = String::new();
831 report.push_str("=== Calibration Method Selection Report ===\n\n");
832 report.push_str("Cross-Validation Results:\n");
833 report.push_str(&format!("{:-<60}\n", ""));
834 for method in &result.method_performances {
835 report.push_str(&format!(
836 "{:<15} | Mean ECE: {:.4} ± {:.4}\n",
837 method.name, method.mean_ece, method.std_ece
838 ));
839 }
840 report.push_str(&format!("\n{:-<60}\n", ""));
841 report.push_str(&format!(
842 "\nBest Individual Method: {}\n",
843 result.best_method
844 ));
845 report.push_str(&format!(
846 "Expected Ensemble ECE: {:.4}\n\n",
847 result.ensemble_ece
848 ));
849 report.push_str("Selected Ensemble:\n");
850 for (method, weight) in result.selected_methods.iter().zip(&result.weights) {
851 report.push_str(&format!(" {} (weight: {:.3})\n", method, weight));
852 }
853 report.push_str("\nRecommendation:\n");
854 if result.selected_methods.len() == 1 {
855 report.push_str(&format!(
856 "Use {} for best calibration performance.\n",
857 result.selected_methods[0]
858 ));
859 } else {
860 report.push_str(
861 "Use weighted ensemble of selected methods for robust calibration.\n",
862 );
863 }
864 report
865 }
866 }
867}