1use super::*;
6use crate::error::{MLError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8pub fn calibration_curve(
11 probabilities: &Array1<f64>,
12 labels: &Array1<usize>,
13 n_bins: usize,
14) -> Result<(Array1<f64>, Array1<f64>)> {
15 if probabilities.len() != labels.len() {
16 return Err(MLError::InvalidInput(
17 "Probabilities and labels must have same length".to_string(),
18 ));
19 }
20 if n_bins < 2 {
21 return Err(MLError::InvalidInput(
22 "Number of bins must be at least 2".to_string(),
23 ));
24 }
25 let mut bins = vec![Vec::new(); n_bins];
26 for (i, &prob) in probabilities.iter().enumerate() {
27 let bin_idx = ((prob * n_bins as f64).floor() as usize).min(n_bins - 1);
28 bins[bin_idx].push((prob, labels[i]));
29 }
30 let mut mean_predicted = Vec::new();
31 let mut fraction_positives = Vec::new();
32 for bin in bins {
33 if !bin.is_empty() {
34 let sum_prob: f64 = bin.iter().map(|(p, _)| p).sum();
35 let sum_labels: f64 = bin.iter().map(|(_, l)| *l as f64).sum();
36 mean_predicted.push(sum_prob / bin.len() as f64);
37 fraction_positives.push(sum_labels / bin.len() as f64);
38 }
39 }
40 Ok((
41 Array1::from_vec(mean_predicted),
42 Array1::from_vec(fraction_positives),
43 ))
44}
45pub mod visualization {
47 use super::*;
48 #[derive(Debug, Clone)]
50 pub struct CalibrationPlotData {
51 pub mean_predicted: Array1<f64>,
53 pub fraction_positives: Array1<f64>,
55 pub bin_counts: Array1<usize>,
57 pub bin_edges: Vec<f64>,
59 }
60 pub fn generate_calibration_plot_data(
62 probabilities: &Array1<f64>,
63 labels: &Array1<usize>,
64 n_bins: usize,
65 ) -> Result<CalibrationPlotData> {
66 if probabilities.len() != labels.len() {
67 return Err(MLError::InvalidInput(
68 "Probabilities and labels must have same length".to_string(),
69 ));
70 }
71 if n_bins < 2 {
72 return Err(MLError::InvalidInput(
73 "Number of bins must be at least 2".to_string(),
74 ));
75 }
76 let mut bins = vec![Vec::new(); n_bins];
77 let bin_edges: Vec<f64> = (0..=n_bins).map(|i| i as f64 / n_bins as f64).collect();
78 for (i, &prob) in probabilities.iter().enumerate() {
79 let bin_idx = ((prob * n_bins as f64).floor() as usize).min(n_bins - 1);
80 bins[bin_idx].push((prob, labels[i]));
81 }
82 let mut mean_predicted = Vec::new();
83 let mut fraction_positives = Vec::new();
84 let mut bin_counts = Vec::new();
85 for bin in bins {
86 if !bin.is_empty() {
87 let sum_prob: f64 = bin.iter().map(|(p, _)| p).sum();
88 let sum_labels: f64 = bin.iter().map(|(_, l)| *l as f64).sum();
89 mean_predicted.push(sum_prob / bin.len() as f64);
90 fraction_positives.push(sum_labels / bin.len() as f64);
91 bin_counts.push(bin.len());
92 } else {
93 mean_predicted.push(
94 (bin_edges[mean_predicted.len()] + bin_edges[mean_predicted.len() + 1]) / 2.0,
95 );
96 fraction_positives.push(0.0);
97 bin_counts.push(0);
98 }
99 }
100 Ok(CalibrationPlotData {
101 mean_predicted: Array1::from_vec(mean_predicted),
102 fraction_positives: Array1::from_vec(fraction_positives),
103 bin_counts: Array1::from_vec(bin_counts),
104 bin_edges,
105 })
106 }
107 #[derive(Debug, Clone)]
109 pub struct CalibrationAnalysis {
110 pub ece: f64,
112 pub mce: f64,
114 pub brier_score: f64,
116 pub nll: f64,
118 pub n_bins: usize,
120 pub bin_errors: Array1<f64>,
122 pub interpretation: String,
124 }
125 impl CalibrationAnalysis {
126 fn interpret_ece(ece: f64) -> String {
128 if ece < 0.01 {
129 "Excellent calibration - predictions are highly reliable".to_string()
130 } else if ece < 0.05 {
131 "Good calibration - predictions are generally reliable".to_string()
132 } else if ece < 0.10 {
133 "Moderate calibration - some miscalibration present".to_string()
134 } else if ece < 0.20 {
135 "Poor calibration - significant miscalibration detected".to_string()
136 } else {
137 "Very poor calibration - predictions are unreliable".to_string()
138 }
139 }
140 }
141 pub fn analyze_calibration(
143 probabilities: &Array1<f64>,
144 labels: &Array1<usize>,
145 n_bins: usize,
146 ) -> Result<CalibrationAnalysis> {
147 let plot_data = generate_calibration_plot_data(probabilities, labels, n_bins)?;
148 let mut ece = 0.0;
149 let total_samples = probabilities.len() as f64;
150 for i in 0..plot_data.mean_predicted.len() {
151 let bin_error = (plot_data.mean_predicted[i] - plot_data.fraction_positives[i]).abs();
152 let bin_weight = plot_data.bin_counts[i] as f64 / total_samples;
153 ece += bin_weight * bin_error;
154 }
155 let bin_errors: Array1<f64> =
156 (&plot_data.mean_predicted - &plot_data.fraction_positives).mapv(|x| x.abs());
157 let mce = bin_errors.iter().cloned().fold(0.0f64, f64::max);
158 let mut brier_score = 0.0;
159 for (i, &prob) in probabilities.iter().enumerate() {
160 let true_label = labels[i] as f64;
161 brier_score += (prob - true_label).powi(2);
162 }
163 brier_score /= probabilities.len() as f64;
164 let mut nll = 0.0;
165 for (i, &prob) in probabilities.iter().enumerate() {
166 let true_label = labels[i];
167 let prob_clamped = prob.max(1e-10).min(1.0 - 1e-10);
168 if true_label == 1 {
169 nll -= prob_clamped.ln();
170 } else {
171 nll -= (1.0 - prob_clamped).ln();
172 }
173 }
174 nll /= probabilities.len() as f64;
175 let interpretation = CalibrationAnalysis::interpret_ece(ece);
176 Ok(CalibrationAnalysis {
177 ece,
178 mce,
179 brier_score,
180 nll,
181 n_bins,
182 bin_errors,
183 interpretation,
184 })
185 }
186 #[derive(Debug, Clone)]
188 pub struct CalibrationComparison {
189 pub method_name: String,
191 pub analysis: CalibrationAnalysis,
193 pub calibrated_probs: Array1<f64>,
195 }
196 pub fn compare_calibration_methods(
198 uncalibrated_probs: &Array1<f64>,
199 labels: &Array1<usize>,
200 n_bins: usize,
201 ) -> Result<Vec<CalibrationComparison>> {
202 let mut comparisons = Vec::new();
203 let uncal_analysis = analyze_calibration(uncalibrated_probs, labels, n_bins)?;
204 comparisons.push(CalibrationComparison {
205 method_name: "Uncalibrated".to_string(),
206 analysis: uncal_analysis,
207 calibrated_probs: uncalibrated_probs.clone(),
208 });
209 if labels.iter().max().unwrap_or(&0) == &1 {
210 let mut platt = PlattScaler::new();
211 if let Ok(calibrated) = platt.fit_transform(uncalibrated_probs, labels) {
212 let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
213 comparisons.push(CalibrationComparison {
214 method_name: "Platt Scaling".to_string(),
215 analysis,
216 calibrated_probs: calibrated,
217 });
218 }
219 }
220 if labels.iter().max().unwrap_or(&0) == &1 {
221 let mut isotonic = IsotonicRegression::new();
222 if let Ok(calibrated) = isotonic.fit_transform(uncalibrated_probs, labels) {
223 let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
224 comparisons.push(CalibrationComparison {
225 method_name: "Isotonic Regression".to_string(),
226 analysis,
227 calibrated_probs: calibrated,
228 });
229 }
230 }
231 if labels.iter().max().unwrap_or(&0) == &1 {
232 let mut bbq = BayesianBinningQuantiles::new(n_bins);
233 if let Ok(calibrated) = bbq.fit_transform(uncalibrated_probs, labels) {
234 let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
235 comparisons.push(CalibrationComparison {
236 method_name: "Bayesian Binning (BBQ)".to_string(),
237 analysis,
238 calibrated_probs: calibrated,
239 });
240 }
241 }
242 Ok(comparisons)
243 }
244 pub fn generate_comparison_report(comparisons: &[CalibrationComparison]) -> String {
246 let mut report = String::new();
247 report.push_str("=== Calibration Methods Comparison Report ===\n\n");
248 let mut best_ece_idx = 0;
249 let mut best_mce_idx = 0;
250 let mut best_brier_idx = 0;
251 let mut best_nll_idx = 0;
252 for (i, comp) in comparisons.iter().enumerate() {
253 if comp.analysis.ece < comparisons[best_ece_idx].analysis.ece {
254 best_ece_idx = i;
255 }
256 if comp.analysis.mce < comparisons[best_mce_idx].analysis.mce {
257 best_mce_idx = i;
258 }
259 if comp.analysis.brier_score < comparisons[best_brier_idx].analysis.brier_score {
260 best_brier_idx = i;
261 }
262 if comp.analysis.nll < comparisons[best_nll_idx].analysis.nll {
263 best_nll_idx = i;
264 }
265 }
266 for (i, comp) in comparisons.iter().enumerate() {
267 report.push_str(&format!("\n{}\n", comp.method_name));
268 report.push_str(&format!("{}\n", "=".repeat(comp.method_name.len())));
269 report.push_str(&format!(
270 "ECE: {:.4}{}\n",
271 comp.analysis.ece,
272 if i == best_ece_idx { " ⭐ BEST" } else { "" }
273 ));
274 report.push_str(&format!(
275 "MCE: {:.4}{}\n",
276 comp.analysis.mce,
277 if i == best_mce_idx { " ⭐ BEST" } else { "" }
278 ));
279 report.push_str(&format!(
280 "Brier Score: {:.4}{}\n",
281 comp.analysis.brier_score,
282 if i == best_brier_idx { " ⭐ BEST" } else { "" }
283 ));
284 report.push_str(&format!(
285 "NLL: {:.4}{}\n",
286 comp.analysis.nll,
287 if i == best_nll_idx { " ⭐ BEST" } else { "" }
288 ));
289 report.push_str(&format!(
290 "Interpretation: {}\n",
291 comp.analysis.interpretation
292 ));
293 }
294 report.push_str("\n=== Recommendations ===\n");
295 report.push_str(&format!(
296 "Best overall (ECE): {}\n",
297 comparisons[best_ece_idx].method_name
298 ));
299 report.push_str(&format!(
300 "Most reliable (MCE): {}\n",
301 comparisons[best_mce_idx].method_name
302 ));
303 report.push_str(&format!(
304 "Best probability estimates (Brier): {}\n",
305 comparisons[best_brier_idx].method_name
306 ));
307 report
308 }
309}
310pub mod quantum_calibration {
313 use super::*;
314 #[derive(Debug, Clone)]
316 pub struct QuantumCalibrationConfig {
317 pub n_bins: usize,
319 pub use_error_mitigation: bool,
321 pub confidence_level: f64,
323 pub account_shot_noise: bool,
325 }
326 impl Default for QuantumCalibrationConfig {
327 fn default() -> Self {
328 Self {
329 n_bins: 10,
330 use_error_mitigation: true,
331 confidence_level: 0.95,
332 account_shot_noise: true,
333 }
334 }
335 }
336 #[derive(Debug, Clone)]
342 pub struct QuantumNeuralNetworkCalibrator {
343 method: CalibrationMethod,
345 config: QuantumCalibrationConfig,
347 shot_noise_estimates: Option<Array1<f64>>,
349 fitted: bool,
351 }
352 #[derive(Debug, Clone)]
354 pub enum CalibrationMethod {
355 Temperature(TemperatureScaler),
357 Vector(VectorScaler),
359 Platt(PlattScaler),
361 Isotonic(IsotonicRegression),
363 BayesianBinning(BayesianBinningQuantiles),
365 }
366 impl QuantumNeuralNetworkCalibrator {
367 pub fn new() -> Self {
369 Self {
370 method: CalibrationMethod::Temperature(TemperatureScaler::new()),
371 config: QuantumCalibrationConfig::default(),
372 shot_noise_estimates: None,
373 fitted: false,
374 }
375 }
376 pub fn with_method(method: CalibrationMethod) -> Self {
378 Self {
379 method,
380 config: QuantumCalibrationConfig::default(),
381 shot_noise_estimates: None,
382 fitted: false,
383 }
384 }
385 pub fn with_config(mut self, config: QuantumCalibrationConfig) -> Self {
387 self.config = config;
388 self
389 }
390 pub fn fit_binary(
392 &mut self,
393 probabilities: &Array1<f64>,
394 labels: &Array1<usize>,
395 shot_counts: Option<&Array1<usize>>,
396 ) -> Result<()> {
397 if let Some(shots) = shot_counts {
398 if self.config.account_shot_noise {
399 self.shot_noise_estimates =
400 Some(self.estimate_shot_noise(probabilities, shots));
401 }
402 }
403 match &mut self.method {
404 CalibrationMethod::Platt(scaler) => {
405 scaler.fit(probabilities, labels)?;
406 }
407 CalibrationMethod::Isotonic(scaler) => {
408 scaler.fit(probabilities, labels)?;
409 }
410 CalibrationMethod::BayesianBinning(scaler) => {
411 scaler.fit(probabilities, labels)?;
412 }
413 _ => {
414 return Err(MLError::InvalidInput(
415 "Binary calibration requires Platt, Isotonic, or BBQ method".to_string(),
416 ));
417 }
418 }
419 self.fitted = true;
420 Ok(())
421 }
422 pub fn fit_multiclass(
424 &mut self,
425 logits: &Array2<f64>,
426 labels: &Array1<usize>,
427 shot_counts: Option<&Array1<usize>>,
428 ) -> Result<()> {
429 if let Some(shots) = shot_counts {
430 if self.config.account_shot_noise {
431 let avg_probs = logits
432 .mean_axis(scirs2_core::ndarray::Axis(1))
433 .expect("logits should have valid axis");
434 self.shot_noise_estimates = Some(self.estimate_shot_noise(&avg_probs, shots));
435 }
436 }
437 match &mut self.method {
438 CalibrationMethod::Temperature(scaler) => {
439 scaler.fit(logits, labels)?;
440 }
441 CalibrationMethod::Vector(scaler) => {
442 scaler.fit(logits, labels)?;
443 }
444 _ => {
445 return Err(MLError::InvalidInput(
446 "Multi-class calibration requires Temperature or Vector method".to_string(),
447 ));
448 }
449 }
450 self.fitted = true;
451 Ok(())
452 }
453 fn estimate_shot_noise(
455 &self,
456 probabilities: &Array1<f64>,
457 shot_counts: &Array1<usize>,
458 ) -> Array1<f64> {
459 probabilities
460 .iter()
461 .zip(shot_counts.iter())
462 .map(|(&p, &n)| {
463 if n > 0 {
464 (p * (1.0 - p) / n as f64).sqrt()
465 } else {
466 0.0
467 }
468 })
469 .collect::<Vec<_>>()
470 .into()
471 }
472 pub fn transform_binary(&self, probabilities: &Array1<f64>) -> Result<Array1<f64>> {
474 if !self.fitted {
475 return Err(MLError::InvalidInput(
476 "Calibrator must be fitted before transform".to_string(),
477 ));
478 }
479 match &self.method {
480 CalibrationMethod::Platt(scaler) => scaler.transform(probabilities),
481 CalibrationMethod::Isotonic(scaler) => scaler.transform(probabilities),
482 CalibrationMethod::BayesianBinning(scaler) => scaler.transform(probabilities),
483 _ => Err(MLError::InvalidInput(
484 "Method does not support binary transformation".to_string(),
485 )),
486 }
487 }
488 pub fn transform_multiclass(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
490 if !self.fitted {
491 return Err(MLError::InvalidInput(
492 "Calibrator must be fitted before transform".to_string(),
493 ));
494 }
495 match &self.method {
496 CalibrationMethod::Temperature(scaler) => scaler.transform(logits),
497 CalibrationMethod::Vector(scaler) => scaler.transform(logits),
498 _ => Err(MLError::InvalidInput(
499 "Method does not support multi-class transformation".to_string(),
500 )),
501 }
502 }
503 pub fn transform_with_uncertainty(
505 &self,
506 probabilities: &Array1<f64>,
507 ) -> Result<Vec<(f64, f64, f64)>> {
508 if !self.fitted {
509 return Err(MLError::InvalidInput(
510 "Calibrator must be fitted before transform".to_string(),
511 ));
512 }
513 match &self.method {
514 CalibrationMethod::BayesianBinning(scaler) => {
515 scaler.predict_with_uncertainty(probabilities, self.config.confidence_level)
516 }
517 _ => {
518 let calibrated = self.transform_binary(probabilities)?;
519 if let Some(noise) = &self.shot_noise_estimates {
520 let results = calibrated
521 .iter()
522 .zip(noise.iter())
523 .map(|(&p, &sigma)| {
524 let z = 1.96;
525 let lower = (p - z * sigma).max(0.0);
526 let upper = (p + z * sigma).min(1.0);
527 (p, lower, upper)
528 })
529 .collect();
530 Ok(results)
531 } else {
532 Ok(calibrated.iter().map(|&p| (p, p, p)).collect())
533 }
534 }
535 }
536 }
537 pub fn evaluate_quantum_calibration(
539 &self,
540 probabilities: &Array1<f64>,
541 labels: &Array1<usize>,
542 ) -> Result<QuantumCalibrationMetrics> {
543 let calibrated = self.transform_binary(probabilities)?;
544 let analysis =
545 visualization::analyze_calibration(&calibrated, labels, self.config.n_bins)?;
546 let shot_noise_impact = if let Some(noise) = &self.shot_noise_estimates {
547 noise.mean().unwrap_or(0.0)
548 } else {
549 0.0
550 };
551 Ok(QuantumCalibrationMetrics {
552 ece: analysis.ece,
553 mce: analysis.mce,
554 brier_score: analysis.brier_score,
555 nll: analysis.nll,
556 shot_noise_impact,
557 interpretation: analysis.interpretation,
558 })
559 }
560 }
561 impl Default for QuantumNeuralNetworkCalibrator {
562 fn default() -> Self {
563 Self::new()
564 }
565 }
566 #[derive(Debug, Clone)]
568 pub struct QuantumCalibrationMetrics {
569 pub ece: f64,
571 pub mce: f64,
573 pub brier_score: f64,
575 pub nll: f64,
577 pub shot_noise_impact: f64,
579 pub interpretation: String,
581 }
582 pub fn quantum_ensemble_calibration(
585 probabilities: &Array1<f64>,
586 labels: &Array1<usize>,
587 shot_counts: &Array1<usize>,
588 n_bins: usize,
589 ) -> Result<(Array1<f64>, QuantumCalibrationMetrics)> {
590 let mut platt_cal = QuantumNeuralNetworkCalibrator::with_method(CalibrationMethod::Platt(
591 PlattScaler::new(),
592 ));
593 platt_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
594 let mut isotonic_cal = QuantumNeuralNetworkCalibrator::with_method(
595 CalibrationMethod::Isotonic(IsotonicRegression::new()),
596 );
597 isotonic_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
598 let mut bbq_cal = QuantumNeuralNetworkCalibrator::with_method(
599 CalibrationMethod::BayesianBinning(BayesianBinningQuantiles::new(n_bins)),
600 );
601 bbq_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
602 let platt_probs = platt_cal.transform_binary(probabilities)?;
603 let isotonic_probs = isotonic_cal.transform_binary(probabilities)?;
604 let bbq_probs = bbq_cal.transform_binary(probabilities)?;
605 let platt_metrics = platt_cal.evaluate_quantum_calibration(probabilities, labels)?;
606 let isotonic_metrics = isotonic_cal.evaluate_quantum_calibration(probabilities, labels)?;
607 let bbq_metrics = bbq_cal.evaluate_quantum_calibration(probabilities, labels)?;
608 let platt_weight = 1.0 / (platt_metrics.ece + 1e-6);
609 let isotonic_weight = 1.0 / (isotonic_metrics.ece + 1e-6);
610 let bbq_weight = 1.0 / (bbq_metrics.ece + 1e-6);
611 let total_weight = platt_weight + isotonic_weight + bbq_weight;
612 let ensemble_probs = (&platt_probs * (platt_weight / total_weight))
613 + (&isotonic_probs * (isotonic_weight / total_weight))
614 + (&bbq_probs * (bbq_weight / total_weight));
615 let ensemble_analysis =
616 visualization::analyze_calibration(&ensemble_probs, labels, n_bins)?;
617 let metrics = QuantumCalibrationMetrics {
618 ece: ensemble_analysis.ece,
619 mce: ensemble_analysis.mce,
620 brier_score: ensemble_analysis.brier_score,
621 nll: ensemble_analysis.nll,
622 shot_noise_impact: platt_metrics.shot_noise_impact,
623 interpretation: ensemble_analysis.interpretation,
624 };
625 Ok((ensemble_probs, metrics))
626 }
627}
628pub mod ensemble_selection {
630 use super::*;
631 use crate::utils::split::KFold;
632 #[derive(Debug, Clone)]
634 pub struct CalibratorCandidate {
635 pub name: String,
637 pub cv_ece_scores: Vec<f64>,
639 pub mean_ece: f64,
641 pub std_ece: f64,
643 pub is_binary: bool,
645 }
646 #[derive(Debug, Clone)]
648 pub enum SelectionStrategy {
649 BestSingle,
651 TopK(usize),
653 Threshold(f64),
655 WeightedAll,
657 }
658 #[derive(Debug, Clone)]
660 pub struct EnsembleSelectionResult {
661 pub selected_methods: Vec<String>,
663 pub weights: Vec<f64>,
665 pub method_performances: Vec<CalibratorCandidate>,
667 pub best_method: String,
669 pub ensemble_ece: f64,
671 }
672 pub fn select_binary_ensemble(
674 probabilities: &Array1<f64>,
675 labels: &Array1<usize>,
676 n_folds: usize,
677 strategy: SelectionStrategy,
678 ) -> Result<EnsembleSelectionResult> {
679 if n_folds < 2 {
680 return Err(MLError::InvalidInput(
681 "Need at least 2 folds for cross-validation".to_string(),
682 ));
683 }
684 let kfold = KFold::new(probabilities.len(), n_folds, true)?;
685 let method_names = vec!["Platt", "Isotonic", "BBQ-5", "BBQ-10"];
686 let mut candidates = Vec::new();
687 for method_name in method_names {
688 let mut cv_ece_scores = Vec::new();
689 for fold in 0..n_folds {
690 let (train_indices, val_indices) = kfold.get_fold(fold)?;
691 let train_probs: Array1<f64> =
692 train_indices.iter().map(|&i| probabilities[i]).collect();
693 let train_labels: Array1<usize> =
694 train_indices.iter().map(|&i| labels[i]).collect();
695 let val_probs: Array1<f64> =
696 val_indices.iter().map(|&i| probabilities[i]).collect();
697 let val_labels: Array1<usize> = val_indices.iter().map(|&i| labels[i]).collect();
698 let calibrated_val = match method_name {
699 "Platt" => {
700 let mut scaler = PlattScaler::new();
701 scaler.fit(&train_probs, &train_labels)?;
702 scaler.transform(&val_probs)?
703 }
704 "Isotonic" => {
705 let mut scaler = IsotonicRegression::new();
706 scaler.fit(&train_probs, &train_labels)?;
707 scaler.transform(&val_probs)?
708 }
709 "BBQ-5" => {
710 let mut scaler = BayesianBinningQuantiles::new(5);
711 scaler.fit(&train_probs, &train_labels)?;
712 scaler.transform(&val_probs)?
713 }
714 "BBQ-10" => {
715 let mut scaler = BayesianBinningQuantiles::new(10);
716 scaler.fit(&train_probs, &train_labels)?;
717 scaler.transform(&val_probs)?
718 }
719 _ => {
720 return Err(MLError::InvalidInput(format!(
721 "Unknown method: {}",
722 method_name
723 )));
724 }
725 };
726 let analysis =
727 visualization::analyze_calibration(&calibrated_val, &val_labels, 10)?;
728 cv_ece_scores.push(analysis.ece);
729 }
730 let mean_ece = cv_ece_scores.iter().sum::<f64>() / cv_ece_scores.len() as f64;
731 let variance = cv_ece_scores
732 .iter()
733 .map(|&x| (x - mean_ece).powi(2))
734 .sum::<f64>()
735 / cv_ece_scores.len() as f64;
736 let std_ece = variance.sqrt();
737 candidates.push(CalibratorCandidate {
738 name: method_name.to_string(),
739 cv_ece_scores,
740 mean_ece,
741 std_ece,
742 is_binary: true,
743 });
744 }
745 candidates.sort_by(|a, b| {
746 a.mean_ece
747 .partial_cmp(&b.mean_ece)
748 .unwrap_or(std::cmp::Ordering::Equal)
749 });
750 let (selected_methods, weights) = match strategy {
751 SelectionStrategy::BestSingle => (vec![candidates[0].name.clone()], vec![1.0]),
752 SelectionStrategy::TopK(k) => {
753 let k = k.min(candidates.len());
754 let methods: Vec<String> = candidates[..k].iter().map(|c| c.name.clone()).collect();
755 let weights = vec![1.0 / k as f64; k];
756 (methods, weights)
757 }
758 SelectionStrategy::Threshold(threshold) => {
759 let selected: Vec<_> = candidates
760 .iter()
761 .filter(|c| c.mean_ece < threshold)
762 .map(|c| c.name.clone())
763 .collect();
764 if selected.is_empty() {
765 (vec![candidates[0].name.clone()], vec![1.0])
766 } else {
767 let n = selected.len();
768 let weights = vec![1.0 / n as f64; n];
769 (selected, weights)
770 }
771 }
772 SelectionStrategy::WeightedAll => {
773 let methods: Vec<String> = candidates.iter().map(|c| c.name.clone()).collect();
774 let inv_eces: Vec<f64> = candidates
775 .iter()
776 .map(|c| 1.0 / (c.mean_ece + 1e-6))
777 .collect();
778 let sum_inv: f64 = inv_eces.iter().sum();
779 let weights: Vec<f64> = inv_eces.iter().map(|&w| w / sum_inv).collect();
780 (methods, weights)
781 }
782 };
783 let best_method = candidates[0].name.clone();
784 let ensemble_ece = if weights.len() == 1 {
785 candidates[0].mean_ece
786 } else {
787 candidates
788 .iter()
789 .zip(&weights)
790 .map(|(c, &w)| c.mean_ece * w)
791 .sum()
792 };
793 Ok(EnsembleSelectionResult {
794 selected_methods,
795 weights,
796 method_performances: candidates,
797 best_method,
798 ensemble_ece,
799 })
800 }
801 #[derive(Debug, Clone)]
804 pub struct CalibrationAwareSelector {
805 strategy: SelectionStrategy,
807 n_folds: usize,
809 is_binary: bool,
811 }
812 impl CalibrationAwareSelector {
813 pub fn new(n_folds: usize, is_binary: bool) -> Self {
815 Self {
816 strategy: SelectionStrategy::BestSingle,
817 n_folds,
818 is_binary,
819 }
820 }
821 pub fn with_strategy(mut self, strategy: SelectionStrategy) -> Self {
823 self.strategy = strategy;
824 self
825 }
826 pub fn select_binary(
828 &self,
829 probabilities: &Array1<f64>,
830 labels: &Array1<usize>,
831 ) -> Result<EnsembleSelectionResult> {
832 select_binary_ensemble(probabilities, labels, self.n_folds, self.strategy.clone())
833 }
834 pub fn generate_selection_report(&self, result: &EnsembleSelectionResult) -> String {
836 let mut report = String::new();
837 report.push_str("=== Calibration Method Selection Report ===\n\n");
838 report.push_str("Cross-Validation Results:\n");
839 report.push_str(&format!("{:-<60}\n", ""));
840 for method in &result.method_performances {
841 report.push_str(&format!(
842 "{:<15} | Mean ECE: {:.4} ± {:.4}\n",
843 method.name, method.mean_ece, method.std_ece
844 ));
845 }
846 report.push_str(&format!("\n{:-<60}\n", ""));
847 report.push_str(&format!(
848 "\nBest Individual Method: {}\n",
849 result.best_method
850 ));
851 report.push_str(&format!(
852 "Expected Ensemble ECE: {:.4}\n\n",
853 result.ensemble_ece
854 ));
855 report.push_str("Selected Ensemble:\n");
856 for (method, weight) in result.selected_methods.iter().zip(&result.weights) {
857 report.push_str(&format!(" {} (weight: {:.3})\n", method, weight));
858 }
859 report.push_str("\nRecommendation:\n");
860 if result.selected_methods.len() == 1 {
861 report.push_str(&format!(
862 "Use {} for best calibration performance.\n",
863 result.selected_methods[0]
864 ));
865 } else {
866 report.push_str(
867 "Use weighted ensemble of selected methods for robust calibration.\n",
868 );
869 }
870 report
871 }
872 }
873}