quantrs2_ml/utils/calibration/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::*;
6use crate::error::{MLError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8/// Calculate calibration curve (reliability diagram)
9/// Returns (mean_predicted_prob, fraction_of_positives) for each bin
10pub fn calibration_curve(
11    probabilities: &Array1<f64>,
12    labels: &Array1<usize>,
13    n_bins: usize,
14) -> Result<(Array1<f64>, Array1<f64>)> {
15    if probabilities.len() != labels.len() {
16        return Err(MLError::InvalidInput(
17            "Probabilities and labels must have same length".to_string(),
18        ));
19    }
20    if n_bins < 2 {
21        return Err(MLError::InvalidInput(
22            "Number of bins must be at least 2".to_string(),
23        ));
24    }
25    let mut bins = vec![Vec::new(); n_bins];
26    for (i, &prob) in probabilities.iter().enumerate() {
27        let bin_idx = ((prob * n_bins as f64).floor() as usize).min(n_bins - 1);
28        bins[bin_idx].push((prob, labels[i]));
29    }
30    let mut mean_predicted = Vec::new();
31    let mut fraction_positives = Vec::new();
32    for bin in bins {
33        if !bin.is_empty() {
34            let sum_prob: f64 = bin.iter().map(|(p, _)| p).sum();
35            let sum_labels: f64 = bin.iter().map(|(_, l)| *l as f64).sum();
36            mean_predicted.push(sum_prob / bin.len() as f64);
37            fraction_positives.push(sum_labels / bin.len() as f64);
38        }
39    }
40    Ok((
41        Array1::from_vec(mean_predicted),
42        Array1::from_vec(fraction_positives),
43    ))
44}
45/// Calibration visualization and analysis utilities
46pub mod visualization {
47    use super::*;
48    /// Calibration plot data for reliability diagrams
49    #[derive(Debug, Clone)]
50    pub struct CalibrationPlotData {
51        /// Mean predicted probabilities in each bin
52        pub mean_predicted: Array1<f64>,
53        /// Fraction of positives in each bin
54        pub fraction_positives: Array1<f64>,
55        /// Number of samples in each bin
56        pub bin_counts: Array1<usize>,
57        /// Bin edges
58        pub bin_edges: Vec<f64>,
59    }
60    /// Generate comprehensive calibration plot data
61    pub fn generate_calibration_plot_data(
62        probabilities: &Array1<f64>,
63        labels: &Array1<usize>,
64        n_bins: usize,
65    ) -> Result<CalibrationPlotData> {
66        if probabilities.len() != labels.len() {
67            return Err(MLError::InvalidInput(
68                "Probabilities and labels must have same length".to_string(),
69            ));
70        }
71        if n_bins < 2 {
72            return Err(MLError::InvalidInput(
73                "Number of bins must be at least 2".to_string(),
74            ));
75        }
76        let mut bins = vec![Vec::new(); n_bins];
77        let bin_edges: Vec<f64> = (0..=n_bins).map(|i| i as f64 / n_bins as f64).collect();
78        for (i, &prob) in probabilities.iter().enumerate() {
79            let bin_idx = ((prob * n_bins as f64).floor() as usize).min(n_bins - 1);
80            bins[bin_idx].push((prob, labels[i]));
81        }
82        let mut mean_predicted = Vec::new();
83        let mut fraction_positives = Vec::new();
84        let mut bin_counts = Vec::new();
85        for bin in bins {
86            if !bin.is_empty() {
87                let sum_prob: f64 = bin.iter().map(|(p, _)| p).sum();
88                let sum_labels: f64 = bin.iter().map(|(_, l)| *l as f64).sum();
89                mean_predicted.push(sum_prob / bin.len() as f64);
90                fraction_positives.push(sum_labels / bin.len() as f64);
91                bin_counts.push(bin.len());
92            } else {
93                mean_predicted.push(
94                    (bin_edges[mean_predicted.len()] + bin_edges[mean_predicted.len() + 1]) / 2.0,
95                );
96                fraction_positives.push(0.0);
97                bin_counts.push(0);
98            }
99        }
100        Ok(CalibrationPlotData {
101            mean_predicted: Array1::from_vec(mean_predicted),
102            fraction_positives: Array1::from_vec(fraction_positives),
103            bin_counts: Array1::from_vec(bin_counts),
104            bin_edges,
105        })
106    }
107    /// Comprehensive calibration analysis report
108    #[derive(Debug, Clone)]
109    pub struct CalibrationAnalysis {
110        /// Expected Calibration Error
111        pub ece: f64,
112        /// Maximum Calibration Error
113        pub mce: f64,
114        /// Brier score
115        pub brier_score: f64,
116        /// Negative log-likelihood
117        pub nll: f64,
118        /// Number of bins used
119        pub n_bins: usize,
120        /// Per-bin calibration errors
121        pub bin_errors: Array1<f64>,
122        /// Interpretation of calibration quality
123        pub interpretation: String,
124    }
125    impl CalibrationAnalysis {
126        /// Generate interpretation based on ECE
127        fn interpret_ece(ece: f64) -> String {
128            if ece < 0.01 {
129                "Excellent calibration - predictions are highly reliable".to_string()
130            } else if ece < 0.05 {
131                "Good calibration - predictions are generally reliable".to_string()
132            } else if ece < 0.10 {
133                "Moderate calibration - some miscalibration present".to_string()
134            } else if ece < 0.20 {
135                "Poor calibration - significant miscalibration detected".to_string()
136            } else {
137                "Very poor calibration - predictions are unreliable".to_string()
138            }
139        }
140    }
141    /// Perform comprehensive calibration analysis
142    pub fn analyze_calibration(
143        probabilities: &Array1<f64>,
144        labels: &Array1<usize>,
145        n_bins: usize,
146    ) -> Result<CalibrationAnalysis> {
147        let plot_data = generate_calibration_plot_data(probabilities, labels, n_bins)?;
148        let mut ece = 0.0;
149        let total_samples = probabilities.len() as f64;
150        for i in 0..plot_data.mean_predicted.len() {
151            let bin_error = (plot_data.mean_predicted[i] - plot_data.fraction_positives[i]).abs();
152            let bin_weight = plot_data.bin_counts[i] as f64 / total_samples;
153            ece += bin_weight * bin_error;
154        }
155        let bin_errors: Array1<f64> =
156            (&plot_data.mean_predicted - &plot_data.fraction_positives).mapv(|x| x.abs());
157        let mce = bin_errors.iter().cloned().fold(0.0f64, f64::max);
158        let mut brier_score = 0.0;
159        for (i, &prob) in probabilities.iter().enumerate() {
160            let true_label = labels[i] as f64;
161            brier_score += (prob - true_label).powi(2);
162        }
163        brier_score /= probabilities.len() as f64;
164        let mut nll = 0.0;
165        for (i, &prob) in probabilities.iter().enumerate() {
166            let true_label = labels[i];
167            let prob_clamped = prob.max(1e-10).min(1.0 - 1e-10);
168            if true_label == 1 {
169                nll -= prob_clamped.ln();
170            } else {
171                nll -= (1.0 - prob_clamped).ln();
172            }
173        }
174        nll /= probabilities.len() as f64;
175        let interpretation = CalibrationAnalysis::interpret_ece(ece);
176        Ok(CalibrationAnalysis {
177            ece,
178            mce,
179            brier_score,
180            nll,
181            n_bins,
182            bin_errors,
183            interpretation,
184        })
185    }
186    /// Compare multiple calibration methods
187    #[derive(Debug, Clone)]
188    pub struct CalibrationComparison {
189        /// Method name
190        pub method_name: String,
191        /// Calibration analysis
192        pub analysis: CalibrationAnalysis,
193        /// Calibrated probabilities
194        pub calibrated_probs: Array1<f64>,
195    }
196    /// Compare multiple calibration methods on the same dataset
197    pub fn compare_calibration_methods(
198        uncalibrated_probs: &Array1<f64>,
199        labels: &Array1<usize>,
200        n_bins: usize,
201    ) -> Result<Vec<CalibrationComparison>> {
202        let mut comparisons = Vec::new();
203        let uncal_analysis = analyze_calibration(uncalibrated_probs, labels, n_bins)?;
204        comparisons.push(CalibrationComparison {
205            method_name: "Uncalibrated".to_string(),
206            analysis: uncal_analysis,
207            calibrated_probs: uncalibrated_probs.clone(),
208        });
209        if labels.iter().max().unwrap_or(&0) == &1 {
210            let mut platt = PlattScaler::new();
211            if let Ok(calibrated) = platt.fit_transform(uncalibrated_probs, labels) {
212                let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
213                comparisons.push(CalibrationComparison {
214                    method_name: "Platt Scaling".to_string(),
215                    analysis,
216                    calibrated_probs: calibrated,
217                });
218            }
219        }
220        if labels.iter().max().unwrap_or(&0) == &1 {
221            let mut isotonic = IsotonicRegression::new();
222            if let Ok(calibrated) = isotonic.fit_transform(uncalibrated_probs, labels) {
223                let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
224                comparisons.push(CalibrationComparison {
225                    method_name: "Isotonic Regression".to_string(),
226                    analysis,
227                    calibrated_probs: calibrated,
228                });
229            }
230        }
231        if labels.iter().max().unwrap_or(&0) == &1 {
232            let mut bbq = BayesianBinningQuantiles::new(n_bins);
233            if let Ok(calibrated) = bbq.fit_transform(uncalibrated_probs, labels) {
234                let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
235                comparisons.push(CalibrationComparison {
236                    method_name: "Bayesian Binning (BBQ)".to_string(),
237                    analysis,
238                    calibrated_probs: calibrated,
239                });
240            }
241        }
242        Ok(comparisons)
243    }
244    /// Generate a text report comparing calibration methods
245    pub fn generate_comparison_report(comparisons: &[CalibrationComparison]) -> String {
246        let mut report = String::new();
247        report.push_str("=== Calibration Methods Comparison Report ===\n\n");
248        let mut best_ece_idx = 0;
249        let mut best_mce_idx = 0;
250        let mut best_brier_idx = 0;
251        let mut best_nll_idx = 0;
252        for (i, comp) in comparisons.iter().enumerate() {
253            if comp.analysis.ece < comparisons[best_ece_idx].analysis.ece {
254                best_ece_idx = i;
255            }
256            if comp.analysis.mce < comparisons[best_mce_idx].analysis.mce {
257                best_mce_idx = i;
258            }
259            if comp.analysis.brier_score < comparisons[best_brier_idx].analysis.brier_score {
260                best_brier_idx = i;
261            }
262            if comp.analysis.nll < comparisons[best_nll_idx].analysis.nll {
263                best_nll_idx = i;
264            }
265        }
266        for (i, comp) in comparisons.iter().enumerate() {
267            report.push_str(&format!("\n{}\n", comp.method_name));
268            report.push_str(&format!("{}\n", "=".repeat(comp.method_name.len())));
269            report.push_str(&format!(
270                "ECE: {:.4}{}\n",
271                comp.analysis.ece,
272                if i == best_ece_idx { " ⭐ BEST" } else { "" }
273            ));
274            report.push_str(&format!(
275                "MCE: {:.4}{}\n",
276                comp.analysis.mce,
277                if i == best_mce_idx { " ⭐ BEST" } else { "" }
278            ));
279            report.push_str(&format!(
280                "Brier Score: {:.4}{}\n",
281                comp.analysis.brier_score,
282                if i == best_brier_idx { " ⭐ BEST" } else { "" }
283            ));
284            report.push_str(&format!(
285                "NLL: {:.4}{}\n",
286                comp.analysis.nll,
287                if i == best_nll_idx { " ⭐ BEST" } else { "" }
288            ));
289            report.push_str(&format!(
290                "Interpretation: {}\n",
291                comp.analysis.interpretation
292            ));
293        }
294        report.push_str("\n=== Recommendations ===\n");
295        report.push_str(&format!(
296            "Best overall (ECE): {}\n",
297            comparisons[best_ece_idx].method_name
298        ));
299        report.push_str(&format!(
300            "Most reliable (MCE): {}\n",
301            comparisons[best_mce_idx].method_name
302        ));
303        report.push_str(&format!(
304            "Best probability estimates (Brier): {}\n",
305            comparisons[best_brier_idx].method_name
306        ));
307        report
308    }
309}
310/// Post-hoc calibration for Quantum Neural Networks
311/// Provides specialized calibration methods for quantum ML models
312pub mod quantum_calibration {
313    use super::*;
314    /// Quantum-aware calibration configuration
315    #[derive(Debug, Clone)]
316    pub struct QuantumCalibrationConfig {
317        /// Number of bins for histogram-based methods
318        pub n_bins: usize,
319        /// Whether to use quantum-aware error mitigation
320        pub use_error_mitigation: bool,
321        /// Confidence level for uncertainty quantification
322        pub confidence_level: f64,
323        /// Whether to account for shot noise
324        pub account_shot_noise: bool,
325    }
326    impl Default for QuantumCalibrationConfig {
327        fn default() -> Self {
328            Self {
329                n_bins: 10,
330                use_error_mitigation: true,
331                confidence_level: 0.95,
332                account_shot_noise: true,
333            }
334        }
335    }
336    /// Quantum Neural Network Calibrator
337    /// Specialized calibration for quantum ML models accounting for:
338    /// - Quantum measurement noise
339    /// - Shot noise from finite sampling
340    /// - Hardware-specific errors
341    #[derive(Debug, Clone)]
342    pub struct QuantumNeuralNetworkCalibrator {
343        /// Base calibration method
344        method: CalibrationMethod,
345        /// Configuration
346        config: QuantumCalibrationConfig,
347        /// Shot noise estimates per prediction
348        shot_noise_estimates: Option<Array1<f64>>,
349        /// Whether calibrator is fitted
350        fitted: bool,
351    }
352    /// Calibration method selection
353    #[derive(Debug, Clone)]
354    pub enum CalibrationMethod {
355        /// Temperature scaling (for multi-class)
356        Temperature(TemperatureScaler),
357        /// Vector scaling (for multi-class with class-specific parameters)
358        Vector(VectorScaler),
359        /// Platt scaling (for binary)
360        Platt(PlattScaler),
361        /// Isotonic regression (for binary)
362        Isotonic(IsotonicRegression),
363        /// Bayesian Binning (for binary with uncertainty)
364        BayesianBinning(BayesianBinningQuantiles),
365    }
366    impl QuantumNeuralNetworkCalibrator {
367        /// Create new quantum calibrator with default temperature scaling
368        pub fn new() -> Self {
369            Self {
370                method: CalibrationMethod::Temperature(TemperatureScaler::new()),
371                config: QuantumCalibrationConfig::default(),
372                shot_noise_estimates: None,
373                fitted: false,
374            }
375        }
376        /// Create calibrator with specific method
377        pub fn with_method(method: CalibrationMethod) -> Self {
378            Self {
379                method,
380                config: QuantumCalibrationConfig::default(),
381                shot_noise_estimates: None,
382                fitted: false,
383            }
384        }
385        /// Set configuration
386        pub fn with_config(mut self, config: QuantumCalibrationConfig) -> Self {
387            self.config = config;
388            self
389        }
390        /// Fit calibrator for binary classification
391        pub fn fit_binary(
392            &mut self,
393            probabilities: &Array1<f64>,
394            labels: &Array1<usize>,
395            shot_counts: Option<&Array1<usize>>,
396        ) -> Result<()> {
397            if let Some(shots) = shot_counts {
398                if self.config.account_shot_noise {
399                    self.shot_noise_estimates =
400                        Some(self.estimate_shot_noise(probabilities, shots));
401                }
402            }
403            match &mut self.method {
404                CalibrationMethod::Platt(scaler) => {
405                    scaler.fit(probabilities, labels)?;
406                }
407                CalibrationMethod::Isotonic(scaler) => {
408                    scaler.fit(probabilities, labels)?;
409                }
410                CalibrationMethod::BayesianBinning(scaler) => {
411                    scaler.fit(probabilities, labels)?;
412                }
413                _ => {
414                    return Err(MLError::InvalidInput(
415                        "Binary calibration requires Platt, Isotonic, or BBQ method".to_string(),
416                    ));
417                }
418            }
419            self.fitted = true;
420            Ok(())
421        }
422        /// Fit calibrator for multi-class classification
423        pub fn fit_multiclass(
424            &mut self,
425            logits: &Array2<f64>,
426            labels: &Array1<usize>,
427            shot_counts: Option<&Array1<usize>>,
428        ) -> Result<()> {
429            if let Some(shots) = shot_counts {
430                if self.config.account_shot_noise {
431                    let avg_probs = logits.mean_axis(scirs2_core::ndarray::Axis(1)).unwrap();
432                    self.shot_noise_estimates = Some(self.estimate_shot_noise(&avg_probs, shots));
433                }
434            }
435            match &mut self.method {
436                CalibrationMethod::Temperature(scaler) => {
437                    scaler.fit(logits, labels)?;
438                }
439                CalibrationMethod::Vector(scaler) => {
440                    scaler.fit(logits, labels)?;
441                }
442                _ => {
443                    return Err(MLError::InvalidInput(
444                        "Multi-class calibration requires Temperature or Vector method".to_string(),
445                    ));
446                }
447            }
448            self.fitted = true;
449            Ok(())
450        }
451        /// Estimate shot noise for each probability
452        fn estimate_shot_noise(
453            &self,
454            probabilities: &Array1<f64>,
455            shot_counts: &Array1<usize>,
456        ) -> Array1<f64> {
457            probabilities
458                .iter()
459                .zip(shot_counts.iter())
460                .map(|(&p, &n)| {
461                    if n > 0 {
462                        (p * (1.0 - p) / n as f64).sqrt()
463                    } else {
464                        0.0
465                    }
466                })
467                .collect::<Vec<_>>()
468                .into()
469        }
470        /// Transform binary probabilities
471        pub fn transform_binary(&self, probabilities: &Array1<f64>) -> Result<Array1<f64>> {
472            if !self.fitted {
473                return Err(MLError::InvalidInput(
474                    "Calibrator must be fitted before transform".to_string(),
475                ));
476            }
477            match &self.method {
478                CalibrationMethod::Platt(scaler) => scaler.transform(probabilities),
479                CalibrationMethod::Isotonic(scaler) => scaler.transform(probabilities),
480                CalibrationMethod::BayesianBinning(scaler) => scaler.transform(probabilities),
481                _ => Err(MLError::InvalidInput(
482                    "Method does not support binary transformation".to_string(),
483                )),
484            }
485        }
486        /// Transform multi-class logits
487        pub fn transform_multiclass(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
488            if !self.fitted {
489                return Err(MLError::InvalidInput(
490                    "Calibrator must be fitted before transform".to_string(),
491                ));
492            }
493            match &self.method {
494                CalibrationMethod::Temperature(scaler) => scaler.transform(logits),
495                CalibrationMethod::Vector(scaler) => scaler.transform(logits),
496                _ => Err(MLError::InvalidInput(
497                    "Method does not support multi-class transformation".to_string(),
498                )),
499            }
500        }
501        /// Transform with uncertainty quantification (binary only)
502        pub fn transform_with_uncertainty(
503            &self,
504            probabilities: &Array1<f64>,
505        ) -> Result<Vec<(f64, f64, f64)>> {
506            if !self.fitted {
507                return Err(MLError::InvalidInput(
508                    "Calibrator must be fitted before transform".to_string(),
509                ));
510            }
511            match &self.method {
512                CalibrationMethod::BayesianBinning(scaler) => {
513                    scaler.predict_with_uncertainty(probabilities, self.config.confidence_level)
514                }
515                _ => {
516                    let calibrated = self.transform_binary(probabilities)?;
517                    if let Some(noise) = &self.shot_noise_estimates {
518                        let results = calibrated
519                            .iter()
520                            .zip(noise.iter())
521                            .map(|(&p, &sigma)| {
522                                let z = 1.96;
523                                let lower = (p - z * sigma).max(0.0);
524                                let upper = (p + z * sigma).min(1.0);
525                                (p, lower, upper)
526                            })
527                            .collect();
528                        Ok(results)
529                    } else {
530                        Ok(calibrated.iter().map(|&p| (p, p, p)).collect())
531                    }
532                }
533            }
534        }
535        /// Get calibration quality metrics for quantum model
536        pub fn evaluate_quantum_calibration(
537            &self,
538            probabilities: &Array1<f64>,
539            labels: &Array1<usize>,
540        ) -> Result<QuantumCalibrationMetrics> {
541            let calibrated = self.transform_binary(probabilities)?;
542            let analysis =
543                visualization::analyze_calibration(&calibrated, labels, self.config.n_bins)?;
544            let shot_noise_impact = if let Some(noise) = &self.shot_noise_estimates {
545                noise.mean().unwrap_or(0.0)
546            } else {
547                0.0
548            };
549            Ok(QuantumCalibrationMetrics {
550                ece: analysis.ece,
551                mce: analysis.mce,
552                brier_score: analysis.brier_score,
553                nll: analysis.nll,
554                shot_noise_impact,
555                interpretation: analysis.interpretation,
556            })
557        }
558    }
559    impl Default for QuantumNeuralNetworkCalibrator {
560        fn default() -> Self {
561            Self::new()
562        }
563    }
564    /// Quantum calibration metrics
565    #[derive(Debug, Clone)]
566    pub struct QuantumCalibrationMetrics {
567        /// Expected Calibration Error
568        pub ece: f64,
569        /// Maximum Calibration Error
570        pub mce: f64,
571        /// Brier score
572        pub brier_score: f64,
573        /// Negative log-likelihood
574        pub nll: f64,
575        /// Average shot noise impact
576        pub shot_noise_impact: f64,
577        /// Interpretation
578        pub interpretation: String,
579    }
580    /// Quantum-aware ensemble calibration
581    /// Combines multiple calibration methods with quantum circuit execution results
582    pub fn quantum_ensemble_calibration(
583        probabilities: &Array1<f64>,
584        labels: &Array1<usize>,
585        shot_counts: &Array1<usize>,
586        n_bins: usize,
587    ) -> Result<(Array1<f64>, QuantumCalibrationMetrics)> {
588        let mut platt_cal = QuantumNeuralNetworkCalibrator::with_method(CalibrationMethod::Platt(
589            PlattScaler::new(),
590        ));
591        platt_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
592        let mut isotonic_cal = QuantumNeuralNetworkCalibrator::with_method(
593            CalibrationMethod::Isotonic(IsotonicRegression::new()),
594        );
595        isotonic_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
596        let mut bbq_cal = QuantumNeuralNetworkCalibrator::with_method(
597            CalibrationMethod::BayesianBinning(BayesianBinningQuantiles::new(n_bins)),
598        );
599        bbq_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
600        let platt_probs = platt_cal.transform_binary(probabilities)?;
601        let isotonic_probs = isotonic_cal.transform_binary(probabilities)?;
602        let bbq_probs = bbq_cal.transform_binary(probabilities)?;
603        let platt_metrics = platt_cal.evaluate_quantum_calibration(probabilities, labels)?;
604        let isotonic_metrics = isotonic_cal.evaluate_quantum_calibration(probabilities, labels)?;
605        let bbq_metrics = bbq_cal.evaluate_quantum_calibration(probabilities, labels)?;
606        let platt_weight = 1.0 / (platt_metrics.ece + 1e-6);
607        let isotonic_weight = 1.0 / (isotonic_metrics.ece + 1e-6);
608        let bbq_weight = 1.0 / (bbq_metrics.ece + 1e-6);
609        let total_weight = platt_weight + isotonic_weight + bbq_weight;
610        let ensemble_probs = (&platt_probs * (platt_weight / total_weight))
611            + (&isotonic_probs * (isotonic_weight / total_weight))
612            + (&bbq_probs * (bbq_weight / total_weight));
613        let ensemble_analysis =
614            visualization::analyze_calibration(&ensemble_probs, labels, n_bins)?;
615        let metrics = QuantumCalibrationMetrics {
616            ece: ensemble_analysis.ece,
617            mce: ensemble_analysis.mce,
618            brier_score: ensemble_analysis.brier_score,
619            nll: ensemble_analysis.nll,
620            shot_noise_impact: platt_metrics.shot_noise_impact,
621            interpretation: ensemble_analysis.interpretation,
622        };
623        Ok((ensemble_probs, metrics))
624    }
625}
626/// Ensemble selection and calibration-aware model selection
627pub mod ensemble_selection {
628    use super::*;
629    use crate::utils::split::KFold;
630    /// Ensemble calibration method with metadata
631    #[derive(Debug, Clone)]
632    pub struct CalibratorCandidate {
633        /// Name of the calibration method
634        pub name: String,
635        /// Cross-validation ECE scores
636        pub cv_ece_scores: Vec<f64>,
637        /// Mean ECE across folds
638        pub mean_ece: f64,
639        /// Standard deviation of ECE
640        pub std_ece: f64,
641        /// Whether this is for binary or multiclass
642        pub is_binary: bool,
643    }
644    /// Ensemble selection strategy
645    #[derive(Debug, Clone)]
646    pub enum SelectionStrategy {
647        /// Select single best method by mean ECE
648        BestSingle,
649        /// Select top K methods
650        TopK(usize),
651        /// Select all methods with ECE below threshold
652        Threshold(f64),
653        /// Weighted ensemble of all methods
654        WeightedAll,
655    }
656    /// Result of ensemble selection
657    #[derive(Debug, Clone)]
658    pub struct EnsembleSelectionResult {
659        /// Selected calibrator names
660        pub selected_methods: Vec<String>,
661        /// Weights for ensemble (if applicable)
662        pub weights: Vec<f64>,
663        /// Performance metrics for each method
664        pub method_performances: Vec<CalibratorCandidate>,
665        /// Best individual method
666        pub best_method: String,
667        /// Ensemble expected ECE
668        pub ensemble_ece: f64,
669    }
670    /// Perform cross-validated ensemble selection for binary calibration
671    pub fn select_binary_ensemble(
672        probabilities: &Array1<f64>,
673        labels: &Array1<usize>,
674        n_folds: usize,
675        strategy: SelectionStrategy,
676    ) -> Result<EnsembleSelectionResult> {
677        if n_folds < 2 {
678            return Err(MLError::InvalidInput(
679                "Need at least 2 folds for cross-validation".to_string(),
680            ));
681        }
682        let kfold = KFold::new(probabilities.len(), n_folds, true)?;
683        let method_names = vec!["Platt", "Isotonic", "BBQ-5", "BBQ-10"];
684        let mut candidates = Vec::new();
685        for method_name in method_names {
686            let mut cv_ece_scores = Vec::new();
687            for fold in 0..n_folds {
688                let (train_indices, val_indices) = kfold.get_fold(fold)?;
689                let train_probs: Array1<f64> =
690                    train_indices.iter().map(|&i| probabilities[i]).collect();
691                let train_labels: Array1<usize> =
692                    train_indices.iter().map(|&i| labels[i]).collect();
693                let val_probs: Array1<f64> =
694                    val_indices.iter().map(|&i| probabilities[i]).collect();
695                let val_labels: Array1<usize> = val_indices.iter().map(|&i| labels[i]).collect();
696                let calibrated_val = match method_name {
697                    "Platt" => {
698                        let mut scaler = PlattScaler::new();
699                        scaler.fit(&train_probs, &train_labels)?;
700                        scaler.transform(&val_probs)?
701                    }
702                    "Isotonic" => {
703                        let mut scaler = IsotonicRegression::new();
704                        scaler.fit(&train_probs, &train_labels)?;
705                        scaler.transform(&val_probs)?
706                    }
707                    "BBQ-5" => {
708                        let mut scaler = BayesianBinningQuantiles::new(5);
709                        scaler.fit(&train_probs, &train_labels)?;
710                        scaler.transform(&val_probs)?
711                    }
712                    "BBQ-10" => {
713                        let mut scaler = BayesianBinningQuantiles::new(10);
714                        scaler.fit(&train_probs, &train_labels)?;
715                        scaler.transform(&val_probs)?
716                    }
717                    _ => {
718                        return Err(MLError::InvalidInput(format!(
719                            "Unknown method: {}",
720                            method_name
721                        )));
722                    }
723                };
724                let analysis =
725                    visualization::analyze_calibration(&calibrated_val, &val_labels, 10)?;
726                cv_ece_scores.push(analysis.ece);
727            }
728            let mean_ece = cv_ece_scores.iter().sum::<f64>() / cv_ece_scores.len() as f64;
729            let variance = cv_ece_scores
730                .iter()
731                .map(|&x| (x - mean_ece).powi(2))
732                .sum::<f64>()
733                / cv_ece_scores.len() as f64;
734            let std_ece = variance.sqrt();
735            candidates.push(CalibratorCandidate {
736                name: method_name.to_string(),
737                cv_ece_scores,
738                mean_ece,
739                std_ece,
740                is_binary: true,
741            });
742        }
743        candidates.sort_by(|a, b| a.mean_ece.partial_cmp(&b.mean_ece).unwrap());
744        let (selected_methods, weights) = match strategy {
745            SelectionStrategy::BestSingle => (vec![candidates[0].name.clone()], vec![1.0]),
746            SelectionStrategy::TopK(k) => {
747                let k = k.min(candidates.len());
748                let methods: Vec<String> = candidates[..k].iter().map(|c| c.name.clone()).collect();
749                let weights = vec![1.0 / k as f64; k];
750                (methods, weights)
751            }
752            SelectionStrategy::Threshold(threshold) => {
753                let selected: Vec<_> = candidates
754                    .iter()
755                    .filter(|c| c.mean_ece < threshold)
756                    .map(|c| c.name.clone())
757                    .collect();
758                if selected.is_empty() {
759                    (vec![candidates[0].name.clone()], vec![1.0])
760                } else {
761                    let n = selected.len();
762                    let weights = vec![1.0 / n as f64; n];
763                    (selected, weights)
764                }
765            }
766            SelectionStrategy::WeightedAll => {
767                let methods: Vec<String> = candidates.iter().map(|c| c.name.clone()).collect();
768                let inv_eces: Vec<f64> = candidates
769                    .iter()
770                    .map(|c| 1.0 / (c.mean_ece + 1e-6))
771                    .collect();
772                let sum_inv: f64 = inv_eces.iter().sum();
773                let weights: Vec<f64> = inv_eces.iter().map(|&w| w / sum_inv).collect();
774                (methods, weights)
775            }
776        };
777        let best_method = candidates[0].name.clone();
778        let ensemble_ece = if weights.len() == 1 {
779            candidates[0].mean_ece
780        } else {
781            candidates
782                .iter()
783                .zip(&weights)
784                .map(|(c, &w)| c.mean_ece * w)
785                .sum()
786        };
787        Ok(EnsembleSelectionResult {
788            selected_methods,
789            weights,
790            method_performances: candidates,
791            best_method,
792            ensemble_ece,
793        })
794    }
795    /// Calibration-aware model selection
796    /// Selects the best calibration method for a given model based on validation performance
797    #[derive(Debug, Clone)]
798    pub struct CalibrationAwareSelector {
799        /// Selection strategy
800        strategy: SelectionStrategy,
801        /// Number of cross-validation folds
802        n_folds: usize,
803        /// Whether to use binary or multiclass calibration
804        is_binary: bool,
805    }
806    impl CalibrationAwareSelector {
807        /// Create a new calibration-aware selector
808        pub fn new(n_folds: usize, is_binary: bool) -> Self {
809            Self {
810                strategy: SelectionStrategy::BestSingle,
811                n_folds,
812                is_binary,
813            }
814        }
815        /// Set selection strategy
816        pub fn with_strategy(mut self, strategy: SelectionStrategy) -> Self {
817            self.strategy = strategy;
818            self
819        }
820        /// Select best calibration method for binary classification
821        pub fn select_binary(
822            &self,
823            probabilities: &Array1<f64>,
824            labels: &Array1<usize>,
825        ) -> Result<EnsembleSelectionResult> {
826            select_binary_ensemble(probabilities, labels, self.n_folds, self.strategy.clone())
827        }
828        /// Generate a detailed report of calibration method comparison
829        pub fn generate_selection_report(&self, result: &EnsembleSelectionResult) -> String {
830            let mut report = String::new();
831            report.push_str("=== Calibration Method Selection Report ===\n\n");
832            report.push_str("Cross-Validation Results:\n");
833            report.push_str(&format!("{:-<60}\n", ""));
834            for method in &result.method_performances {
835                report.push_str(&format!(
836                    "{:<15} | Mean ECE: {:.4} ± {:.4}\n",
837                    method.name, method.mean_ece, method.std_ece
838                ));
839            }
840            report.push_str(&format!("\n{:-<60}\n", ""));
841            report.push_str(&format!(
842                "\nBest Individual Method: {}\n",
843                result.best_method
844            ));
845            report.push_str(&format!(
846                "Expected Ensemble ECE: {:.4}\n\n",
847                result.ensemble_ece
848            ));
849            report.push_str("Selected Ensemble:\n");
850            for (method, weight) in result.selected_methods.iter().zip(&result.weights) {
851                report.push_str(&format!("  {} (weight: {:.3})\n", method, weight));
852            }
853            report.push_str("\nRecommendation:\n");
854            if result.selected_methods.len() == 1 {
855                report.push_str(&format!(
856                    "Use {} for best calibration performance.\n",
857                    result.selected_methods[0]
858                ));
859            } else {
860                report.push_str(
861                    "Use weighted ensemble of selected methods for robust calibration.\n",
862                );
863            }
864            report
865        }
866    }
867}