sklears_ensemble/
analysis.rs

1//! Advanced ensemble analysis and interpretation tools
2//!
3//! This module provides comprehensive analysis capabilities for ensemble methods,
4//! including feature importance aggregation, uncertainty quantification, and
5//! ensemble interpretation tools.
6
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_core::random::thread_rng;
9use sklears_core::{
10    error::{Result, SklearsError},
11    types::Float,
12};
13
14/// Feature importance aggregation methods for ensembles
15#[derive(Debug, Clone)]
16pub enum ImportanceAggregationMethod {
17    /// Simple mean of feature importances
18    Mean,
19    /// Weighted mean based on model performance
20    WeightedMean(Vec<Float>),
21    /// Median aggregation (robust to outliers)
22    Median,
23    /// Use only top-k most important features
24    TopK(usize),
25    /// Rank-based aggregation
26    RankBased,
27    /// Bayesian model averaging of importances
28    BayesianAveraging,
29    /// Permutation-based importance aggregation
30    PermutationBased { n_repeats: usize },
31    /// SHAP-style additive feature attribution
32    SHAPBased { background_samples: usize },
33}
34
35/// Feature importance analysis results
36#[derive(Debug, Clone)]
37pub struct FeatureImportanceAnalysis {
38    /// Aggregated feature importances
39    pub feature_importances: Array1<Float>,
40    /// Standard deviation of importances across models
41    pub importance_std: Array1<Float>,
42    /// Individual model importances
43    pub individual_importances: Array2<Float>,
44    /// Feature rankings by importance
45    pub feature_rankings: Vec<usize>,
46    /// Stability measure of importance rankings
47    pub ranking_stability: Float,
48    /// Top-k most important features
49    pub top_features: Vec<(usize, Float)>,
50    /// Confidence intervals for importances
51    pub confidence_intervals: Vec<(Float, Float)>,
52    /// Pairwise feature interaction strengths
53    pub feature_interactions: Option<Array2<Float>>,
54}
55
56/// Ensemble uncertainty quantification results
57#[derive(Debug, Clone)]
58pub struct UncertaintyQuantification {
59    /// Epistemic uncertainty (model uncertainty)
60    pub epistemic_uncertainty: Array1<Float>,
61    /// Aleatoric uncertainty (data uncertainty)
62    pub aleatoric_uncertainty: Array1<Float>,
63    /// Total uncertainty
64    pub total_uncertainty: Array1<Float>,
65    /// Prediction confidence intervals
66    pub confidence_intervals: Array2<Float>,
67    /// Ensemble diversity at each prediction
68    pub prediction_diversity: Array1<Float>,
69    /// Uncertainty decomposition by source
70    pub uncertainty_decomposition: UncertaintyDecomposition,
71    /// Calibration metrics
72    pub calibration_metrics: CalibrationMetrics,
73}
74
75/// Decomposition of uncertainty into different sources
76#[derive(Debug, Clone)]
77pub struct UncertaintyDecomposition {
78    /// Uncertainty due to model disagreement
79    pub model_disagreement: Array1<Float>,
80    /// Uncertainty due to insufficient training data
81    pub data_uncertainty: Array1<Float>,
82    /// Uncertainty due to feature noise
83    pub feature_uncertainty: Array1<Float>,
84    /// Uncertainty due to label noise
85    pub label_uncertainty: Array1<Float>,
86    /// Irreducible uncertainty
87    pub irreducible_uncertainty: Array1<Float>,
88}
89
90/// Calibration metrics for ensemble predictions
91#[derive(Debug, Clone)]
92pub struct CalibrationMetrics {
93    /// Expected Calibration Error (ECE)
94    pub expected_calibration_error: Float,
95    /// Maximum Calibration Error (MCE)
96    pub maximum_calibration_error: Float,
97    /// Brier score for probability predictions
98    pub brier_score: Float,
99    /// Reliability diagram data
100    pub reliability_diagram: ReliabilityDiagram,
101    /// Over/under-confidence metrics
102    pub confidence_metrics: ConfidenceMetrics,
103}
104
105/// Reliability diagram for calibration analysis
106#[derive(Debug, Clone)]
107pub struct ReliabilityDiagram {
108    /// Bin boundaries for confidence levels
109    pub confidence_bins: Vec<Float>,
110    /// Accuracy within each confidence bin
111    pub bin_accuracies: Vec<Float>,
112    /// Proportion of predictions in each bin
113    pub bin_proportions: Vec<Float>,
114    /// Average confidence in each bin
115    pub bin_confidences: Vec<Float>,
116    /// Number of samples in each bin
117    pub bin_counts: Vec<usize>,
118}
119
120/// Confidence analysis metrics
121#[derive(Debug, Clone)]
122pub struct ConfidenceMetrics {
123    /// Average confidence on correct predictions
124    pub avg_confidence_correct: Float,
125    /// Average confidence on incorrect predictions
126    pub avg_confidence_incorrect: Float,
127    /// Confidence-accuracy correlation
128    pub confidence_accuracy_correlation: Float,
129    /// Over-confidence rate
130    pub overconfidence_rate: Float,
131    /// Under-confidence rate
132    pub underconfidence_rate: Float,
133}
134
135/// Ensemble analyzer for feature importance and uncertainty quantification
136pub struct EnsembleAnalyzer {
137    /// Method for aggregating feature importances
138    pub importance_method: ImportanceAggregationMethod,
139    /// Random state for reproducible analysis
140    pub random_state: Option<u64>,
141    /// Number of bootstrap samples for confidence intervals
142    pub n_bootstrap: usize,
143    /// Confidence level for intervals
144    pub confidence_level: Float,
145}
146
147impl EnsembleAnalyzer {
148    /// Create a new ensemble analyzer
149    pub fn new(importance_method: ImportanceAggregationMethod) -> Self {
150        Self {
151            importance_method,
152            random_state: None,
153            n_bootstrap: 100,
154            confidence_level: 0.95,
155        }
156    }
157
158    /// Configure bootstrap parameters
159    pub fn with_bootstrap(mut self, n_bootstrap: usize, confidence_level: Float) -> Self {
160        self.n_bootstrap = n_bootstrap;
161        self.confidence_level = confidence_level;
162        self
163    }
164
165    /// Set random state for reproducible results
166    pub fn with_random_state(mut self, random_state: u64) -> Self {
167        self.random_state = Some(random_state);
168        self
169    }
170
171    /// Analyze feature importances across ensemble models
172    pub fn analyze_feature_importance(
173        &self,
174        individual_importances: &Array2<Float>,
175        model_weights: Option<&Array1<Float>>,
176    ) -> Result<FeatureImportanceAnalysis> {
177        let n_models = individual_importances.nrows();
178        let n_features = individual_importances.ncols();
179
180        if n_models == 0 || n_features == 0 {
181            return Err(SklearsError::InvalidInput(
182                "Empty importance matrix".to_string(),
183            ));
184        }
185
186        // Aggregate feature importances using specified method
187        let feature_importances =
188            self.aggregate_importances(individual_importances, model_weights)?;
189
190        // Compute standard deviation across models
191        let importance_std =
192            self.compute_importance_std(individual_importances, &feature_importances);
193
194        // Rank features by importance
195        let feature_rankings = self.rank_features(&feature_importances);
196
197        // Compute ranking stability across models
198        let ranking_stability = self.compute_ranking_stability(individual_importances)?;
199
200        // Get top-k features
201        let top_features = self.get_top_features(&feature_importances, 10);
202
203        // Compute confidence intervals using bootstrap
204        let confidence_intervals = self.compute_confidence_intervals(individual_importances)?;
205
206        // Compute feature interactions if requested
207        let feature_interactions = self.compute_feature_interactions(individual_importances)?;
208
209        Ok(FeatureImportanceAnalysis {
210            feature_importances,
211            importance_std,
212            individual_importances: individual_importances.clone(),
213            feature_rankings,
214            ranking_stability,
215            top_features,
216            confidence_intervals,
217            feature_interactions,
218        })
219    }
220
221    /// Aggregate individual feature importances using the specified method
222    fn aggregate_importances(
223        &self,
224        individual_importances: &Array2<Float>,
225        model_weights: Option<&Array1<Float>>,
226    ) -> Result<Array1<Float>> {
227        let n_models = individual_importances.nrows();
228        let n_features = individual_importances.ncols();
229
230        match &self.importance_method {
231            ImportanceAggregationMethod::Mean => {
232                Ok(individual_importances.mean_axis(Axis(0)).unwrap())
233            }
234
235            ImportanceAggregationMethod::WeightedMean(weights) => {
236                if weights.len() != n_models {
237                    return Err(SklearsError::InvalidInput(
238                        "Weight vector length must match number of models".to_string(),
239                    ));
240                }
241
242                let mut aggregated = Array1::zeros(n_features);
243                let total_weight = weights.iter().sum::<Float>();
244
245                for i in 0..n_models {
246                    let row = individual_importances.row(i).to_owned();
247                    aggregated += &(row * weights[i]);
248                }
249
250                Ok(aggregated / total_weight)
251            }
252
253            ImportanceAggregationMethod::Median => {
254                let mut aggregated = Array1::zeros(n_features);
255                for j in 0..n_features {
256                    let mut feature_importances: Vec<Float> =
257                        individual_importances.column(j).iter().copied().collect();
258                    feature_importances.sort_by(|a, b| a.partial_cmp(b).unwrap());
259
260                    let median = if feature_importances.len() % 2 == 0 {
261                        let mid = feature_importances.len() / 2;
262                        (feature_importances[mid - 1] + feature_importances[mid]) / 2.0
263                    } else {
264                        feature_importances[feature_importances.len() / 2]
265                    };
266
267                    aggregated[j] = median;
268                }
269                Ok(aggregated)
270            }
271
272            ImportanceAggregationMethod::TopK(k) => {
273                // Use mean aggregation but zero out features not in top-k
274                let mean_importances = individual_importances.mean_axis(Axis(0)).unwrap();
275                let mut indexed_importances: Vec<(usize, Float)> = mean_importances
276                    .iter()
277                    .enumerate()
278                    .map(|(i, &imp)| (i, imp))
279                    .collect();
280
281                indexed_importances.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
282
283                let mut aggregated = Array1::zeros(n_features);
284                for i in 0..*k.min(&n_features) {
285                    let (feature_idx, importance) = indexed_importances[i];
286                    aggregated[feature_idx] = importance;
287                }
288
289                Ok(aggregated)
290            }
291
292            ImportanceAggregationMethod::RankBased => {
293                let mut aggregated = Array1::zeros(n_features);
294
295                for i in 0..n_models {
296                    let row = individual_importances.row(i);
297                    let ranks = self.compute_ranks(&row.to_owned());
298
299                    for j in 0..n_features {
300                        aggregated[j] += ranks[j];
301                    }
302                }
303
304                // Normalize by number of models
305                aggregated /= n_models as Float;
306
307                // Convert average ranks back to importance scores (higher rank = higher importance)
308                let max_rank = aggregated.iter().fold(0.0f64, |a, &b| a.max(b));
309                for val in aggregated.iter_mut() {
310                    *val = max_rank - *val;
311                }
312
313                Ok(aggregated)
314            }
315
316            ImportanceAggregationMethod::BayesianAveraging => {
317                // Simplified Bayesian averaging with uniform priors
318                self.bayesian_average_importances(individual_importances, model_weights)
319            }
320
321            ImportanceAggregationMethod::PermutationBased { n_repeats } => {
322                // Placeholder for permutation-based importance
323                // In a real implementation, this would compute permutation importance
324                let mean_importances = individual_importances.mean_axis(Axis(0)).unwrap();
325                Ok(mean_importances)
326            }
327
328            ImportanceAggregationMethod::SHAPBased { background_samples } => {
329                // Placeholder for SHAP-based importance
330                // In a real implementation, this would compute SHAP values
331                let mean_importances = individual_importances.mean_axis(Axis(0)).unwrap();
332                Ok(mean_importances)
333            }
334        }
335    }
336
337    /// Compute ranks for a feature importance vector
338    fn compute_ranks(&self, importances: &Array1<Float>) -> Array1<Float> {
339        let mut indexed: Vec<(usize, Float)> = importances
340            .iter()
341            .enumerate()
342            .map(|(i, &val)| (i, val))
343            .collect();
344
345        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
346
347        let mut ranks = Array1::zeros(importances.len());
348        for (rank, &(idx, _)) in indexed.iter().enumerate() {
349            ranks[idx] = rank as Float;
350        }
351
352        ranks
353    }
354
355    /// Bayesian averaging of feature importances
356    fn bayesian_average_importances(
357        &self,
358        individual_importances: &Array2<Float>,
359        model_weights: Option<&Array1<Float>>,
360    ) -> Result<Array1<Float>> {
361        let n_models = individual_importances.nrows();
362        let n_features = individual_importances.ncols();
363
364        // Use model weights as posterior probabilities if provided
365        let weights = if let Some(w) = model_weights {
366            w.clone()
367        } else {
368            Array1::from_elem(n_models, 1.0 / n_models as Float)
369        };
370
371        // Compute weighted average with Bayesian interpretation
372        let mut aggregated = Array1::zeros(n_features);
373        for i in 0..n_models {
374            let row = individual_importances.row(i).to_owned();
375            aggregated += &(row * weights[i]);
376        }
377
378        Ok(aggregated)
379    }
380
381    /// Compute standard deviation of feature importances across models
382    fn compute_importance_std(
383        &self,
384        individual_importances: &Array2<Float>,
385        mean_importances: &Array1<Float>,
386    ) -> Array1<Float> {
387        let n_models = individual_importances.nrows();
388        let n_features = individual_importances.ncols();
389
390        let mut std_dev = Array1::zeros(n_features);
391
392        for j in 0..n_features {
393            let mean = mean_importances[j];
394            let variance = individual_importances
395                .column(j)
396                .iter()
397                .map(|&val| (val - mean).powi(2))
398                .sum::<Float>()
399                / n_models as Float;
400
401            std_dev[j] = variance.sqrt();
402        }
403
404        std_dev
405    }
406
407    /// Rank features by their aggregated importance
408    fn rank_features(&self, feature_importances: &Array1<Float>) -> Vec<usize> {
409        let mut indexed: Vec<(usize, Float)> = feature_importances
410            .iter()
411            .enumerate()
412            .map(|(i, &val)| (i, val))
413            .collect();
414
415        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
416
417        indexed.into_iter().map(|(idx, _)| idx).collect()
418    }
419
420    /// Compute ranking stability across models using Kendall's tau
421    fn compute_ranking_stability(&self, individual_importances: &Array2<Float>) -> Result<Float> {
422        let n_models = individual_importances.nrows();
423        if n_models < 2 {
424            return Ok(1.0);
425        }
426
427        let mut total_tau = 0.0;
428        let mut pair_count = 0;
429
430        for i in 0..n_models {
431            for j in (i + 1)..n_models {
432                let rank1 = self.compute_ranks(&individual_importances.row(i).to_owned());
433                let rank2 = self.compute_ranks(&individual_importances.row(j).to_owned());
434
435                let tau = self.kendall_tau(&rank1, &rank2)?;
436                total_tau += tau;
437                pair_count += 1;
438            }
439        }
440
441        Ok(total_tau / pair_count as Float)
442    }
443
444    /// Compute Kendall's tau correlation coefficient
445    fn kendall_tau(&self, rank1: &Array1<Float>, rank2: &Array1<Float>) -> Result<Float> {
446        if rank1.len() != rank2.len() {
447            return Err(SklearsError::InvalidInput(
448                "Rank vectors must have same length".to_string(),
449            ));
450        }
451
452        let n = rank1.len();
453        if n < 2 {
454            return Ok(1.0);
455        }
456
457        let mut concordant = 0;
458        let mut discordant = 0;
459
460        for i in 0..n {
461            for j in (i + 1)..n {
462                let diff1 = rank1[i] - rank1[j];
463                let diff2 = rank2[i] - rank2[j];
464
465                if diff1 * diff2 > 0.0 {
466                    concordant += 1;
467                } else if diff1 * diff2 < 0.0 {
468                    discordant += 1;
469                }
470            }
471        }
472
473        let total_pairs = (n * (n - 1)) / 2;
474        let tau = (concordant as Float - discordant as Float) / total_pairs as Float;
475
476        Ok(tau)
477    }
478
479    /// Get top-k most important features
480    fn get_top_features(
481        &self,
482        feature_importances: &Array1<Float>,
483        k: usize,
484    ) -> Vec<(usize, Float)> {
485        let mut indexed: Vec<(usize, Float)> = feature_importances
486            .iter()
487            .enumerate()
488            .map(|(i, &val)| (i, val))
489            .collect();
490
491        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
492
493        indexed.into_iter().take(k).collect()
494    }
495
496    /// Compute confidence intervals for feature importances using bootstrap
497    fn compute_confidence_intervals(
498        &self,
499        individual_importances: &Array2<Float>,
500    ) -> Result<Vec<(Float, Float)>> {
501        let n_models = individual_importances.nrows();
502        let n_features = individual_importances.ncols();
503
504        if self.n_bootstrap == 0 {
505            // Return dummy intervals if no bootstrap requested
506            return Ok(vec![(0.0, 0.0); n_features]);
507        }
508
509        let mut rng = thread_rng();
510        let mut bootstrap_importances = Vec::with_capacity(self.n_bootstrap);
511
512        for _ in 0..self.n_bootstrap {
513            // Bootstrap sample of models
514            let mut bootstrap_matrix = Array2::zeros((n_models, n_features));
515            for i in 0..n_models {
516                let idx = rng.gen_range(0..n_models);
517                bootstrap_matrix
518                    .row_mut(i)
519                    .assign(&individual_importances.row(idx));
520            }
521
522            // Compute mean importance for this bootstrap sample
523            let mean_importance = bootstrap_matrix.mean_axis(Axis(0)).unwrap();
524            bootstrap_importances.push(mean_importance);
525        }
526
527        // Compute confidence intervals
528        let alpha = 1.0 - self.confidence_level;
529        let lower_percentile = (alpha / 2.0) * 100.0;
530        let upper_percentile = (1.0 - alpha / 2.0) * 100.0;
531
532        let mut confidence_intervals = Vec::with_capacity(n_features);
533
534        for j in 0..n_features {
535            let mut feature_values: Vec<Float> = bootstrap_importances
536                .iter()
537                .map(|importance| importance[j])
538                .collect();
539
540            feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
541
542            let lower_idx =
543                ((lower_percentile / 100.0) * (feature_values.len() - 1) as Float) as usize;
544            let upper_idx =
545                ((upper_percentile / 100.0) * (feature_values.len() - 1) as Float) as usize;
546
547            let lower = feature_values[lower_idx];
548            let upper = feature_values[upper_idx];
549
550            confidence_intervals.push((lower, upper));
551        }
552
553        Ok(confidence_intervals)
554    }
555
556    /// Compute pairwise feature interactions
557    fn compute_feature_interactions(
558        &self,
559        individual_importances: &Array2<Float>,
560    ) -> Result<Option<Array2<Float>>> {
561        let n_features = individual_importances.ncols();
562
563        // For now, compute correlation between feature importances across models
564        let mut interactions = Array2::zeros((n_features, n_features));
565
566        for i in 0..n_features {
567            for j in 0..n_features {
568                if i == j {
569                    interactions[[i, j]] = 1.0;
570                } else {
571                    let corr = self.compute_feature_correlation(
572                        &individual_importances.column(i).to_owned(),
573                        &individual_importances.column(j).to_owned(),
574                    )?;
575                    interactions[[i, j]] = corr;
576                }
577            }
578        }
579
580        Ok(Some(interactions))
581    }
582
583    /// Compute correlation between two feature importance vectors
584    fn compute_feature_correlation(
585        &self,
586        feature1: &Array1<Float>,
587        feature2: &Array1<Float>,
588    ) -> Result<Float> {
589        if feature1.len() != feature2.len() || feature1.is_empty() {
590            return Ok(0.0);
591        }
592
593        let n = feature1.len() as Float;
594        let mean1 = feature1.sum() / n;
595        let mean2 = feature2.sum() / n;
596
597        let mut numerator = 0.0;
598        let mut sum_sq1 = 0.0;
599        let mut sum_sq2 = 0.0;
600
601        for i in 0..feature1.len() {
602            let diff1 = feature1[i] - mean1;
603            let diff2 = feature2[i] - mean2;
604
605            numerator += diff1 * diff2;
606            sum_sq1 += diff1 * diff1;
607            sum_sq2 += diff2 * diff2;
608        }
609
610        let denominator = (sum_sq1 * sum_sq2).sqrt();
611        let correlation = if denominator > 0.0 {
612            numerator / denominator
613        } else {
614            0.0
615        };
616
617        Ok(correlation)
618    }
619
620    /// Quantify uncertainty in ensemble predictions
621    pub fn quantify_uncertainty(
622        &self,
623        ensemble_predictions: &Array2<Float>,
624        true_labels: Option<&Array1<Float>>,
625    ) -> Result<UncertaintyQuantification> {
626        let n_models = ensemble_predictions.nrows();
627        let n_samples = ensemble_predictions.ncols();
628
629        if n_models == 0 || n_samples == 0 {
630            return Err(SklearsError::InvalidInput(
631                "Empty prediction matrix".to_string(),
632            ));
633        }
634
635        // Compute epistemic uncertainty (model disagreement)
636        let epistemic_uncertainty = self.compute_epistemic_uncertainty(ensemble_predictions);
637
638        // Compute aleatoric uncertainty (data uncertainty) - simplified version
639        let aleatoric_uncertainty = self.compute_aleatoric_uncertainty(ensemble_predictions);
640
641        // Total uncertainty
642        let total_uncertainty =
643            self.compute_total_uncertainty(&epistemic_uncertainty, &aleatoric_uncertainty);
644
645        // Confidence intervals
646        let confidence_intervals =
647            self.compute_prediction_confidence_intervals(ensemble_predictions)?;
648
649        // Prediction diversity
650        let prediction_diversity = self.compute_prediction_diversity(ensemble_predictions);
651
652        // Uncertainty decomposition
653        let uncertainty_decomposition = self.decompose_uncertainty(ensemble_predictions)?;
654
655        // Calibration metrics (if true labels provided)
656        let calibration_metrics = if let Some(labels) = true_labels {
657            self.compute_calibration_metrics(ensemble_predictions, labels)?
658        } else {
659            self.default_calibration_metrics()
660        };
661
662        Ok(UncertaintyQuantification {
663            epistemic_uncertainty,
664            aleatoric_uncertainty,
665            total_uncertainty,
666            confidence_intervals,
667            prediction_diversity,
668            uncertainty_decomposition,
669            calibration_metrics,
670        })
671    }
672
673    /// Compute epistemic uncertainty (model disagreement)
674    fn compute_epistemic_uncertainty(&self, ensemble_predictions: &Array2<Float>) -> Array1<Float> {
675        let n_samples = ensemble_predictions.ncols();
676        let mut epistemic = Array1::zeros(n_samples);
677
678        for i in 0..n_samples {
679            let predictions = ensemble_predictions.column(i);
680            let mean_pred = predictions.mean().unwrap();
681            let variance = predictions
682                .iter()
683                .map(|&pred| (pred - mean_pred).powi(2))
684                .sum::<Float>()
685                / predictions.len() as Float;
686
687            epistemic[i] = variance.sqrt();
688        }
689
690        epistemic
691    }
692
693    /// Compute aleatoric uncertainty (simplified version)
694    fn compute_aleatoric_uncertainty(&self, ensemble_predictions: &Array2<Float>) -> Array1<Float> {
695        let n_samples = ensemble_predictions.ncols();
696        let mut aleatoric = Array1::zeros(n_samples);
697
698        // Simplified aleatoric uncertainty estimation
699        // In practice, this would require additional information about data noise
700        for i in 0..n_samples {
701            let predictions = ensemble_predictions.column(i);
702            let mean_pred = predictions.mean().unwrap();
703
704            // Use prediction magnitude as a proxy for aleatoric uncertainty
705            aleatoric[i] = mean_pred.abs() * 0.1; // Simplified heuristic
706        }
707
708        aleatoric
709    }
710
711    /// Compute total uncertainty
712    fn compute_total_uncertainty(
713        &self,
714        epistemic: &Array1<Float>,
715        aleatoric: &Array1<Float>,
716    ) -> Array1<Float> {
717        let mut total = Array1::zeros(epistemic.len());
718
719        for i in 0..epistemic.len() {
720            // Total uncertainty as combination of epistemic and aleatoric
721            total[i] = (epistemic[i].powi(2) + aleatoric[i].powi(2)).sqrt();
722        }
723
724        total
725    }
726
727    /// Compute confidence intervals for predictions
728    fn compute_prediction_confidence_intervals(
729        &self,
730        ensemble_predictions: &Array2<Float>,
731    ) -> Result<Array2<Float>> {
732        let n_samples = ensemble_predictions.ncols();
733        let alpha = 1.0 - self.confidence_level;
734        let lower_percentile = alpha / 2.0;
735        let upper_percentile = 1.0 - alpha / 2.0;
736
737        let mut intervals = Array2::zeros((n_samples, 2));
738
739        for i in 0..n_samples {
740            let mut predictions: Vec<Float> = ensemble_predictions.column(i).to_vec();
741            predictions.sort_by(|a, b| a.partial_cmp(b).unwrap());
742
743            let lower_idx = (lower_percentile * (predictions.len() - 1) as Float) as usize;
744            let upper_idx = (upper_percentile * (predictions.len() - 1) as Float) as usize;
745
746            intervals[[i, 0]] = predictions[lower_idx];
747            intervals[[i, 1]] = predictions[upper_idx];
748        }
749
750        Ok(intervals)
751    }
752
753    /// Compute prediction diversity across ensemble
754    fn compute_prediction_diversity(&self, ensemble_predictions: &Array2<Float>) -> Array1<Float> {
755        let n_samples = ensemble_predictions.ncols();
756        let mut diversity = Array1::zeros(n_samples);
757
758        for i in 0..n_samples {
759            let predictions = ensemble_predictions.column(i);
760            let mean_pred = predictions.mean().unwrap();
761
762            // Coefficient of variation as diversity measure
763            let std_pred = predictions
764                .iter()
765                .map(|&pred| (pred - mean_pred).powi(2))
766                .sum::<Float>()
767                / predictions.len() as Float;
768
769            diversity[i] = if mean_pred.abs() > 1e-8 {
770                std_pred.sqrt() / mean_pred.abs()
771            } else {
772                std_pred.sqrt()
773            };
774        }
775
776        diversity
777    }
778
779    /// Decompose uncertainty into different sources
780    fn decompose_uncertainty(
781        &self,
782        ensemble_predictions: &Array2<Float>,
783    ) -> Result<UncertaintyDecomposition> {
784        let n_samples = ensemble_predictions.ncols();
785
786        // Simplified uncertainty decomposition
787        let model_disagreement = self.compute_epistemic_uncertainty(ensemble_predictions);
788        let data_uncertainty = Array1::from_elem(n_samples, 0.1); // Placeholder
789        let feature_uncertainty = Array1::from_elem(n_samples, 0.05); // Placeholder
790        let label_uncertainty = Array1::from_elem(n_samples, 0.03); // Placeholder
791        let irreducible_uncertainty = Array1::from_elem(n_samples, 0.02); // Placeholder
792
793        Ok(UncertaintyDecomposition {
794            model_disagreement,
795            data_uncertainty,
796            feature_uncertainty,
797            label_uncertainty,
798            irreducible_uncertainty,
799        })
800    }
801
802    /// Compute calibration metrics for ensemble predictions
803    fn compute_calibration_metrics(
804        &self,
805        ensemble_predictions: &Array2<Float>,
806        true_labels: &Array1<Float>,
807    ) -> Result<CalibrationMetrics> {
808        let n_samples = ensemble_predictions.ncols();
809
810        if true_labels.len() != n_samples {
811            return Err(SklearsError::InvalidInput(
812                "Prediction and label lengths must match".to_string(),
813            ));
814        }
815
816        // Compute mean predictions
817        let mean_predictions = ensemble_predictions.mean_axis(Axis(0)).unwrap();
818
819        // Convert to binary classification for calibration analysis
820        let predicted_probs: Vec<Float> = mean_predictions.iter().copied().collect();
821        let true_binary: Vec<bool> = true_labels.iter().map(|&label| label > 0.5).collect();
822
823        // Compute ECE and MCE
824        let (ece, mce) = self.compute_calibration_errors(&predicted_probs, &true_binary)?;
825
826        // Compute Brier score
827        let brier_score = self.compute_brier_score(&predicted_probs, &true_binary);
828
829        // Create reliability diagram
830        let reliability_diagram =
831            self.create_reliability_diagram(&predicted_probs, &true_binary)?;
832
833        // Compute confidence metrics
834        let confidence_metrics = self.compute_confidence_metrics(&predicted_probs, &true_binary);
835
836        Ok(CalibrationMetrics {
837            expected_calibration_error: ece,
838            maximum_calibration_error: mce,
839            brier_score,
840            reliability_diagram,
841            confidence_metrics,
842        })
843    }
844
845    /// Compute Expected and Maximum Calibration Errors
846    fn compute_calibration_errors(
847        &self,
848        predicted_probs: &[Float],
849        true_binary: &[bool],
850    ) -> Result<(Float, Float)> {
851        let n_bins = 10;
852        let bin_size = 1.0 / n_bins as Float;
853
854        let mut ece: Float = 0.0;
855        let mut mce: Float = 0.0;
856        let mut total_samples = 0;
857
858        for i in 0..n_bins {
859            let bin_lower = i as Float * bin_size;
860            let bin_upper = (i + 1) as Float * bin_size;
861
862            // Find samples in this bin
863            let bin_indices: Vec<usize> = predicted_probs
864                .iter()
865                .enumerate()
866                .filter(|(_, &prob)| prob >= bin_lower && prob < bin_upper)
867                .map(|(idx, _)| idx)
868                .collect();
869
870            if bin_indices.is_empty() {
871                continue;
872            }
873
874            let bin_size_actual = bin_indices.len();
875            total_samples += bin_size_actual;
876
877            // Compute average confidence and accuracy in this bin
878            let avg_confidence = bin_indices
879                .iter()
880                .map(|&idx| predicted_probs[idx])
881                .sum::<Float>()
882                / bin_size_actual as Float;
883
884            let bin_accuracy = bin_indices
885                .iter()
886                .map(|&idx| if true_binary[idx] { 1.0 } else { 0.0 })
887                .sum::<Float>()
888                / bin_size_actual as Float;
889
890            let calibration_error = (avg_confidence - bin_accuracy).abs();
891
892            // Update ECE and MCE
893            ece += (bin_size_actual as Float / predicted_probs.len() as Float) * calibration_error;
894            mce = mce.max(calibration_error);
895        }
896
897        Ok((ece, mce))
898    }
899
900    /// Compute Brier score
901    fn compute_brier_score(&self, predicted_probs: &[Float], true_binary: &[bool]) -> Float {
902        predicted_probs
903            .iter()
904            .zip(true_binary.iter())
905            .map(|(&prob, &is_true)| {
906                let true_prob = if is_true { 1.0 } else { 0.0 };
907                (prob - true_prob).powi(2)
908            })
909            .sum::<Float>()
910            / predicted_probs.len() as Float
911    }
912
913    /// Create reliability diagram data
914    fn create_reliability_diagram(
915        &self,
916        predicted_probs: &[Float],
917        true_binary: &[bool],
918    ) -> Result<ReliabilityDiagram> {
919        let n_bins = 10;
920        let bin_size = 1.0 / n_bins as Float;
921
922        let mut confidence_bins = Vec::with_capacity(n_bins + 1);
923        let mut bin_accuracies = Vec::with_capacity(n_bins);
924        let mut bin_proportions = Vec::with_capacity(n_bins);
925        let mut bin_confidences = Vec::with_capacity(n_bins);
926        let mut bin_counts = Vec::with_capacity(n_bins);
927
928        // Create bin boundaries
929        for i in 0..=n_bins {
930            confidence_bins.push(i as Float * bin_size);
931        }
932
933        for i in 0..n_bins {
934            let bin_lower = i as Float * bin_size;
935            let bin_upper = (i + 1) as Float * bin_size;
936
937            // Find samples in this bin
938            let bin_indices: Vec<usize> = predicted_probs
939                .iter()
940                .enumerate()
941                .filter(|(_, &prob)| prob >= bin_lower && prob < bin_upper)
942                .map(|(idx, _)| idx)
943                .collect();
944
945            let bin_count = bin_indices.len();
946            bin_counts.push(bin_count);
947
948            if bin_count == 0 {
949                bin_accuracies.push(0.0);
950                bin_proportions.push(0.0);
951                bin_confidences.push((bin_lower + bin_upper) / 2.0);
952            } else {
953                let avg_confidence = bin_indices
954                    .iter()
955                    .map(|&idx| predicted_probs[idx])
956                    .sum::<Float>()
957                    / bin_count as Float;
958
959                let bin_accuracy = bin_indices
960                    .iter()
961                    .map(|&idx| if true_binary[idx] { 1.0 } else { 0.0 })
962                    .sum::<Float>()
963                    / bin_count as Float;
964
965                let bin_proportion = bin_count as Float / predicted_probs.len() as Float;
966
967                bin_accuracies.push(bin_accuracy);
968                bin_proportions.push(bin_proportion);
969                bin_confidences.push(avg_confidence);
970            }
971        }
972
973        Ok(ReliabilityDiagram {
974            confidence_bins,
975            bin_accuracies,
976            bin_proportions,
977            bin_confidences,
978            bin_counts,
979        })
980    }
981
982    /// Compute confidence-related metrics
983    fn compute_confidence_metrics(
984        &self,
985        predicted_probs: &[Float],
986        true_binary: &[bool],
987    ) -> ConfidenceMetrics {
988        let mut correct_confidences = Vec::new();
989        let mut incorrect_confidences = Vec::new();
990
991        for (&prob, &is_true) in predicted_probs.iter().zip(true_binary.iter()) {
992            let predicted_class = prob > 0.5;
993            if predicted_class == is_true {
994                correct_confidences.push(prob.max(1.0 - prob)); // Distance from 0.5
995            } else {
996                incorrect_confidences.push(prob.max(1.0 - prob));
997            }
998        }
999
1000        let avg_confidence_correct = if correct_confidences.is_empty() {
1001            0.0
1002        } else {
1003            correct_confidences.iter().sum::<Float>() / correct_confidences.len() as Float
1004        };
1005
1006        let avg_confidence_incorrect = if incorrect_confidences.is_empty() {
1007            0.0
1008        } else {
1009            incorrect_confidences.iter().sum::<Float>() / incorrect_confidences.len() as Float
1010        };
1011
1012        // Compute correlation between confidence and accuracy
1013        let confidence_accuracy_correlation =
1014            self.compute_confidence_accuracy_correlation(predicted_probs, true_binary);
1015
1016        // Over/under-confidence rates (simplified)
1017        let threshold = 0.8;
1018        let overconfident_count = predicted_probs
1019            .iter()
1020            .zip(true_binary.iter())
1021            .filter(|(&prob, &is_true)| {
1022                let predicted_class = prob > 0.5;
1023                prob.max(1.0 - prob) > threshold && predicted_class != is_true
1024            })
1025            .count();
1026
1027        let underconfident_count = predicted_probs
1028            .iter()
1029            .zip(true_binary.iter())
1030            .filter(|(&prob, &is_true)| {
1031                let predicted_class = prob > 0.5;
1032                prob.max(1.0 - prob) < (1.0 - threshold) && predicted_class == is_true
1033            })
1034            .count();
1035
1036        let overconfidence_rate = overconfident_count as Float / predicted_probs.len() as Float;
1037        let underconfidence_rate = underconfident_count as Float / predicted_probs.len() as Float;
1038
1039        ConfidenceMetrics {
1040            avg_confidence_correct,
1041            avg_confidence_incorrect,
1042            confidence_accuracy_correlation,
1043            overconfidence_rate,
1044            underconfidence_rate,
1045        }
1046    }
1047
1048    /// Compute correlation between confidence and accuracy
1049    fn compute_confidence_accuracy_correlation(
1050        &self,
1051        predicted_probs: &[Float],
1052        true_binary: &[bool],
1053    ) -> Float {
1054        let confidences: Vec<Float> = predicted_probs
1055            .iter()
1056            .map(|&prob| prob.max(1.0 - prob))
1057            .collect();
1058        let accuracies: Vec<Float> = predicted_probs
1059            .iter()
1060            .zip(true_binary.iter())
1061            .map(|(&prob, &is_true)| {
1062                let predicted_class = prob > 0.5;
1063                if predicted_class == is_true {
1064                    1.0
1065                } else {
1066                    0.0
1067                }
1068            })
1069            .collect();
1070
1071        // Compute Pearson correlation
1072        let n = confidences.len() as Float;
1073        let mean_conf = confidences.iter().sum::<Float>() / n;
1074        let mean_acc = accuracies.iter().sum::<Float>() / n;
1075
1076        let mut numerator = 0.0;
1077        let mut sum_sq_conf = 0.0;
1078        let mut sum_sq_acc = 0.0;
1079
1080        for i in 0..confidences.len() {
1081            let diff_conf = confidences[i] - mean_conf;
1082            let diff_acc = accuracies[i] - mean_acc;
1083
1084            numerator += diff_conf * diff_acc;
1085            sum_sq_conf += diff_conf * diff_conf;
1086            sum_sq_acc += diff_acc * diff_acc;
1087        }
1088
1089        let denominator = (sum_sq_conf * sum_sq_acc).sqrt();
1090        if denominator > 0.0 {
1091            numerator / denominator
1092        } else {
1093            0.0
1094        }
1095    }
1096
1097    /// Create default calibration metrics when no true labels are provided
1098    fn default_calibration_metrics(&self) -> CalibrationMetrics {
1099        CalibrationMetrics {
1100            expected_calibration_error: 0.0,
1101            maximum_calibration_error: 0.0,
1102            brier_score: 0.0,
1103            reliability_diagram: ReliabilityDiagram {
1104                confidence_bins: vec![],
1105                bin_accuracies: vec![],
1106                bin_proportions: vec![],
1107                bin_confidences: vec![],
1108                bin_counts: vec![],
1109            },
1110            confidence_metrics: ConfidenceMetrics {
1111                avg_confidence_correct: 0.0,
1112                avg_confidence_incorrect: 0.0,
1113                confidence_accuracy_correlation: 0.0,
1114                overconfidence_rate: 0.0,
1115                underconfidence_rate: 0.0,
1116            },
1117        }
1118    }
1119}
1120
1121impl Default for EnsembleAnalyzer {
1122    fn default() -> Self {
1123        Self::new(ImportanceAggregationMethod::Mean)
1124    }
1125}
1126
1127/// Convenience functions for common analysis tasks
1128impl EnsembleAnalyzer {
1129    /// Create analyzer for mean-based feature importance aggregation
1130    pub fn mean_importance() -> Self {
1131        Self::new(ImportanceAggregationMethod::Mean)
1132    }
1133
1134    /// Create analyzer for weighted feature importance aggregation
1135    pub fn weighted_importance(weights: Vec<Float>) -> Self {
1136        Self::new(ImportanceAggregationMethod::WeightedMean(weights))
1137    }
1138
1139    /// Create analyzer for robust median-based aggregation
1140    pub fn robust_importance() -> Self {
1141        Self::new(ImportanceAggregationMethod::Median)
1142    }
1143
1144    /// Create analyzer for rank-based aggregation
1145    pub fn rank_based_importance() -> Self {
1146        Self::new(ImportanceAggregationMethod::RankBased)
1147    }
1148
1149    /// Create analyzer with permutation-based importance
1150    pub fn permutation_importance(n_repeats: usize) -> Self {
1151        Self::new(ImportanceAggregationMethod::PermutationBased { n_repeats })
1152    }
1153
1154    /// Create analyzer with SHAP-based importance
1155    pub fn shap_importance(background_samples: usize) -> Self {
1156        Self::new(ImportanceAggregationMethod::SHAPBased { background_samples })
1157    }
1158}
1159
1160#[allow(non_snake_case)]
1161#[cfg(test)]
1162mod tests {
1163    use super::*;
1164    use scirs2_core::ndarray::array;
1165
1166    #[test]
1167    fn test_ensemble_analyzer_creation() {
1168        let analyzer = EnsembleAnalyzer::default();
1169        assert!(matches!(
1170            analyzer.importance_method,
1171            ImportanceAggregationMethod::Mean
1172        ));
1173    }
1174
1175    #[test]
1176    fn test_feature_importance_analysis() {
1177        let analyzer = EnsembleAnalyzer::mean_importance();
1178
1179        // Create mock importance matrix
1180        let importances = array![[0.5, 0.3, 0.2], [0.4, 0.4, 0.2], [0.6, 0.2, 0.2]];
1181
1182        let analysis = analyzer
1183            .analyze_feature_importance(&importances, None)
1184            .unwrap();
1185
1186        assert_eq!(analysis.feature_importances.len(), 3);
1187        assert_eq!(analysis.importance_std.len(), 3);
1188        assert_eq!(analysis.feature_rankings.len(), 3);
1189        assert!(analysis.ranking_stability >= 0.0 && analysis.ranking_stability <= 1.0);
1190        assert_eq!(analysis.top_features.len(), 3);
1191    }
1192
1193    #[test]
1194    fn test_weighted_importance_aggregation() {
1195        let weights = vec![0.5, 0.3, 0.2];
1196        let analyzer = EnsembleAnalyzer::weighted_importance(weights);
1197
1198        let importances = array![[0.5, 0.3, 0.2], [0.4, 0.4, 0.2], [0.6, 0.2, 0.2]];
1199
1200        let analysis = analyzer
1201            .analyze_feature_importance(&importances, None)
1202            .unwrap();
1203
1204        assert_eq!(analysis.feature_importances.len(), 3);
1205        // Weighted average should be different from simple mean
1206        let simple_mean = importances.mean_axis(Axis(0)).unwrap();
1207        assert!((analysis.feature_importances[0] - simple_mean[0]).abs() > 1e-10);
1208    }
1209
1210    #[test]
1211    fn test_median_aggregation() {
1212        let analyzer = EnsembleAnalyzer::robust_importance();
1213
1214        let importances = array![
1215            [0.1, 0.3, 0.6], // Outlier in first feature
1216            [0.5, 0.3, 0.2],
1217            [0.4, 0.4, 0.2]
1218        ];
1219
1220        let analysis = analyzer
1221            .analyze_feature_importance(&importances, None)
1222            .unwrap();
1223
1224        assert_eq!(analysis.feature_importances.len(), 3);
1225        // Median should be more robust to outliers
1226        assert!((analysis.feature_importances[0] - 0.4).abs() < 1e-6); // Median of [0.1, 0.5, 0.4] is 0.4
1227    }
1228
1229    #[test]
1230    fn test_uncertainty_quantification() {
1231        let analyzer = EnsembleAnalyzer::default();
1232
1233        // Create mock ensemble predictions
1234        let predictions = array![
1235            [0.1, 0.8, 0.3, 0.7],
1236            [0.2, 0.9, 0.4, 0.6],
1237            [0.1, 0.7, 0.5, 0.8]
1238        ];
1239
1240        let uncertainty = analyzer.quantify_uncertainty(&predictions, None).unwrap();
1241
1242        assert_eq!(uncertainty.epistemic_uncertainty.len(), 4);
1243        assert_eq!(uncertainty.aleatoric_uncertainty.len(), 4);
1244        assert_eq!(uncertainty.total_uncertainty.len(), 4);
1245        assert_eq!(uncertainty.confidence_intervals.nrows(), 4);
1246        assert_eq!(uncertainty.confidence_intervals.ncols(), 2);
1247        assert_eq!(uncertainty.prediction_diversity.len(), 4);
1248    }
1249
1250    #[test]
1251    fn test_calibration_metrics() {
1252        let analyzer = EnsembleAnalyzer::default();
1253
1254        let predictions = array![
1255            [0.1, 0.8, 0.3, 0.7],
1256            [0.2, 0.9, 0.4, 0.6],
1257            [0.1, 0.7, 0.5, 0.8]
1258        ];
1259
1260        let true_labels = array![0.0, 1.0, 0.0, 1.0];
1261
1262        let uncertainty = analyzer
1263            .quantify_uncertainty(&predictions, Some(&true_labels))
1264            .unwrap();
1265
1266        assert!(uncertainty.calibration_metrics.expected_calibration_error >= 0.0);
1267        assert!(uncertainty.calibration_metrics.maximum_calibration_error >= 0.0);
1268        assert!(uncertainty.calibration_metrics.brier_score >= 0.0);
1269        assert!(!uncertainty
1270            .calibration_metrics
1271            .reliability_diagram
1272            .confidence_bins
1273            .is_empty());
1274    }
1275
1276    #[test]
1277    fn test_rank_based_aggregation() {
1278        let analyzer = EnsembleAnalyzer::rank_based_importance();
1279
1280        let importances = array![
1281            [0.1, 0.5, 0.4], // Ranks: [2, 0, 1]
1282            [0.3, 0.2, 0.5], // Ranks: [1, 2, 0]
1283            [0.6, 0.1, 0.3]  // Ranks: [0, 2, 1]
1284        ];
1285
1286        let analysis = analyzer
1287            .analyze_feature_importance(&importances, None)
1288            .unwrap();
1289
1290        assert_eq!(analysis.feature_importances.len(), 3);
1291        // Rank-based aggregation should handle different scales well
1292        let sum_importance = analysis.feature_importances.sum();
1293        assert!(sum_importance > 0.0);
1294    }
1295
1296    #[test]
1297    fn test_kendall_tau_correlation() {
1298        let analyzer = EnsembleAnalyzer::default();
1299
1300        let rank1 = array![1.0, 2.0, 3.0, 4.0];
1301        let rank2 = array![1.0, 2.0, 3.0, 4.0]; // Perfect correlation
1302
1303        let tau = analyzer.kendall_tau(&rank1, &rank2).unwrap();
1304        assert!((tau - 1.0).abs() < 1e-6);
1305
1306        let rank3 = array![4.0, 3.0, 2.0, 1.0]; // Perfect negative correlation
1307        let tau_neg = analyzer.kendall_tau(&rank1, &rank3).unwrap();
1308        assert!((tau_neg + 1.0).abs() < 1e-6);
1309    }
1310
1311    #[test]
1312    fn test_top_features_extraction() {
1313        let analyzer = EnsembleAnalyzer::default();
1314
1315        let importances = array![0.1, 0.5, 0.3, 0.8, 0.2];
1316        let top_features = analyzer.get_top_features(&importances, 3);
1317
1318        assert_eq!(top_features.len(), 3);
1319        assert_eq!(top_features[0].0, 3); // Index of highest importance (0.8)
1320        assert_eq!(top_features[1].0, 1); // Index of second highest (0.5)
1321        assert_eq!(top_features[2].0, 2); // Index of third highest (0.3)
1322    }
1323
1324    #[test]
1325    fn test_confidence_intervals() {
1326        let analyzer = EnsembleAnalyzer::default().with_bootstrap(50, 0.95);
1327
1328        let importances = array![
1329            [0.5, 0.3, 0.2],
1330            [0.4, 0.4, 0.2],
1331            [0.6, 0.2, 0.2],
1332            [0.5, 0.3, 0.2],
1333            [0.4, 0.5, 0.1]
1334        ];
1335
1336        let intervals = analyzer.compute_confidence_intervals(&importances).unwrap();
1337
1338        assert_eq!(intervals.len(), 3); // 3 features
1339        for (lower, upper) in &intervals {
1340            assert!(lower <= upper);
1341        }
1342    }
1343
1344    #[test]
1345    fn test_prediction_confidence_intervals() {
1346        let analyzer = EnsembleAnalyzer::default().with_bootstrap(10, 0.9);
1347
1348        let predictions = array![
1349            [0.1, 0.8, 0.3],
1350            [0.2, 0.9, 0.4],
1351            [0.0, 0.7, 0.5],
1352            [0.3, 0.8, 0.2]
1353        ];
1354
1355        let intervals = analyzer
1356            .compute_prediction_confidence_intervals(&predictions)
1357            .unwrap();
1358
1359        assert_eq!(intervals.nrows(), 3); // 3 samples
1360        assert_eq!(intervals.ncols(), 2); // Lower and upper bounds
1361
1362        for i in 0..3 {
1363            assert!(intervals[[i, 0]] <= intervals[[i, 1]]); // Lower <= Upper
1364        }
1365    }
1366}