1use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_core::random::thread_rng;
9use sklears_core::{
10 error::{Result, SklearsError},
11 types::Float,
12};
13
14#[derive(Debug, Clone)]
16pub enum ImportanceAggregationMethod {
17 Mean,
19 WeightedMean(Vec<Float>),
21 Median,
23 TopK(usize),
25 RankBased,
27 BayesianAveraging,
29 PermutationBased { n_repeats: usize },
31 SHAPBased { background_samples: usize },
33}
34
35#[derive(Debug, Clone)]
37pub struct FeatureImportanceAnalysis {
38 pub feature_importances: Array1<Float>,
40 pub importance_std: Array1<Float>,
42 pub individual_importances: Array2<Float>,
44 pub feature_rankings: Vec<usize>,
46 pub ranking_stability: Float,
48 pub top_features: Vec<(usize, Float)>,
50 pub confidence_intervals: Vec<(Float, Float)>,
52 pub feature_interactions: Option<Array2<Float>>,
54}
55
56#[derive(Debug, Clone)]
58pub struct UncertaintyQuantification {
59 pub epistemic_uncertainty: Array1<Float>,
61 pub aleatoric_uncertainty: Array1<Float>,
63 pub total_uncertainty: Array1<Float>,
65 pub confidence_intervals: Array2<Float>,
67 pub prediction_diversity: Array1<Float>,
69 pub uncertainty_decomposition: UncertaintyDecomposition,
71 pub calibration_metrics: CalibrationMetrics,
73}
74
75#[derive(Debug, Clone)]
77pub struct UncertaintyDecomposition {
78 pub model_disagreement: Array1<Float>,
80 pub data_uncertainty: Array1<Float>,
82 pub feature_uncertainty: Array1<Float>,
84 pub label_uncertainty: Array1<Float>,
86 pub irreducible_uncertainty: Array1<Float>,
88}
89
90#[derive(Debug, Clone)]
92pub struct CalibrationMetrics {
93 pub expected_calibration_error: Float,
95 pub maximum_calibration_error: Float,
97 pub brier_score: Float,
99 pub reliability_diagram: ReliabilityDiagram,
101 pub confidence_metrics: ConfidenceMetrics,
103}
104
105#[derive(Debug, Clone)]
107pub struct ReliabilityDiagram {
108 pub confidence_bins: Vec<Float>,
110 pub bin_accuracies: Vec<Float>,
112 pub bin_proportions: Vec<Float>,
114 pub bin_confidences: Vec<Float>,
116 pub bin_counts: Vec<usize>,
118}
119
120#[derive(Debug, Clone)]
122pub struct ConfidenceMetrics {
123 pub avg_confidence_correct: Float,
125 pub avg_confidence_incorrect: Float,
127 pub confidence_accuracy_correlation: Float,
129 pub overconfidence_rate: Float,
131 pub underconfidence_rate: Float,
133}
134
135pub struct EnsembleAnalyzer {
137 pub importance_method: ImportanceAggregationMethod,
139 pub random_state: Option<u64>,
141 pub n_bootstrap: usize,
143 pub confidence_level: Float,
145}
146
147impl EnsembleAnalyzer {
148 pub fn new(importance_method: ImportanceAggregationMethod) -> Self {
150 Self {
151 importance_method,
152 random_state: None,
153 n_bootstrap: 100,
154 confidence_level: 0.95,
155 }
156 }
157
158 pub fn with_bootstrap(mut self, n_bootstrap: usize, confidence_level: Float) -> Self {
160 self.n_bootstrap = n_bootstrap;
161 self.confidence_level = confidence_level;
162 self
163 }
164
165 pub fn with_random_state(mut self, random_state: u64) -> Self {
167 self.random_state = Some(random_state);
168 self
169 }
170
171 pub fn analyze_feature_importance(
173 &self,
174 individual_importances: &Array2<Float>,
175 model_weights: Option<&Array1<Float>>,
176 ) -> Result<FeatureImportanceAnalysis> {
177 let n_models = individual_importances.nrows();
178 let n_features = individual_importances.ncols();
179
180 if n_models == 0 || n_features == 0 {
181 return Err(SklearsError::InvalidInput(
182 "Empty importance matrix".to_string(),
183 ));
184 }
185
186 let feature_importances =
188 self.aggregate_importances(individual_importances, model_weights)?;
189
190 let importance_std =
192 self.compute_importance_std(individual_importances, &feature_importances);
193
194 let feature_rankings = self.rank_features(&feature_importances);
196
197 let ranking_stability = self.compute_ranking_stability(individual_importances)?;
199
200 let top_features = self.get_top_features(&feature_importances, 10);
202
203 let confidence_intervals = self.compute_confidence_intervals(individual_importances)?;
205
206 let feature_interactions = self.compute_feature_interactions(individual_importances)?;
208
209 Ok(FeatureImportanceAnalysis {
210 feature_importances,
211 importance_std,
212 individual_importances: individual_importances.clone(),
213 feature_rankings,
214 ranking_stability,
215 top_features,
216 confidence_intervals,
217 feature_interactions,
218 })
219 }
220
221 fn aggregate_importances(
223 &self,
224 individual_importances: &Array2<Float>,
225 model_weights: Option<&Array1<Float>>,
226 ) -> Result<Array1<Float>> {
227 let n_models = individual_importances.nrows();
228 let n_features = individual_importances.ncols();
229
230 match &self.importance_method {
231 ImportanceAggregationMethod::Mean => {
232 Ok(individual_importances.mean_axis(Axis(0)).unwrap())
233 }
234
235 ImportanceAggregationMethod::WeightedMean(weights) => {
236 if weights.len() != n_models {
237 return Err(SklearsError::InvalidInput(
238 "Weight vector length must match number of models".to_string(),
239 ));
240 }
241
242 let mut aggregated = Array1::zeros(n_features);
243 let total_weight = weights.iter().sum::<Float>();
244
245 for i in 0..n_models {
246 let row = individual_importances.row(i).to_owned();
247 aggregated += &(row * weights[i]);
248 }
249
250 Ok(aggregated / total_weight)
251 }
252
253 ImportanceAggregationMethod::Median => {
254 let mut aggregated = Array1::zeros(n_features);
255 for j in 0..n_features {
256 let mut feature_importances: Vec<Float> =
257 individual_importances.column(j).iter().copied().collect();
258 feature_importances.sort_by(|a, b| a.partial_cmp(b).unwrap());
259
260 let median = if feature_importances.len() % 2 == 0 {
261 let mid = feature_importances.len() / 2;
262 (feature_importances[mid - 1] + feature_importances[mid]) / 2.0
263 } else {
264 feature_importances[feature_importances.len() / 2]
265 };
266
267 aggregated[j] = median;
268 }
269 Ok(aggregated)
270 }
271
272 ImportanceAggregationMethod::TopK(k) => {
273 let mean_importances = individual_importances.mean_axis(Axis(0)).unwrap();
275 let mut indexed_importances: Vec<(usize, Float)> = mean_importances
276 .iter()
277 .enumerate()
278 .map(|(i, &imp)| (i, imp))
279 .collect();
280
281 indexed_importances.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
282
283 let mut aggregated = Array1::zeros(n_features);
284 for i in 0..*k.min(&n_features) {
285 let (feature_idx, importance) = indexed_importances[i];
286 aggregated[feature_idx] = importance;
287 }
288
289 Ok(aggregated)
290 }
291
292 ImportanceAggregationMethod::RankBased => {
293 let mut aggregated = Array1::zeros(n_features);
294
295 for i in 0..n_models {
296 let row = individual_importances.row(i);
297 let ranks = self.compute_ranks(&row.to_owned());
298
299 for j in 0..n_features {
300 aggregated[j] += ranks[j];
301 }
302 }
303
304 aggregated /= n_models as Float;
306
307 let max_rank = aggregated.iter().fold(0.0f64, |a, &b| a.max(b));
309 for val in aggregated.iter_mut() {
310 *val = max_rank - *val;
311 }
312
313 Ok(aggregated)
314 }
315
316 ImportanceAggregationMethod::BayesianAveraging => {
317 self.bayesian_average_importances(individual_importances, model_weights)
319 }
320
321 ImportanceAggregationMethod::PermutationBased { n_repeats } => {
322 let mean_importances = individual_importances.mean_axis(Axis(0)).unwrap();
325 Ok(mean_importances)
326 }
327
328 ImportanceAggregationMethod::SHAPBased { background_samples } => {
329 let mean_importances = individual_importances.mean_axis(Axis(0)).unwrap();
332 Ok(mean_importances)
333 }
334 }
335 }
336
337 fn compute_ranks(&self, importances: &Array1<Float>) -> Array1<Float> {
339 let mut indexed: Vec<(usize, Float)> = importances
340 .iter()
341 .enumerate()
342 .map(|(i, &val)| (i, val))
343 .collect();
344
345 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
346
347 let mut ranks = Array1::zeros(importances.len());
348 for (rank, &(idx, _)) in indexed.iter().enumerate() {
349 ranks[idx] = rank as Float;
350 }
351
352 ranks
353 }
354
355 fn bayesian_average_importances(
357 &self,
358 individual_importances: &Array2<Float>,
359 model_weights: Option<&Array1<Float>>,
360 ) -> Result<Array1<Float>> {
361 let n_models = individual_importances.nrows();
362 let n_features = individual_importances.ncols();
363
364 let weights = if let Some(w) = model_weights {
366 w.clone()
367 } else {
368 Array1::from_elem(n_models, 1.0 / n_models as Float)
369 };
370
371 let mut aggregated = Array1::zeros(n_features);
373 for i in 0..n_models {
374 let row = individual_importances.row(i).to_owned();
375 aggregated += &(row * weights[i]);
376 }
377
378 Ok(aggregated)
379 }
380
381 fn compute_importance_std(
383 &self,
384 individual_importances: &Array2<Float>,
385 mean_importances: &Array1<Float>,
386 ) -> Array1<Float> {
387 let n_models = individual_importances.nrows();
388 let n_features = individual_importances.ncols();
389
390 let mut std_dev = Array1::zeros(n_features);
391
392 for j in 0..n_features {
393 let mean = mean_importances[j];
394 let variance = individual_importances
395 .column(j)
396 .iter()
397 .map(|&val| (val - mean).powi(2))
398 .sum::<Float>()
399 / n_models as Float;
400
401 std_dev[j] = variance.sqrt();
402 }
403
404 std_dev
405 }
406
407 fn rank_features(&self, feature_importances: &Array1<Float>) -> Vec<usize> {
409 let mut indexed: Vec<(usize, Float)> = feature_importances
410 .iter()
411 .enumerate()
412 .map(|(i, &val)| (i, val))
413 .collect();
414
415 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
416
417 indexed.into_iter().map(|(idx, _)| idx).collect()
418 }
419
420 fn compute_ranking_stability(&self, individual_importances: &Array2<Float>) -> Result<Float> {
422 let n_models = individual_importances.nrows();
423 if n_models < 2 {
424 return Ok(1.0);
425 }
426
427 let mut total_tau = 0.0;
428 let mut pair_count = 0;
429
430 for i in 0..n_models {
431 for j in (i + 1)..n_models {
432 let rank1 = self.compute_ranks(&individual_importances.row(i).to_owned());
433 let rank2 = self.compute_ranks(&individual_importances.row(j).to_owned());
434
435 let tau = self.kendall_tau(&rank1, &rank2)?;
436 total_tau += tau;
437 pair_count += 1;
438 }
439 }
440
441 Ok(total_tau / pair_count as Float)
442 }
443
444 fn kendall_tau(&self, rank1: &Array1<Float>, rank2: &Array1<Float>) -> Result<Float> {
446 if rank1.len() != rank2.len() {
447 return Err(SklearsError::InvalidInput(
448 "Rank vectors must have same length".to_string(),
449 ));
450 }
451
452 let n = rank1.len();
453 if n < 2 {
454 return Ok(1.0);
455 }
456
457 let mut concordant = 0;
458 let mut discordant = 0;
459
460 for i in 0..n {
461 for j in (i + 1)..n {
462 let diff1 = rank1[i] - rank1[j];
463 let diff2 = rank2[i] - rank2[j];
464
465 if diff1 * diff2 > 0.0 {
466 concordant += 1;
467 } else if diff1 * diff2 < 0.0 {
468 discordant += 1;
469 }
470 }
471 }
472
473 let total_pairs = (n * (n - 1)) / 2;
474 let tau = (concordant as Float - discordant as Float) / total_pairs as Float;
475
476 Ok(tau)
477 }
478
479 fn get_top_features(
481 &self,
482 feature_importances: &Array1<Float>,
483 k: usize,
484 ) -> Vec<(usize, Float)> {
485 let mut indexed: Vec<(usize, Float)> = feature_importances
486 .iter()
487 .enumerate()
488 .map(|(i, &val)| (i, val))
489 .collect();
490
491 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
492
493 indexed.into_iter().take(k).collect()
494 }
495
496 fn compute_confidence_intervals(
498 &self,
499 individual_importances: &Array2<Float>,
500 ) -> Result<Vec<(Float, Float)>> {
501 let n_models = individual_importances.nrows();
502 let n_features = individual_importances.ncols();
503
504 if self.n_bootstrap == 0 {
505 return Ok(vec![(0.0, 0.0); n_features]);
507 }
508
509 let mut rng = thread_rng();
510 let mut bootstrap_importances = Vec::with_capacity(self.n_bootstrap);
511
512 for _ in 0..self.n_bootstrap {
513 let mut bootstrap_matrix = Array2::zeros((n_models, n_features));
515 for i in 0..n_models {
516 let idx = rng.gen_range(0..n_models);
517 bootstrap_matrix
518 .row_mut(i)
519 .assign(&individual_importances.row(idx));
520 }
521
522 let mean_importance = bootstrap_matrix.mean_axis(Axis(0)).unwrap();
524 bootstrap_importances.push(mean_importance);
525 }
526
527 let alpha = 1.0 - self.confidence_level;
529 let lower_percentile = (alpha / 2.0) * 100.0;
530 let upper_percentile = (1.0 - alpha / 2.0) * 100.0;
531
532 let mut confidence_intervals = Vec::with_capacity(n_features);
533
534 for j in 0..n_features {
535 let mut feature_values: Vec<Float> = bootstrap_importances
536 .iter()
537 .map(|importance| importance[j])
538 .collect();
539
540 feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
541
542 let lower_idx =
543 ((lower_percentile / 100.0) * (feature_values.len() - 1) as Float) as usize;
544 let upper_idx =
545 ((upper_percentile / 100.0) * (feature_values.len() - 1) as Float) as usize;
546
547 let lower = feature_values[lower_idx];
548 let upper = feature_values[upper_idx];
549
550 confidence_intervals.push((lower, upper));
551 }
552
553 Ok(confidence_intervals)
554 }
555
556 fn compute_feature_interactions(
558 &self,
559 individual_importances: &Array2<Float>,
560 ) -> Result<Option<Array2<Float>>> {
561 let n_features = individual_importances.ncols();
562
563 let mut interactions = Array2::zeros((n_features, n_features));
565
566 for i in 0..n_features {
567 for j in 0..n_features {
568 if i == j {
569 interactions[[i, j]] = 1.0;
570 } else {
571 let corr = self.compute_feature_correlation(
572 &individual_importances.column(i).to_owned(),
573 &individual_importances.column(j).to_owned(),
574 )?;
575 interactions[[i, j]] = corr;
576 }
577 }
578 }
579
580 Ok(Some(interactions))
581 }
582
583 fn compute_feature_correlation(
585 &self,
586 feature1: &Array1<Float>,
587 feature2: &Array1<Float>,
588 ) -> Result<Float> {
589 if feature1.len() != feature2.len() || feature1.is_empty() {
590 return Ok(0.0);
591 }
592
593 let n = feature1.len() as Float;
594 let mean1 = feature1.sum() / n;
595 let mean2 = feature2.sum() / n;
596
597 let mut numerator = 0.0;
598 let mut sum_sq1 = 0.0;
599 let mut sum_sq2 = 0.0;
600
601 for i in 0..feature1.len() {
602 let diff1 = feature1[i] - mean1;
603 let diff2 = feature2[i] - mean2;
604
605 numerator += diff1 * diff2;
606 sum_sq1 += diff1 * diff1;
607 sum_sq2 += diff2 * diff2;
608 }
609
610 let denominator = (sum_sq1 * sum_sq2).sqrt();
611 let correlation = if denominator > 0.0 {
612 numerator / denominator
613 } else {
614 0.0
615 };
616
617 Ok(correlation)
618 }
619
620 pub fn quantify_uncertainty(
622 &self,
623 ensemble_predictions: &Array2<Float>,
624 true_labels: Option<&Array1<Float>>,
625 ) -> Result<UncertaintyQuantification> {
626 let n_models = ensemble_predictions.nrows();
627 let n_samples = ensemble_predictions.ncols();
628
629 if n_models == 0 || n_samples == 0 {
630 return Err(SklearsError::InvalidInput(
631 "Empty prediction matrix".to_string(),
632 ));
633 }
634
635 let epistemic_uncertainty = self.compute_epistemic_uncertainty(ensemble_predictions);
637
638 let aleatoric_uncertainty = self.compute_aleatoric_uncertainty(ensemble_predictions);
640
641 let total_uncertainty =
643 self.compute_total_uncertainty(&epistemic_uncertainty, &aleatoric_uncertainty);
644
645 let confidence_intervals =
647 self.compute_prediction_confidence_intervals(ensemble_predictions)?;
648
649 let prediction_diversity = self.compute_prediction_diversity(ensemble_predictions);
651
652 let uncertainty_decomposition = self.decompose_uncertainty(ensemble_predictions)?;
654
655 let calibration_metrics = if let Some(labels) = true_labels {
657 self.compute_calibration_metrics(ensemble_predictions, labels)?
658 } else {
659 self.default_calibration_metrics()
660 };
661
662 Ok(UncertaintyQuantification {
663 epistemic_uncertainty,
664 aleatoric_uncertainty,
665 total_uncertainty,
666 confidence_intervals,
667 prediction_diversity,
668 uncertainty_decomposition,
669 calibration_metrics,
670 })
671 }
672
673 fn compute_epistemic_uncertainty(&self, ensemble_predictions: &Array2<Float>) -> Array1<Float> {
675 let n_samples = ensemble_predictions.ncols();
676 let mut epistemic = Array1::zeros(n_samples);
677
678 for i in 0..n_samples {
679 let predictions = ensemble_predictions.column(i);
680 let mean_pred = predictions.mean().unwrap();
681 let variance = predictions
682 .iter()
683 .map(|&pred| (pred - mean_pred).powi(2))
684 .sum::<Float>()
685 / predictions.len() as Float;
686
687 epistemic[i] = variance.sqrt();
688 }
689
690 epistemic
691 }
692
693 fn compute_aleatoric_uncertainty(&self, ensemble_predictions: &Array2<Float>) -> Array1<Float> {
695 let n_samples = ensemble_predictions.ncols();
696 let mut aleatoric = Array1::zeros(n_samples);
697
698 for i in 0..n_samples {
701 let predictions = ensemble_predictions.column(i);
702 let mean_pred = predictions.mean().unwrap();
703
704 aleatoric[i] = mean_pred.abs() * 0.1; }
707
708 aleatoric
709 }
710
711 fn compute_total_uncertainty(
713 &self,
714 epistemic: &Array1<Float>,
715 aleatoric: &Array1<Float>,
716 ) -> Array1<Float> {
717 let mut total = Array1::zeros(epistemic.len());
718
719 for i in 0..epistemic.len() {
720 total[i] = (epistemic[i].powi(2) + aleatoric[i].powi(2)).sqrt();
722 }
723
724 total
725 }
726
727 fn compute_prediction_confidence_intervals(
729 &self,
730 ensemble_predictions: &Array2<Float>,
731 ) -> Result<Array2<Float>> {
732 let n_samples = ensemble_predictions.ncols();
733 let alpha = 1.0 - self.confidence_level;
734 let lower_percentile = alpha / 2.0;
735 let upper_percentile = 1.0 - alpha / 2.0;
736
737 let mut intervals = Array2::zeros((n_samples, 2));
738
739 for i in 0..n_samples {
740 let mut predictions: Vec<Float> = ensemble_predictions.column(i).to_vec();
741 predictions.sort_by(|a, b| a.partial_cmp(b).unwrap());
742
743 let lower_idx = (lower_percentile * (predictions.len() - 1) as Float) as usize;
744 let upper_idx = (upper_percentile * (predictions.len() - 1) as Float) as usize;
745
746 intervals[[i, 0]] = predictions[lower_idx];
747 intervals[[i, 1]] = predictions[upper_idx];
748 }
749
750 Ok(intervals)
751 }
752
753 fn compute_prediction_diversity(&self, ensemble_predictions: &Array2<Float>) -> Array1<Float> {
755 let n_samples = ensemble_predictions.ncols();
756 let mut diversity = Array1::zeros(n_samples);
757
758 for i in 0..n_samples {
759 let predictions = ensemble_predictions.column(i);
760 let mean_pred = predictions.mean().unwrap();
761
762 let std_pred = predictions
764 .iter()
765 .map(|&pred| (pred - mean_pred).powi(2))
766 .sum::<Float>()
767 / predictions.len() as Float;
768
769 diversity[i] = if mean_pred.abs() > 1e-8 {
770 std_pred.sqrt() / mean_pred.abs()
771 } else {
772 std_pred.sqrt()
773 };
774 }
775
776 diversity
777 }
778
779 fn decompose_uncertainty(
781 &self,
782 ensemble_predictions: &Array2<Float>,
783 ) -> Result<UncertaintyDecomposition> {
784 let n_samples = ensemble_predictions.ncols();
785
786 let model_disagreement = self.compute_epistemic_uncertainty(ensemble_predictions);
788 let data_uncertainty = Array1::from_elem(n_samples, 0.1); let feature_uncertainty = Array1::from_elem(n_samples, 0.05); let label_uncertainty = Array1::from_elem(n_samples, 0.03); let irreducible_uncertainty = Array1::from_elem(n_samples, 0.02); Ok(UncertaintyDecomposition {
794 model_disagreement,
795 data_uncertainty,
796 feature_uncertainty,
797 label_uncertainty,
798 irreducible_uncertainty,
799 })
800 }
801
802 fn compute_calibration_metrics(
804 &self,
805 ensemble_predictions: &Array2<Float>,
806 true_labels: &Array1<Float>,
807 ) -> Result<CalibrationMetrics> {
808 let n_samples = ensemble_predictions.ncols();
809
810 if true_labels.len() != n_samples {
811 return Err(SklearsError::InvalidInput(
812 "Prediction and label lengths must match".to_string(),
813 ));
814 }
815
816 let mean_predictions = ensemble_predictions.mean_axis(Axis(0)).unwrap();
818
819 let predicted_probs: Vec<Float> = mean_predictions.iter().copied().collect();
821 let true_binary: Vec<bool> = true_labels.iter().map(|&label| label > 0.5).collect();
822
823 let (ece, mce) = self.compute_calibration_errors(&predicted_probs, &true_binary)?;
825
826 let brier_score = self.compute_brier_score(&predicted_probs, &true_binary);
828
829 let reliability_diagram =
831 self.create_reliability_diagram(&predicted_probs, &true_binary)?;
832
833 let confidence_metrics = self.compute_confidence_metrics(&predicted_probs, &true_binary);
835
836 Ok(CalibrationMetrics {
837 expected_calibration_error: ece,
838 maximum_calibration_error: mce,
839 brier_score,
840 reliability_diagram,
841 confidence_metrics,
842 })
843 }
844
845 fn compute_calibration_errors(
847 &self,
848 predicted_probs: &[Float],
849 true_binary: &[bool],
850 ) -> Result<(Float, Float)> {
851 let n_bins = 10;
852 let bin_size = 1.0 / n_bins as Float;
853
854 let mut ece: Float = 0.0;
855 let mut mce: Float = 0.0;
856 let mut total_samples = 0;
857
858 for i in 0..n_bins {
859 let bin_lower = i as Float * bin_size;
860 let bin_upper = (i + 1) as Float * bin_size;
861
862 let bin_indices: Vec<usize> = predicted_probs
864 .iter()
865 .enumerate()
866 .filter(|(_, &prob)| prob >= bin_lower && prob < bin_upper)
867 .map(|(idx, _)| idx)
868 .collect();
869
870 if bin_indices.is_empty() {
871 continue;
872 }
873
874 let bin_size_actual = bin_indices.len();
875 total_samples += bin_size_actual;
876
877 let avg_confidence = bin_indices
879 .iter()
880 .map(|&idx| predicted_probs[idx])
881 .sum::<Float>()
882 / bin_size_actual as Float;
883
884 let bin_accuracy = bin_indices
885 .iter()
886 .map(|&idx| if true_binary[idx] { 1.0 } else { 0.0 })
887 .sum::<Float>()
888 / bin_size_actual as Float;
889
890 let calibration_error = (avg_confidence - bin_accuracy).abs();
891
892 ece += (bin_size_actual as Float / predicted_probs.len() as Float) * calibration_error;
894 mce = mce.max(calibration_error);
895 }
896
897 Ok((ece, mce))
898 }
899
900 fn compute_brier_score(&self, predicted_probs: &[Float], true_binary: &[bool]) -> Float {
902 predicted_probs
903 .iter()
904 .zip(true_binary.iter())
905 .map(|(&prob, &is_true)| {
906 let true_prob = if is_true { 1.0 } else { 0.0 };
907 (prob - true_prob).powi(2)
908 })
909 .sum::<Float>()
910 / predicted_probs.len() as Float
911 }
912
913 fn create_reliability_diagram(
915 &self,
916 predicted_probs: &[Float],
917 true_binary: &[bool],
918 ) -> Result<ReliabilityDiagram> {
919 let n_bins = 10;
920 let bin_size = 1.0 / n_bins as Float;
921
922 let mut confidence_bins = Vec::with_capacity(n_bins + 1);
923 let mut bin_accuracies = Vec::with_capacity(n_bins);
924 let mut bin_proportions = Vec::with_capacity(n_bins);
925 let mut bin_confidences = Vec::with_capacity(n_bins);
926 let mut bin_counts = Vec::with_capacity(n_bins);
927
928 for i in 0..=n_bins {
930 confidence_bins.push(i as Float * bin_size);
931 }
932
933 for i in 0..n_bins {
934 let bin_lower = i as Float * bin_size;
935 let bin_upper = (i + 1) as Float * bin_size;
936
937 let bin_indices: Vec<usize> = predicted_probs
939 .iter()
940 .enumerate()
941 .filter(|(_, &prob)| prob >= bin_lower && prob < bin_upper)
942 .map(|(idx, _)| idx)
943 .collect();
944
945 let bin_count = bin_indices.len();
946 bin_counts.push(bin_count);
947
948 if bin_count == 0 {
949 bin_accuracies.push(0.0);
950 bin_proportions.push(0.0);
951 bin_confidences.push((bin_lower + bin_upper) / 2.0);
952 } else {
953 let avg_confidence = bin_indices
954 .iter()
955 .map(|&idx| predicted_probs[idx])
956 .sum::<Float>()
957 / bin_count as Float;
958
959 let bin_accuracy = bin_indices
960 .iter()
961 .map(|&idx| if true_binary[idx] { 1.0 } else { 0.0 })
962 .sum::<Float>()
963 / bin_count as Float;
964
965 let bin_proportion = bin_count as Float / predicted_probs.len() as Float;
966
967 bin_accuracies.push(bin_accuracy);
968 bin_proportions.push(bin_proportion);
969 bin_confidences.push(avg_confidence);
970 }
971 }
972
973 Ok(ReliabilityDiagram {
974 confidence_bins,
975 bin_accuracies,
976 bin_proportions,
977 bin_confidences,
978 bin_counts,
979 })
980 }
981
982 fn compute_confidence_metrics(
984 &self,
985 predicted_probs: &[Float],
986 true_binary: &[bool],
987 ) -> ConfidenceMetrics {
988 let mut correct_confidences = Vec::new();
989 let mut incorrect_confidences = Vec::new();
990
991 for (&prob, &is_true) in predicted_probs.iter().zip(true_binary.iter()) {
992 let predicted_class = prob > 0.5;
993 if predicted_class == is_true {
994 correct_confidences.push(prob.max(1.0 - prob)); } else {
996 incorrect_confidences.push(prob.max(1.0 - prob));
997 }
998 }
999
1000 let avg_confidence_correct = if correct_confidences.is_empty() {
1001 0.0
1002 } else {
1003 correct_confidences.iter().sum::<Float>() / correct_confidences.len() as Float
1004 };
1005
1006 let avg_confidence_incorrect = if incorrect_confidences.is_empty() {
1007 0.0
1008 } else {
1009 incorrect_confidences.iter().sum::<Float>() / incorrect_confidences.len() as Float
1010 };
1011
1012 let confidence_accuracy_correlation =
1014 self.compute_confidence_accuracy_correlation(predicted_probs, true_binary);
1015
1016 let threshold = 0.8;
1018 let overconfident_count = predicted_probs
1019 .iter()
1020 .zip(true_binary.iter())
1021 .filter(|(&prob, &is_true)| {
1022 let predicted_class = prob > 0.5;
1023 prob.max(1.0 - prob) > threshold && predicted_class != is_true
1024 })
1025 .count();
1026
1027 let underconfident_count = predicted_probs
1028 .iter()
1029 .zip(true_binary.iter())
1030 .filter(|(&prob, &is_true)| {
1031 let predicted_class = prob > 0.5;
1032 prob.max(1.0 - prob) < (1.0 - threshold) && predicted_class == is_true
1033 })
1034 .count();
1035
1036 let overconfidence_rate = overconfident_count as Float / predicted_probs.len() as Float;
1037 let underconfidence_rate = underconfident_count as Float / predicted_probs.len() as Float;
1038
1039 ConfidenceMetrics {
1040 avg_confidence_correct,
1041 avg_confidence_incorrect,
1042 confidence_accuracy_correlation,
1043 overconfidence_rate,
1044 underconfidence_rate,
1045 }
1046 }
1047
1048 fn compute_confidence_accuracy_correlation(
1050 &self,
1051 predicted_probs: &[Float],
1052 true_binary: &[bool],
1053 ) -> Float {
1054 let confidences: Vec<Float> = predicted_probs
1055 .iter()
1056 .map(|&prob| prob.max(1.0 - prob))
1057 .collect();
1058 let accuracies: Vec<Float> = predicted_probs
1059 .iter()
1060 .zip(true_binary.iter())
1061 .map(|(&prob, &is_true)| {
1062 let predicted_class = prob > 0.5;
1063 if predicted_class == is_true {
1064 1.0
1065 } else {
1066 0.0
1067 }
1068 })
1069 .collect();
1070
1071 let n = confidences.len() as Float;
1073 let mean_conf = confidences.iter().sum::<Float>() / n;
1074 let mean_acc = accuracies.iter().sum::<Float>() / n;
1075
1076 let mut numerator = 0.0;
1077 let mut sum_sq_conf = 0.0;
1078 let mut sum_sq_acc = 0.0;
1079
1080 for i in 0..confidences.len() {
1081 let diff_conf = confidences[i] - mean_conf;
1082 let diff_acc = accuracies[i] - mean_acc;
1083
1084 numerator += diff_conf * diff_acc;
1085 sum_sq_conf += diff_conf * diff_conf;
1086 sum_sq_acc += diff_acc * diff_acc;
1087 }
1088
1089 let denominator = (sum_sq_conf * sum_sq_acc).sqrt();
1090 if denominator > 0.0 {
1091 numerator / denominator
1092 } else {
1093 0.0
1094 }
1095 }
1096
1097 fn default_calibration_metrics(&self) -> CalibrationMetrics {
1099 CalibrationMetrics {
1100 expected_calibration_error: 0.0,
1101 maximum_calibration_error: 0.0,
1102 brier_score: 0.0,
1103 reliability_diagram: ReliabilityDiagram {
1104 confidence_bins: vec![],
1105 bin_accuracies: vec![],
1106 bin_proportions: vec![],
1107 bin_confidences: vec![],
1108 bin_counts: vec![],
1109 },
1110 confidence_metrics: ConfidenceMetrics {
1111 avg_confidence_correct: 0.0,
1112 avg_confidence_incorrect: 0.0,
1113 confidence_accuracy_correlation: 0.0,
1114 overconfidence_rate: 0.0,
1115 underconfidence_rate: 0.0,
1116 },
1117 }
1118 }
1119}
1120
1121impl Default for EnsembleAnalyzer {
1122 fn default() -> Self {
1123 Self::new(ImportanceAggregationMethod::Mean)
1124 }
1125}
1126
1127impl EnsembleAnalyzer {
1129 pub fn mean_importance() -> Self {
1131 Self::new(ImportanceAggregationMethod::Mean)
1132 }
1133
1134 pub fn weighted_importance(weights: Vec<Float>) -> Self {
1136 Self::new(ImportanceAggregationMethod::WeightedMean(weights))
1137 }
1138
1139 pub fn robust_importance() -> Self {
1141 Self::new(ImportanceAggregationMethod::Median)
1142 }
1143
1144 pub fn rank_based_importance() -> Self {
1146 Self::new(ImportanceAggregationMethod::RankBased)
1147 }
1148
1149 pub fn permutation_importance(n_repeats: usize) -> Self {
1151 Self::new(ImportanceAggregationMethod::PermutationBased { n_repeats })
1152 }
1153
1154 pub fn shap_importance(background_samples: usize) -> Self {
1156 Self::new(ImportanceAggregationMethod::SHAPBased { background_samples })
1157 }
1158}
1159
1160#[allow(non_snake_case)]
1161#[cfg(test)]
1162mod tests {
1163 use super::*;
1164 use scirs2_core::ndarray::array;
1165
1166 #[test]
1167 fn test_ensemble_analyzer_creation() {
1168 let analyzer = EnsembleAnalyzer::default();
1169 assert!(matches!(
1170 analyzer.importance_method,
1171 ImportanceAggregationMethod::Mean
1172 ));
1173 }
1174
1175 #[test]
1176 fn test_feature_importance_analysis() {
1177 let analyzer = EnsembleAnalyzer::mean_importance();
1178
1179 let importances = array![[0.5, 0.3, 0.2], [0.4, 0.4, 0.2], [0.6, 0.2, 0.2]];
1181
1182 let analysis = analyzer
1183 .analyze_feature_importance(&importances, None)
1184 .unwrap();
1185
1186 assert_eq!(analysis.feature_importances.len(), 3);
1187 assert_eq!(analysis.importance_std.len(), 3);
1188 assert_eq!(analysis.feature_rankings.len(), 3);
1189 assert!(analysis.ranking_stability >= 0.0 && analysis.ranking_stability <= 1.0);
1190 assert_eq!(analysis.top_features.len(), 3);
1191 }
1192
1193 #[test]
1194 fn test_weighted_importance_aggregation() {
1195 let weights = vec![0.5, 0.3, 0.2];
1196 let analyzer = EnsembleAnalyzer::weighted_importance(weights);
1197
1198 let importances = array![[0.5, 0.3, 0.2], [0.4, 0.4, 0.2], [0.6, 0.2, 0.2]];
1199
1200 let analysis = analyzer
1201 .analyze_feature_importance(&importances, None)
1202 .unwrap();
1203
1204 assert_eq!(analysis.feature_importances.len(), 3);
1205 let simple_mean = importances.mean_axis(Axis(0)).unwrap();
1207 assert!((analysis.feature_importances[0] - simple_mean[0]).abs() > 1e-10);
1208 }
1209
1210 #[test]
1211 fn test_median_aggregation() {
1212 let analyzer = EnsembleAnalyzer::robust_importance();
1213
1214 let importances = array![
1215 [0.1, 0.3, 0.6], [0.5, 0.3, 0.2],
1217 [0.4, 0.4, 0.2]
1218 ];
1219
1220 let analysis = analyzer
1221 .analyze_feature_importance(&importances, None)
1222 .unwrap();
1223
1224 assert_eq!(analysis.feature_importances.len(), 3);
1225 assert!((analysis.feature_importances[0] - 0.4).abs() < 1e-6); }
1228
1229 #[test]
1230 fn test_uncertainty_quantification() {
1231 let analyzer = EnsembleAnalyzer::default();
1232
1233 let predictions = array![
1235 [0.1, 0.8, 0.3, 0.7],
1236 [0.2, 0.9, 0.4, 0.6],
1237 [0.1, 0.7, 0.5, 0.8]
1238 ];
1239
1240 let uncertainty = analyzer.quantify_uncertainty(&predictions, None).unwrap();
1241
1242 assert_eq!(uncertainty.epistemic_uncertainty.len(), 4);
1243 assert_eq!(uncertainty.aleatoric_uncertainty.len(), 4);
1244 assert_eq!(uncertainty.total_uncertainty.len(), 4);
1245 assert_eq!(uncertainty.confidence_intervals.nrows(), 4);
1246 assert_eq!(uncertainty.confidence_intervals.ncols(), 2);
1247 assert_eq!(uncertainty.prediction_diversity.len(), 4);
1248 }
1249
1250 #[test]
1251 fn test_calibration_metrics() {
1252 let analyzer = EnsembleAnalyzer::default();
1253
1254 let predictions = array![
1255 [0.1, 0.8, 0.3, 0.7],
1256 [0.2, 0.9, 0.4, 0.6],
1257 [0.1, 0.7, 0.5, 0.8]
1258 ];
1259
1260 let true_labels = array![0.0, 1.0, 0.0, 1.0];
1261
1262 let uncertainty = analyzer
1263 .quantify_uncertainty(&predictions, Some(&true_labels))
1264 .unwrap();
1265
1266 assert!(uncertainty.calibration_metrics.expected_calibration_error >= 0.0);
1267 assert!(uncertainty.calibration_metrics.maximum_calibration_error >= 0.0);
1268 assert!(uncertainty.calibration_metrics.brier_score >= 0.0);
1269 assert!(!uncertainty
1270 .calibration_metrics
1271 .reliability_diagram
1272 .confidence_bins
1273 .is_empty());
1274 }
1275
1276 #[test]
1277 fn test_rank_based_aggregation() {
1278 let analyzer = EnsembleAnalyzer::rank_based_importance();
1279
1280 let importances = array![
1281 [0.1, 0.5, 0.4], [0.3, 0.2, 0.5], [0.6, 0.1, 0.3] ];
1285
1286 let analysis = analyzer
1287 .analyze_feature_importance(&importances, None)
1288 .unwrap();
1289
1290 assert_eq!(analysis.feature_importances.len(), 3);
1291 let sum_importance = analysis.feature_importances.sum();
1293 assert!(sum_importance > 0.0);
1294 }
1295
1296 #[test]
1297 fn test_kendall_tau_correlation() {
1298 let analyzer = EnsembleAnalyzer::default();
1299
1300 let rank1 = array![1.0, 2.0, 3.0, 4.0];
1301 let rank2 = array![1.0, 2.0, 3.0, 4.0]; let tau = analyzer.kendall_tau(&rank1, &rank2).unwrap();
1304 assert!((tau - 1.0).abs() < 1e-6);
1305
1306 let rank3 = array![4.0, 3.0, 2.0, 1.0]; let tau_neg = analyzer.kendall_tau(&rank1, &rank3).unwrap();
1308 assert!((tau_neg + 1.0).abs() < 1e-6);
1309 }
1310
1311 #[test]
1312 fn test_top_features_extraction() {
1313 let analyzer = EnsembleAnalyzer::default();
1314
1315 let importances = array![0.1, 0.5, 0.3, 0.8, 0.2];
1316 let top_features = analyzer.get_top_features(&importances, 3);
1317
1318 assert_eq!(top_features.len(), 3);
1319 assert_eq!(top_features[0].0, 3); assert_eq!(top_features[1].0, 1); assert_eq!(top_features[2].0, 2); }
1323
1324 #[test]
1325 fn test_confidence_intervals() {
1326 let analyzer = EnsembleAnalyzer::default().with_bootstrap(50, 0.95);
1327
1328 let importances = array![
1329 [0.5, 0.3, 0.2],
1330 [0.4, 0.4, 0.2],
1331 [0.6, 0.2, 0.2],
1332 [0.5, 0.3, 0.2],
1333 [0.4, 0.5, 0.1]
1334 ];
1335
1336 let intervals = analyzer.compute_confidence_intervals(&importances).unwrap();
1337
1338 assert_eq!(intervals.len(), 3); for (lower, upper) in &intervals {
1340 assert!(lower <= upper);
1341 }
1342 }
1343
1344 #[test]
1345 fn test_prediction_confidence_intervals() {
1346 let analyzer = EnsembleAnalyzer::default().with_bootstrap(10, 0.9);
1347
1348 let predictions = array![
1349 [0.1, 0.8, 0.3],
1350 [0.2, 0.9, 0.4],
1351 [0.0, 0.7, 0.5],
1352 [0.3, 0.8, 0.2]
1353 ];
1354
1355 let intervals = analyzer
1356 .compute_prediction_confidence_intervals(&predictions)
1357 .unwrap();
1358
1359 assert_eq!(intervals.nrows(), 3); assert_eq!(intervals.ncols(), 2); for i in 0..3 {
1363 assert!(intervals[[i, 0]] <= intervals[[i, 1]]); }
1365 }
1366}