quantrs2_ml/utils/calibration/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::*;
6use crate::error::{MLError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8/// Calculate calibration curve (reliability diagram)
9/// Returns (mean_predicted_prob, fraction_of_positives) for each bin
10pub fn calibration_curve(
11    probabilities: &Array1<f64>,
12    labels: &Array1<usize>,
13    n_bins: usize,
14) -> Result<(Array1<f64>, Array1<f64>)> {
15    if probabilities.len() != labels.len() {
16        return Err(MLError::InvalidInput(
17            "Probabilities and labels must have same length".to_string(),
18        ));
19    }
20    if n_bins < 2 {
21        return Err(MLError::InvalidInput(
22            "Number of bins must be at least 2".to_string(),
23        ));
24    }
25    let mut bins = vec![Vec::new(); n_bins];
26    for (i, &prob) in probabilities.iter().enumerate() {
27        let bin_idx = ((prob * n_bins as f64).floor() as usize).min(n_bins - 1);
28        bins[bin_idx].push((prob, labels[i]));
29    }
30    let mut mean_predicted = Vec::new();
31    let mut fraction_positives = Vec::new();
32    for bin in bins {
33        if !bin.is_empty() {
34            let sum_prob: f64 = bin.iter().map(|(p, _)| p).sum();
35            let sum_labels: f64 = bin.iter().map(|(_, l)| *l as f64).sum();
36            mean_predicted.push(sum_prob / bin.len() as f64);
37            fraction_positives.push(sum_labels / bin.len() as f64);
38        }
39    }
40    Ok((
41        Array1::from_vec(mean_predicted),
42        Array1::from_vec(fraction_positives),
43    ))
44}
45/// Calibration visualization and analysis utilities
46pub mod visualization {
47    use super::*;
48    /// Calibration plot data for reliability diagrams
49    #[derive(Debug, Clone)]
50    pub struct CalibrationPlotData {
51        /// Mean predicted probabilities in each bin
52        pub mean_predicted: Array1<f64>,
53        /// Fraction of positives in each bin
54        pub fraction_positives: Array1<f64>,
55        /// Number of samples in each bin
56        pub bin_counts: Array1<usize>,
57        /// Bin edges
58        pub bin_edges: Vec<f64>,
59    }
60    /// Generate comprehensive calibration plot data
61    pub fn generate_calibration_plot_data(
62        probabilities: &Array1<f64>,
63        labels: &Array1<usize>,
64        n_bins: usize,
65    ) -> Result<CalibrationPlotData> {
66        if probabilities.len() != labels.len() {
67            return Err(MLError::InvalidInput(
68                "Probabilities and labels must have same length".to_string(),
69            ));
70        }
71        if n_bins < 2 {
72            return Err(MLError::InvalidInput(
73                "Number of bins must be at least 2".to_string(),
74            ));
75        }
76        let mut bins = vec![Vec::new(); n_bins];
77        let bin_edges: Vec<f64> = (0..=n_bins).map(|i| i as f64 / n_bins as f64).collect();
78        for (i, &prob) in probabilities.iter().enumerate() {
79            let bin_idx = ((prob * n_bins as f64).floor() as usize).min(n_bins - 1);
80            bins[bin_idx].push((prob, labels[i]));
81        }
82        let mut mean_predicted = Vec::new();
83        let mut fraction_positives = Vec::new();
84        let mut bin_counts = Vec::new();
85        for bin in bins {
86            if !bin.is_empty() {
87                let sum_prob: f64 = bin.iter().map(|(p, _)| p).sum();
88                let sum_labels: f64 = bin.iter().map(|(_, l)| *l as f64).sum();
89                mean_predicted.push(sum_prob / bin.len() as f64);
90                fraction_positives.push(sum_labels / bin.len() as f64);
91                bin_counts.push(bin.len());
92            } else {
93                mean_predicted.push(
94                    (bin_edges[mean_predicted.len()] + bin_edges[mean_predicted.len() + 1]) / 2.0,
95                );
96                fraction_positives.push(0.0);
97                bin_counts.push(0);
98            }
99        }
100        Ok(CalibrationPlotData {
101            mean_predicted: Array1::from_vec(mean_predicted),
102            fraction_positives: Array1::from_vec(fraction_positives),
103            bin_counts: Array1::from_vec(bin_counts),
104            bin_edges,
105        })
106    }
107    /// Comprehensive calibration analysis report
108    #[derive(Debug, Clone)]
109    pub struct CalibrationAnalysis {
110        /// Expected Calibration Error
111        pub ece: f64,
112        /// Maximum Calibration Error
113        pub mce: f64,
114        /// Brier score
115        pub brier_score: f64,
116        /// Negative log-likelihood
117        pub nll: f64,
118        /// Number of bins used
119        pub n_bins: usize,
120        /// Per-bin calibration errors
121        pub bin_errors: Array1<f64>,
122        /// Interpretation of calibration quality
123        pub interpretation: String,
124    }
125    impl CalibrationAnalysis {
126        /// Generate interpretation based on ECE
127        fn interpret_ece(ece: f64) -> String {
128            if ece < 0.01 {
129                "Excellent calibration - predictions are highly reliable".to_string()
130            } else if ece < 0.05 {
131                "Good calibration - predictions are generally reliable".to_string()
132            } else if ece < 0.10 {
133                "Moderate calibration - some miscalibration present".to_string()
134            } else if ece < 0.20 {
135                "Poor calibration - significant miscalibration detected".to_string()
136            } else {
137                "Very poor calibration - predictions are unreliable".to_string()
138            }
139        }
140    }
141    /// Perform comprehensive calibration analysis
142    pub fn analyze_calibration(
143        probabilities: &Array1<f64>,
144        labels: &Array1<usize>,
145        n_bins: usize,
146    ) -> Result<CalibrationAnalysis> {
147        let plot_data = generate_calibration_plot_data(probabilities, labels, n_bins)?;
148        let mut ece = 0.0;
149        let total_samples = probabilities.len() as f64;
150        for i in 0..plot_data.mean_predicted.len() {
151            let bin_error = (plot_data.mean_predicted[i] - plot_data.fraction_positives[i]).abs();
152            let bin_weight = plot_data.bin_counts[i] as f64 / total_samples;
153            ece += bin_weight * bin_error;
154        }
155        let bin_errors: Array1<f64> =
156            (&plot_data.mean_predicted - &plot_data.fraction_positives).mapv(|x| x.abs());
157        let mce = bin_errors.iter().cloned().fold(0.0f64, f64::max);
158        let mut brier_score = 0.0;
159        for (i, &prob) in probabilities.iter().enumerate() {
160            let true_label = labels[i] as f64;
161            brier_score += (prob - true_label).powi(2);
162        }
163        brier_score /= probabilities.len() as f64;
164        let mut nll = 0.0;
165        for (i, &prob) in probabilities.iter().enumerate() {
166            let true_label = labels[i];
167            let prob_clamped = prob.max(1e-10).min(1.0 - 1e-10);
168            if true_label == 1 {
169                nll -= prob_clamped.ln();
170            } else {
171                nll -= (1.0 - prob_clamped).ln();
172            }
173        }
174        nll /= probabilities.len() as f64;
175        let interpretation = CalibrationAnalysis::interpret_ece(ece);
176        Ok(CalibrationAnalysis {
177            ece,
178            mce,
179            brier_score,
180            nll,
181            n_bins,
182            bin_errors,
183            interpretation,
184        })
185    }
186    /// Compare multiple calibration methods
187    #[derive(Debug, Clone)]
188    pub struct CalibrationComparison {
189        /// Method name
190        pub method_name: String,
191        /// Calibration analysis
192        pub analysis: CalibrationAnalysis,
193        /// Calibrated probabilities
194        pub calibrated_probs: Array1<f64>,
195    }
196    /// Compare multiple calibration methods on the same dataset
197    pub fn compare_calibration_methods(
198        uncalibrated_probs: &Array1<f64>,
199        labels: &Array1<usize>,
200        n_bins: usize,
201    ) -> Result<Vec<CalibrationComparison>> {
202        let mut comparisons = Vec::new();
203        let uncal_analysis = analyze_calibration(uncalibrated_probs, labels, n_bins)?;
204        comparisons.push(CalibrationComparison {
205            method_name: "Uncalibrated".to_string(),
206            analysis: uncal_analysis,
207            calibrated_probs: uncalibrated_probs.clone(),
208        });
209        if labels.iter().max().unwrap_or(&0) == &1 {
210            let mut platt = PlattScaler::new();
211            if let Ok(calibrated) = platt.fit_transform(uncalibrated_probs, labels) {
212                let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
213                comparisons.push(CalibrationComparison {
214                    method_name: "Platt Scaling".to_string(),
215                    analysis,
216                    calibrated_probs: calibrated,
217                });
218            }
219        }
220        if labels.iter().max().unwrap_or(&0) == &1 {
221            let mut isotonic = IsotonicRegression::new();
222            if let Ok(calibrated) = isotonic.fit_transform(uncalibrated_probs, labels) {
223                let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
224                comparisons.push(CalibrationComparison {
225                    method_name: "Isotonic Regression".to_string(),
226                    analysis,
227                    calibrated_probs: calibrated,
228                });
229            }
230        }
231        if labels.iter().max().unwrap_or(&0) == &1 {
232            let mut bbq = BayesianBinningQuantiles::new(n_bins);
233            if let Ok(calibrated) = bbq.fit_transform(uncalibrated_probs, labels) {
234                let analysis = analyze_calibration(&calibrated, labels, n_bins)?;
235                comparisons.push(CalibrationComparison {
236                    method_name: "Bayesian Binning (BBQ)".to_string(),
237                    analysis,
238                    calibrated_probs: calibrated,
239                });
240            }
241        }
242        Ok(comparisons)
243    }
244    /// Generate a text report comparing calibration methods
245    pub fn generate_comparison_report(comparisons: &[CalibrationComparison]) -> String {
246        let mut report = String::new();
247        report.push_str("=== Calibration Methods Comparison Report ===\n\n");
248        let mut best_ece_idx = 0;
249        let mut best_mce_idx = 0;
250        let mut best_brier_idx = 0;
251        let mut best_nll_idx = 0;
252        for (i, comp) in comparisons.iter().enumerate() {
253            if comp.analysis.ece < comparisons[best_ece_idx].analysis.ece {
254                best_ece_idx = i;
255            }
256            if comp.analysis.mce < comparisons[best_mce_idx].analysis.mce {
257                best_mce_idx = i;
258            }
259            if comp.analysis.brier_score < comparisons[best_brier_idx].analysis.brier_score {
260                best_brier_idx = i;
261            }
262            if comp.analysis.nll < comparisons[best_nll_idx].analysis.nll {
263                best_nll_idx = i;
264            }
265        }
266        for (i, comp) in comparisons.iter().enumerate() {
267            report.push_str(&format!("\n{}\n", comp.method_name));
268            report.push_str(&format!("{}\n", "=".repeat(comp.method_name.len())));
269            report.push_str(&format!(
270                "ECE: {:.4}{}\n",
271                comp.analysis.ece,
272                if i == best_ece_idx { " ⭐ BEST" } else { "" }
273            ));
274            report.push_str(&format!(
275                "MCE: {:.4}{}\n",
276                comp.analysis.mce,
277                if i == best_mce_idx { " ⭐ BEST" } else { "" }
278            ));
279            report.push_str(&format!(
280                "Brier Score: {:.4}{}\n",
281                comp.analysis.brier_score,
282                if i == best_brier_idx { " ⭐ BEST" } else { "" }
283            ));
284            report.push_str(&format!(
285                "NLL: {:.4}{}\n",
286                comp.analysis.nll,
287                if i == best_nll_idx { " ⭐ BEST" } else { "" }
288            ));
289            report.push_str(&format!(
290                "Interpretation: {}\n",
291                comp.analysis.interpretation
292            ));
293        }
294        report.push_str("\n=== Recommendations ===\n");
295        report.push_str(&format!(
296            "Best overall (ECE): {}\n",
297            comparisons[best_ece_idx].method_name
298        ));
299        report.push_str(&format!(
300            "Most reliable (MCE): {}\n",
301            comparisons[best_mce_idx].method_name
302        ));
303        report.push_str(&format!(
304            "Best probability estimates (Brier): {}\n",
305            comparisons[best_brier_idx].method_name
306        ));
307        report
308    }
309}
310/// Post-hoc calibration for Quantum Neural Networks
311/// Provides specialized calibration methods for quantum ML models
312pub mod quantum_calibration {
313    use super::*;
314    /// Quantum-aware calibration configuration
315    #[derive(Debug, Clone)]
316    pub struct QuantumCalibrationConfig {
317        /// Number of bins for histogram-based methods
318        pub n_bins: usize,
319        /// Whether to use quantum-aware error mitigation
320        pub use_error_mitigation: bool,
321        /// Confidence level for uncertainty quantification
322        pub confidence_level: f64,
323        /// Whether to account for shot noise
324        pub account_shot_noise: bool,
325    }
326    impl Default for QuantumCalibrationConfig {
327        fn default() -> Self {
328            Self {
329                n_bins: 10,
330                use_error_mitigation: true,
331                confidence_level: 0.95,
332                account_shot_noise: true,
333            }
334        }
335    }
336    /// Quantum Neural Network Calibrator
337    /// Specialized calibration for quantum ML models accounting for:
338    /// - Quantum measurement noise
339    /// - Shot noise from finite sampling
340    /// - Hardware-specific errors
341    #[derive(Debug, Clone)]
342    pub struct QuantumNeuralNetworkCalibrator {
343        /// Base calibration method
344        method: CalibrationMethod,
345        /// Configuration
346        config: QuantumCalibrationConfig,
347        /// Shot noise estimates per prediction
348        shot_noise_estimates: Option<Array1<f64>>,
349        /// Whether calibrator is fitted
350        fitted: bool,
351    }
352    /// Calibration method selection
353    #[derive(Debug, Clone)]
354    pub enum CalibrationMethod {
355        /// Temperature scaling (for multi-class)
356        Temperature(TemperatureScaler),
357        /// Vector scaling (for multi-class with class-specific parameters)
358        Vector(VectorScaler),
359        /// Platt scaling (for binary)
360        Platt(PlattScaler),
361        /// Isotonic regression (for binary)
362        Isotonic(IsotonicRegression),
363        /// Bayesian Binning (for binary with uncertainty)
364        BayesianBinning(BayesianBinningQuantiles),
365    }
366    impl QuantumNeuralNetworkCalibrator {
367        /// Create new quantum calibrator with default temperature scaling
368        pub fn new() -> Self {
369            Self {
370                method: CalibrationMethod::Temperature(TemperatureScaler::new()),
371                config: QuantumCalibrationConfig::default(),
372                shot_noise_estimates: None,
373                fitted: false,
374            }
375        }
376        /// Create calibrator with specific method
377        pub fn with_method(method: CalibrationMethod) -> Self {
378            Self {
379                method,
380                config: QuantumCalibrationConfig::default(),
381                shot_noise_estimates: None,
382                fitted: false,
383            }
384        }
385        /// Set configuration
386        pub fn with_config(mut self, config: QuantumCalibrationConfig) -> Self {
387            self.config = config;
388            self
389        }
390        /// Fit calibrator for binary classification
391        pub fn fit_binary(
392            &mut self,
393            probabilities: &Array1<f64>,
394            labels: &Array1<usize>,
395            shot_counts: Option<&Array1<usize>>,
396        ) -> Result<()> {
397            if let Some(shots) = shot_counts {
398                if self.config.account_shot_noise {
399                    self.shot_noise_estimates =
400                        Some(self.estimate_shot_noise(probabilities, shots));
401                }
402            }
403            match &mut self.method {
404                CalibrationMethod::Platt(scaler) => {
405                    scaler.fit(probabilities, labels)?;
406                }
407                CalibrationMethod::Isotonic(scaler) => {
408                    scaler.fit(probabilities, labels)?;
409                }
410                CalibrationMethod::BayesianBinning(scaler) => {
411                    scaler.fit(probabilities, labels)?;
412                }
413                _ => {
414                    return Err(MLError::InvalidInput(
415                        "Binary calibration requires Platt, Isotonic, or BBQ method".to_string(),
416                    ));
417                }
418            }
419            self.fitted = true;
420            Ok(())
421        }
422        /// Fit calibrator for multi-class classification
423        pub fn fit_multiclass(
424            &mut self,
425            logits: &Array2<f64>,
426            labels: &Array1<usize>,
427            shot_counts: Option<&Array1<usize>>,
428        ) -> Result<()> {
429            if let Some(shots) = shot_counts {
430                if self.config.account_shot_noise {
431                    let avg_probs = logits
432                        .mean_axis(scirs2_core::ndarray::Axis(1))
433                        .expect("logits should have valid axis");
434                    self.shot_noise_estimates = Some(self.estimate_shot_noise(&avg_probs, shots));
435                }
436            }
437            match &mut self.method {
438                CalibrationMethod::Temperature(scaler) => {
439                    scaler.fit(logits, labels)?;
440                }
441                CalibrationMethod::Vector(scaler) => {
442                    scaler.fit(logits, labels)?;
443                }
444                _ => {
445                    return Err(MLError::InvalidInput(
446                        "Multi-class calibration requires Temperature or Vector method".to_string(),
447                    ));
448                }
449            }
450            self.fitted = true;
451            Ok(())
452        }
453        /// Estimate shot noise for each probability
454        fn estimate_shot_noise(
455            &self,
456            probabilities: &Array1<f64>,
457            shot_counts: &Array1<usize>,
458        ) -> Array1<f64> {
459            probabilities
460                .iter()
461                .zip(shot_counts.iter())
462                .map(|(&p, &n)| {
463                    if n > 0 {
464                        (p * (1.0 - p) / n as f64).sqrt()
465                    } else {
466                        0.0
467                    }
468                })
469                .collect::<Vec<_>>()
470                .into()
471        }
472        /// Transform binary probabilities
473        pub fn transform_binary(&self, probabilities: &Array1<f64>) -> Result<Array1<f64>> {
474            if !self.fitted {
475                return Err(MLError::InvalidInput(
476                    "Calibrator must be fitted before transform".to_string(),
477                ));
478            }
479            match &self.method {
480                CalibrationMethod::Platt(scaler) => scaler.transform(probabilities),
481                CalibrationMethod::Isotonic(scaler) => scaler.transform(probabilities),
482                CalibrationMethod::BayesianBinning(scaler) => scaler.transform(probabilities),
483                _ => Err(MLError::InvalidInput(
484                    "Method does not support binary transformation".to_string(),
485                )),
486            }
487        }
488        /// Transform multi-class logits
489        pub fn transform_multiclass(&self, logits: &Array2<f64>) -> Result<Array2<f64>> {
490            if !self.fitted {
491                return Err(MLError::InvalidInput(
492                    "Calibrator must be fitted before transform".to_string(),
493                ));
494            }
495            match &self.method {
496                CalibrationMethod::Temperature(scaler) => scaler.transform(logits),
497                CalibrationMethod::Vector(scaler) => scaler.transform(logits),
498                _ => Err(MLError::InvalidInput(
499                    "Method does not support multi-class transformation".to_string(),
500                )),
501            }
502        }
503        /// Transform with uncertainty quantification (binary only)
504        pub fn transform_with_uncertainty(
505            &self,
506            probabilities: &Array1<f64>,
507        ) -> Result<Vec<(f64, f64, f64)>> {
508            if !self.fitted {
509                return Err(MLError::InvalidInput(
510                    "Calibrator must be fitted before transform".to_string(),
511                ));
512            }
513            match &self.method {
514                CalibrationMethod::BayesianBinning(scaler) => {
515                    scaler.predict_with_uncertainty(probabilities, self.config.confidence_level)
516                }
517                _ => {
518                    let calibrated = self.transform_binary(probabilities)?;
519                    if let Some(noise) = &self.shot_noise_estimates {
520                        let results = calibrated
521                            .iter()
522                            .zip(noise.iter())
523                            .map(|(&p, &sigma)| {
524                                let z = 1.96;
525                                let lower = (p - z * sigma).max(0.0);
526                                let upper = (p + z * sigma).min(1.0);
527                                (p, lower, upper)
528                            })
529                            .collect();
530                        Ok(results)
531                    } else {
532                        Ok(calibrated.iter().map(|&p| (p, p, p)).collect())
533                    }
534                }
535            }
536        }
537        /// Get calibration quality metrics for quantum model
538        pub fn evaluate_quantum_calibration(
539            &self,
540            probabilities: &Array1<f64>,
541            labels: &Array1<usize>,
542        ) -> Result<QuantumCalibrationMetrics> {
543            let calibrated = self.transform_binary(probabilities)?;
544            let analysis =
545                visualization::analyze_calibration(&calibrated, labels, self.config.n_bins)?;
546            let shot_noise_impact = if let Some(noise) = &self.shot_noise_estimates {
547                noise.mean().unwrap_or(0.0)
548            } else {
549                0.0
550            };
551            Ok(QuantumCalibrationMetrics {
552                ece: analysis.ece,
553                mce: analysis.mce,
554                brier_score: analysis.brier_score,
555                nll: analysis.nll,
556                shot_noise_impact,
557                interpretation: analysis.interpretation,
558            })
559        }
560    }
561    impl Default for QuantumNeuralNetworkCalibrator {
562        fn default() -> Self {
563            Self::new()
564        }
565    }
566    /// Quantum calibration metrics
567    #[derive(Debug, Clone)]
568    pub struct QuantumCalibrationMetrics {
569        /// Expected Calibration Error
570        pub ece: f64,
571        /// Maximum Calibration Error
572        pub mce: f64,
573        /// Brier score
574        pub brier_score: f64,
575        /// Negative log-likelihood
576        pub nll: f64,
577        /// Average shot noise impact
578        pub shot_noise_impact: f64,
579        /// Interpretation
580        pub interpretation: String,
581    }
582    /// Quantum-aware ensemble calibration
583    /// Combines multiple calibration methods with quantum circuit execution results
584    pub fn quantum_ensemble_calibration(
585        probabilities: &Array1<f64>,
586        labels: &Array1<usize>,
587        shot_counts: &Array1<usize>,
588        n_bins: usize,
589    ) -> Result<(Array1<f64>, QuantumCalibrationMetrics)> {
590        let mut platt_cal = QuantumNeuralNetworkCalibrator::with_method(CalibrationMethod::Platt(
591            PlattScaler::new(),
592        ));
593        platt_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
594        let mut isotonic_cal = QuantumNeuralNetworkCalibrator::with_method(
595            CalibrationMethod::Isotonic(IsotonicRegression::new()),
596        );
597        isotonic_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
598        let mut bbq_cal = QuantumNeuralNetworkCalibrator::with_method(
599            CalibrationMethod::BayesianBinning(BayesianBinningQuantiles::new(n_bins)),
600        );
601        bbq_cal.fit_binary(probabilities, labels, Some(shot_counts))?;
602        let platt_probs = platt_cal.transform_binary(probabilities)?;
603        let isotonic_probs = isotonic_cal.transform_binary(probabilities)?;
604        let bbq_probs = bbq_cal.transform_binary(probabilities)?;
605        let platt_metrics = platt_cal.evaluate_quantum_calibration(probabilities, labels)?;
606        let isotonic_metrics = isotonic_cal.evaluate_quantum_calibration(probabilities, labels)?;
607        let bbq_metrics = bbq_cal.evaluate_quantum_calibration(probabilities, labels)?;
608        let platt_weight = 1.0 / (platt_metrics.ece + 1e-6);
609        let isotonic_weight = 1.0 / (isotonic_metrics.ece + 1e-6);
610        let bbq_weight = 1.0 / (bbq_metrics.ece + 1e-6);
611        let total_weight = platt_weight + isotonic_weight + bbq_weight;
612        let ensemble_probs = (&platt_probs * (platt_weight / total_weight))
613            + (&isotonic_probs * (isotonic_weight / total_weight))
614            + (&bbq_probs * (bbq_weight / total_weight));
615        let ensemble_analysis =
616            visualization::analyze_calibration(&ensemble_probs, labels, n_bins)?;
617        let metrics = QuantumCalibrationMetrics {
618            ece: ensemble_analysis.ece,
619            mce: ensemble_analysis.mce,
620            brier_score: ensemble_analysis.brier_score,
621            nll: ensemble_analysis.nll,
622            shot_noise_impact: platt_metrics.shot_noise_impact,
623            interpretation: ensemble_analysis.interpretation,
624        };
625        Ok((ensemble_probs, metrics))
626    }
627}
628/// Ensemble selection and calibration-aware model selection
629pub mod ensemble_selection {
630    use super::*;
631    use crate::utils::split::KFold;
632    /// Ensemble calibration method with metadata
633    #[derive(Debug, Clone)]
634    pub struct CalibratorCandidate {
635        /// Name of the calibration method
636        pub name: String,
637        /// Cross-validation ECE scores
638        pub cv_ece_scores: Vec<f64>,
639        /// Mean ECE across folds
640        pub mean_ece: f64,
641        /// Standard deviation of ECE
642        pub std_ece: f64,
643        /// Whether this is for binary or multiclass
644        pub is_binary: bool,
645    }
646    /// Ensemble selection strategy
647    #[derive(Debug, Clone)]
648    pub enum SelectionStrategy {
649        /// Select single best method by mean ECE
650        BestSingle,
651        /// Select top K methods
652        TopK(usize),
653        /// Select all methods with ECE below threshold
654        Threshold(f64),
655        /// Weighted ensemble of all methods
656        WeightedAll,
657    }
658    /// Result of ensemble selection
659    #[derive(Debug, Clone)]
660    pub struct EnsembleSelectionResult {
661        /// Selected calibrator names
662        pub selected_methods: Vec<String>,
663        /// Weights for ensemble (if applicable)
664        pub weights: Vec<f64>,
665        /// Performance metrics for each method
666        pub method_performances: Vec<CalibratorCandidate>,
667        /// Best individual method
668        pub best_method: String,
669        /// Ensemble expected ECE
670        pub ensemble_ece: f64,
671    }
672    /// Perform cross-validated ensemble selection for binary calibration
673    pub fn select_binary_ensemble(
674        probabilities: &Array1<f64>,
675        labels: &Array1<usize>,
676        n_folds: usize,
677        strategy: SelectionStrategy,
678    ) -> Result<EnsembleSelectionResult> {
679        if n_folds < 2 {
680            return Err(MLError::InvalidInput(
681                "Need at least 2 folds for cross-validation".to_string(),
682            ));
683        }
684        let kfold = KFold::new(probabilities.len(), n_folds, true)?;
685        let method_names = vec!["Platt", "Isotonic", "BBQ-5", "BBQ-10"];
686        let mut candidates = Vec::new();
687        for method_name in method_names {
688            let mut cv_ece_scores = Vec::new();
689            for fold in 0..n_folds {
690                let (train_indices, val_indices) = kfold.get_fold(fold)?;
691                let train_probs: Array1<f64> =
692                    train_indices.iter().map(|&i| probabilities[i]).collect();
693                let train_labels: Array1<usize> =
694                    train_indices.iter().map(|&i| labels[i]).collect();
695                let val_probs: Array1<f64> =
696                    val_indices.iter().map(|&i| probabilities[i]).collect();
697                let val_labels: Array1<usize> = val_indices.iter().map(|&i| labels[i]).collect();
698                let calibrated_val = match method_name {
699                    "Platt" => {
700                        let mut scaler = PlattScaler::new();
701                        scaler.fit(&train_probs, &train_labels)?;
702                        scaler.transform(&val_probs)?
703                    }
704                    "Isotonic" => {
705                        let mut scaler = IsotonicRegression::new();
706                        scaler.fit(&train_probs, &train_labels)?;
707                        scaler.transform(&val_probs)?
708                    }
709                    "BBQ-5" => {
710                        let mut scaler = BayesianBinningQuantiles::new(5);
711                        scaler.fit(&train_probs, &train_labels)?;
712                        scaler.transform(&val_probs)?
713                    }
714                    "BBQ-10" => {
715                        let mut scaler = BayesianBinningQuantiles::new(10);
716                        scaler.fit(&train_probs, &train_labels)?;
717                        scaler.transform(&val_probs)?
718                    }
719                    _ => {
720                        return Err(MLError::InvalidInput(format!(
721                            "Unknown method: {}",
722                            method_name
723                        )));
724                    }
725                };
726                let analysis =
727                    visualization::analyze_calibration(&calibrated_val, &val_labels, 10)?;
728                cv_ece_scores.push(analysis.ece);
729            }
730            let mean_ece = cv_ece_scores.iter().sum::<f64>() / cv_ece_scores.len() as f64;
731            let variance = cv_ece_scores
732                .iter()
733                .map(|&x| (x - mean_ece).powi(2))
734                .sum::<f64>()
735                / cv_ece_scores.len() as f64;
736            let std_ece = variance.sqrt();
737            candidates.push(CalibratorCandidate {
738                name: method_name.to_string(),
739                cv_ece_scores,
740                mean_ece,
741                std_ece,
742                is_binary: true,
743            });
744        }
745        candidates.sort_by(|a, b| {
746            a.mean_ece
747                .partial_cmp(&b.mean_ece)
748                .unwrap_or(std::cmp::Ordering::Equal)
749        });
750        let (selected_methods, weights) = match strategy {
751            SelectionStrategy::BestSingle => (vec![candidates[0].name.clone()], vec![1.0]),
752            SelectionStrategy::TopK(k) => {
753                let k = k.min(candidates.len());
754                let methods: Vec<String> = candidates[..k].iter().map(|c| c.name.clone()).collect();
755                let weights = vec![1.0 / k as f64; k];
756                (methods, weights)
757            }
758            SelectionStrategy::Threshold(threshold) => {
759                let selected: Vec<_> = candidates
760                    .iter()
761                    .filter(|c| c.mean_ece < threshold)
762                    .map(|c| c.name.clone())
763                    .collect();
764                if selected.is_empty() {
765                    (vec![candidates[0].name.clone()], vec![1.0])
766                } else {
767                    let n = selected.len();
768                    let weights = vec![1.0 / n as f64; n];
769                    (selected, weights)
770                }
771            }
772            SelectionStrategy::WeightedAll => {
773                let methods: Vec<String> = candidates.iter().map(|c| c.name.clone()).collect();
774                let inv_eces: Vec<f64> = candidates
775                    .iter()
776                    .map(|c| 1.0 / (c.mean_ece + 1e-6))
777                    .collect();
778                let sum_inv: f64 = inv_eces.iter().sum();
779                let weights: Vec<f64> = inv_eces.iter().map(|&w| w / sum_inv).collect();
780                (methods, weights)
781            }
782        };
783        let best_method = candidates[0].name.clone();
784        let ensemble_ece = if weights.len() == 1 {
785            candidates[0].mean_ece
786        } else {
787            candidates
788                .iter()
789                .zip(&weights)
790                .map(|(c, &w)| c.mean_ece * w)
791                .sum()
792        };
793        Ok(EnsembleSelectionResult {
794            selected_methods,
795            weights,
796            method_performances: candidates,
797            best_method,
798            ensemble_ece,
799        })
800    }
801    /// Calibration-aware model selection
802    /// Selects the best calibration method for a given model based on validation performance
803    #[derive(Debug, Clone)]
804    pub struct CalibrationAwareSelector {
805        /// Selection strategy
806        strategy: SelectionStrategy,
807        /// Number of cross-validation folds
808        n_folds: usize,
809        /// Whether to use binary or multiclass calibration
810        is_binary: bool,
811    }
812    impl CalibrationAwareSelector {
813        /// Create a new calibration-aware selector
814        pub fn new(n_folds: usize, is_binary: bool) -> Self {
815            Self {
816                strategy: SelectionStrategy::BestSingle,
817                n_folds,
818                is_binary,
819            }
820        }
821        /// Set selection strategy
822        pub fn with_strategy(mut self, strategy: SelectionStrategy) -> Self {
823            self.strategy = strategy;
824            self
825        }
826        /// Select best calibration method for binary classification
827        pub fn select_binary(
828            &self,
829            probabilities: &Array1<f64>,
830            labels: &Array1<usize>,
831        ) -> Result<EnsembleSelectionResult> {
832            select_binary_ensemble(probabilities, labels, self.n_folds, self.strategy.clone())
833        }
834        /// Generate a detailed report of calibration method comparison
835        pub fn generate_selection_report(&self, result: &EnsembleSelectionResult) -> String {
836            let mut report = String::new();
837            report.push_str("=== Calibration Method Selection Report ===\n\n");
838            report.push_str("Cross-Validation Results:\n");
839            report.push_str(&format!("{:-<60}\n", ""));
840            for method in &result.method_performances {
841                report.push_str(&format!(
842                    "{:<15} | Mean ECE: {:.4} ± {:.4}\n",
843                    method.name, method.mean_ece, method.std_ece
844                ));
845            }
846            report.push_str(&format!("\n{:-<60}\n", ""));
847            report.push_str(&format!(
848                "\nBest Individual Method: {}\n",
849                result.best_method
850            ));
851            report.push_str(&format!(
852                "Expected Ensemble ECE: {:.4}\n\n",
853                result.ensemble_ece
854            ));
855            report.push_str("Selected Ensemble:\n");
856            for (method, weight) in result.selected_methods.iter().zip(&result.weights) {
857                report.push_str(&format!("  {} (weight: {:.3})\n", method, weight));
858            }
859            report.push_str("\nRecommendation:\n");
860            if result.selected_methods.len() == 1 {
861                report.push_str(&format!(
862                    "Use {} for best calibration performance.\n",
863                    result.selected_methods[0]
864                ));
865            } else {
866                report.push_str(
867                    "Use weighted ensemble of selected methods for robust calibration.\n",
868                );
869            }
870            report
871        }
872    }
873}