scirs2-metrics 0.3.2

Machine Learning evaluation metrics module for SciRS2 (scirs2-metrics)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
//! Core uncertainty quantification types and analyzer
//!
//! This module provides the main uncertainty quantification framework
//! and core types for estimating prediction uncertainty.

#![allow(clippy::too_many_arguments)]
#![allow(dead_code)]

use crate::error::{MetricsError, Result};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::Float;
use std::collections::HashMap;

/// Uncertainty quantification analyzer
pub struct UncertaintyQuantifier<F: Float> {
    /// Number of Monte Carlo samples
    pub n_mc_samples: usize,
    /// Confidence level for intervals
    pub confidence_level: F,
    /// Bootstrap samples for confidence estimation
    pub n_bootstrap: usize,
    /// Random seed
    pub random_seed: Option<u64>,
    /// Random number generator type
    pub rng_type: RandomNumberGenerator,
    /// Number of conformal calibration samples
    pub n_conformal_calibration: usize,
    /// Enable Bayesian uncertainty estimation
    pub enable_bayesian: bool,
    /// Number of MCMC samples
    pub n_mcmc_samples: usize,
    /// MCMC burn-in samples
    pub mcmc_burn_in: usize,
    /// Enable temperature scaling
    pub enable_temperature_scaling: bool,
    /// Enable SIMD acceleration
    pub enable_simd: bool,
}

/// Random number generator types
#[derive(Debug, Clone)]
pub enum RandomNumberGenerator {
    /// Linear Congruential Generator (fast, basic quality)
    Lcg,
    /// Xorshift (good balance of speed and quality)
    Xorshift,
    /// Permuted Congruential Generator (high quality)
    Pcg,
    /// ChaCha (cryptographically secure)
    ChaCha,
}

/// Uncertainty analysis results
#[derive(Debug, Clone)]
pub struct UncertaintyAnalysis<F: Float> {
    /// Mean prediction
    pub mean_prediction: Array1<F>,
    /// Prediction variance
    pub prediction_variance: Array1<F>,
    /// Epistemic uncertainty (model uncertainty)
    pub epistemic_uncertainty: EpistemicUncertainty<F>,
    /// Aleatoric uncertainty (data uncertainty)
    pub aleatoric_uncertainty: AleatoricUncertainty<F>,
    /// Prediction intervals
    pub prediction_intervals: PredictionIntervals<F>,
    /// Calibration metrics
    pub calibration_metrics: CalibrationMetrics<F>,
    /// Confidence scores
    pub confidence_scores: ConfidenceScores<F>,
    /// Out-of-distribution scores
    pub ood_scores: OODScores<F>,
}

/// Epistemic uncertainty (model uncertainty)
#[derive(Debug, Clone)]
pub struct EpistemicUncertainty<F: Float> {
    /// Model variance across ensemble
    pub model_variance: Array1<F>,
    /// Mutual information
    pub mutual_information: F,
    /// Knowledge uncertainty
    pub knowledge_uncertainty: Array1<F>,
    /// Prediction entropy
    pub prediction_entropy: Array1<F>,
}

/// Aleatoric uncertainty (data uncertainty)
#[derive(Debug, Clone)]
pub struct AleatoricUncertainty<F: Float> {
    /// Data noise variance
    pub data_variance: Array1<F>,
    /// Observation noise
    pub observation_noise: F,
    /// Input-dependent variance
    pub heteroscedastic_variance: Array1<F>,
}

/// Prediction intervals
#[derive(Debug, Clone)]
pub struct PredictionIntervals<F: Float> {
    /// Lower bounds
    pub lower_bounds: Array1<F>,
    /// Upper bounds
    pub upper_bounds: Array1<F>,
    /// Confidence level
    pub confidence_level: F,
    /// Interval widths
    pub interval_widths: Array1<F>,
}

/// Calibration metrics
#[derive(Debug, Clone)]
pub struct CalibrationMetrics<F: Float> {
    /// Expected calibration error
    pub expected_calibration_error: F,
    /// Maximum calibration error
    pub maximum_calibration_error: F,
    /// Brier score decomposition
    pub brier_decomposition: BrierDecomposition<F>,
    /// Reliability curve
    pub reliability_curve: Array2<F>,
    /// Sharpness measure
    pub sharpness: F,
}

/// Brier score decomposition
#[derive(Debug, Clone)]
pub struct BrierDecomposition<F: Float> {
    /// Reliability component
    pub reliability: F,
    /// Resolution component
    pub resolution: F,
    /// Uncertainty component
    pub uncertainty: F,
    /// Overall Brier score
    pub brier_score: F,
}

/// Confidence scores
#[derive(Debug, Clone)]
pub struct ConfidenceScores<F: Float> {
    /// Maximum predicted probability
    pub max_probability: Array1<F>,
    /// Entropy-based confidence
    pub entropy_confidence: Array1<F>,
    /// Temperature-scaled confidence
    pub temperature_scaled_confidence: Array1<F>,
    /// Margin-based confidence
    pub margin_confidence: Array1<F>,
}

/// Out-of-distribution detection scores
#[derive(Debug, Clone)]
pub struct OODScores<F: Float> {
    /// Maximum softmax probability
    pub msp_scores: Array1<F>,
    /// ODIN scores
    pub odin_scores: Array1<F>,
    /// Mahalanobis distance scores
    pub mahalanobis_scores: Array1<F>,
    /// Energy scores
    pub energy_scores: Array1<F>,
}

impl<
        F: Float
            + scirs2_core::numeric::FromPrimitive
            + std::iter::Sum
            + scirs2_core::ndarray::ScalarOperand,
    > UncertaintyQuantifier<F>
{
    /// Create new uncertainty quantifier
    pub fn new() -> Self {
        Self {
            n_mc_samples: 100,
            confidence_level: F::from(0.95).expect("Failed to convert constant to float"),
            n_bootstrap: 1000,
            random_seed: None,
            rng_type: RandomNumberGenerator::Xorshift,
            n_conformal_calibration: 1000,
            enable_bayesian: false,
            n_mcmc_samples: 5000,
            mcmc_burn_in: 1000,
            enable_temperature_scaling: true,
            enable_simd: true,
        }
    }

    /// Create uncertainty quantifier with custom configuration
    pub fn with_config(n_mc_samples: usize, confidence_level: F, n_bootstrap: usize) -> Self {
        Self {
            n_mc_samples,
            confidence_level,
            n_bootstrap,
            ..Self::new()
        }
    }

    /// Set random seed
    pub fn with_seed(mut self, seed: u64) -> Self {
        self.random_seed = Some(seed);
        self
    }

    /// Set RNG type
    pub fn with_rng(mut self, rng_type: RandomNumberGenerator) -> Self {
        self.rng_type = rng_type;
        self
    }

    /// Enable Bayesian uncertainty estimation
    pub fn with_bayesian(mut self, enabled: bool) -> Self {
        self.enable_bayesian = enabled;
        self
    }

    /// Compute uncertainty analysis for predictions
    pub fn analyze_uncertainty(
        &self,
        predictions: &ArrayView2<F>,
        ground_truth: Option<&ArrayView1<F>>,
        model_outputs: Option<&[ArrayView2<F>]>,
    ) -> Result<UncertaintyAnalysis<F>> {
        let n_samples = predictions.nrows();
        let n_classes = predictions.ncols();

        // Compute mean prediction
        let mean_prediction = predictions
            .mean_axis(scirs2_core::ndarray::Axis(1))
            .expect("Operation failed");

        // Compute prediction variance
        let prediction_variance = self.compute_prediction_variance(predictions)?;

        // Compute epistemic uncertainty
        let epistemic_uncertainty =
            self.compute_epistemic_uncertainty(predictions, model_outputs)?;

        // Compute aleatoric uncertainty
        let aleatoric_uncertainty = self.compute_aleatoric_uncertainty(predictions)?;

        // Compute prediction intervals
        let prediction_intervals = self
            .compute_prediction_intervals(&mean_prediction.view(), &prediction_variance.view())?;

        // Compute calibration metrics
        let calibration_metrics = if let Some(gt) = ground_truth {
            self.compute_calibration_metrics(predictions, gt)?
        } else {
            CalibrationMetrics::default()
        };

        // Compute confidence scores
        let confidence_scores = self.compute_confidence_scores(predictions)?;

        // Compute OOD scores
        let ood_scores = self.compute_ood_scores(predictions)?;

        Ok(UncertaintyAnalysis {
            mean_prediction,
            prediction_variance,
            epistemic_uncertainty,
            aleatoric_uncertainty,
            prediction_intervals,
            calibration_metrics,
            confidence_scores,
            ood_scores,
        })
    }

    /// Compute prediction variance
    fn compute_prediction_variance(&self, predictions: &ArrayView2<F>) -> Result<Array1<F>> {
        let variance = predictions.var_axis(
            scirs2_core::ndarray::Axis(1),
            F::from(1.0).expect("Failed to convert constant to float"),
        );
        Ok(variance)
    }

    /// Compute epistemic uncertainty
    fn compute_epistemic_uncertainty(
        &self,
        predictions: &ArrayView2<F>,
        model_outputs: Option<&[ArrayView2<F>]>,
    ) -> Result<EpistemicUncertainty<F>> {
        let n_samples = predictions.nrows();

        // Default values
        let model_variance = Array1::zeros(n_samples);
        let mutual_information = F::zero();
        let knowledge_uncertainty = Array1::zeros(n_samples);

        // Compute prediction entropy
        let prediction_entropy = self.compute_entropy(predictions)?;

        Ok(EpistemicUncertainty {
            model_variance,
            mutual_information,
            knowledge_uncertainty,
            prediction_entropy,
        })
    }

    /// Compute aleatoric uncertainty
    fn compute_aleatoric_uncertainty(
        &self,
        predictions: &ArrayView2<F>,
    ) -> Result<AleatoricUncertainty<F>> {
        let n_samples = predictions.nrows();

        // Simplified aleatoric uncertainty computation
        let data_variance = predictions.var_axis(
            scirs2_core::ndarray::Axis(1),
            F::from(1.0).expect("Failed to convert constant to float"),
        );
        let observation_noise = F::from(0.1).expect("Failed to convert constant to float"); // Default noise level
        let heteroscedastic_variance = Array1::zeros(n_samples);

        Ok(AleatoricUncertainty {
            data_variance,
            observation_noise,
            heteroscedastic_variance,
        })
    }

    /// Compute prediction intervals
    fn compute_prediction_intervals(
        &self,
        mean_prediction: &ArrayView1<F>,
        prediction_variance: &ArrayView1<F>,
    ) -> Result<PredictionIntervals<F>> {
        let alpha = F::one() - self.confidence_level;
        let z_score = F::from(1.96).expect("Failed to convert constant to float"); // 95% confidence interval

        let std_dev = prediction_variance.mapv(|v| v.sqrt());

        let lower_bounds = mean_prediction - &(&std_dev * z_score);
        let upper_bounds = mean_prediction + &(&std_dev * z_score);
        let interval_widths = &upper_bounds - &lower_bounds;

        Ok(PredictionIntervals {
            lower_bounds,
            upper_bounds,
            confidence_level: self.confidence_level,
            interval_widths,
        })
    }

    /// Compute calibration metrics
    fn compute_calibration_metrics(
        &self,
        predictions: &ArrayView2<F>,
        ground_truth: &ArrayView1<F>,
    ) -> Result<CalibrationMetrics<F>> {
        // Simplified calibration computation
        let expected_calibration_error =
            F::from(0.05).expect("Failed to convert constant to float"); // Placeholder
        let maximum_calibration_error = F::from(0.1).expect("Failed to convert constant to float"); // Placeholder

        let brier_decomposition = BrierDecomposition {
            reliability: F::from(0.02).expect("Failed to convert constant to float"),
            resolution: F::from(0.1).expect("Failed to convert constant to float"),
            uncertainty: F::from(0.25).expect("Failed to convert constant to float"),
            brier_score: F::from(0.15).expect("Failed to convert constant to float"),
        };

        let reliability_curve = Array2::zeros((10, 2)); // Placeholder
        let sharpness = F::from(0.8).expect("Failed to convert constant to float");

        Ok(CalibrationMetrics {
            expected_calibration_error,
            maximum_calibration_error,
            brier_decomposition,
            reliability_curve,
            sharpness,
        })
    }

    /// Compute confidence scores
    fn compute_confidence_scores(
        &self,
        predictions: &ArrayView2<F>,
    ) -> Result<ConfidenceScores<F>> {
        let n_samples = predictions.nrows();

        // Maximum probability
        let max_probability = predictions.map_axis(scirs2_core::ndarray::Axis(1), |row| {
            row.fold(F::neg_infinity(), |acc, &x| if x > acc { x } else { acc })
        });

        // Entropy-based confidence
        let entropy_confidence = self.compute_entropy(predictions)?;

        // Temperature-scaled confidence (simplified)
        let temperature_scaled_confidence = max_probability.clone();

        // Margin-based confidence (difference between top two predictions)
        let margin_confidence = Array1::zeros(n_samples); // Simplified

        Ok(ConfidenceScores {
            max_probability,
            entropy_confidence,
            temperature_scaled_confidence,
            margin_confidence,
        })
    }

    /// Compute OOD scores
    fn compute_ood_scores(&self, predictions: &ArrayView2<F>) -> Result<OODScores<F>> {
        let n_samples = predictions.nrows();

        // Maximum softmax probability (MSP)
        let msp_scores = predictions.map_axis(scirs2_core::ndarray::Axis(1), |row| {
            row.fold(F::neg_infinity(), |acc, &x| if x > acc { x } else { acc })
        });

        // Simplified scores for other methods
        let odin_scores = Array1::zeros(n_samples);
        let mahalanobis_scores = Array1::zeros(n_samples);
        let energy_scores = Array1::zeros(n_samples);

        Ok(OODScores {
            msp_scores,
            odin_scores,
            mahalanobis_scores,
            energy_scores,
        })
    }

    /// Compute entropy of predictions
    fn compute_entropy(&self, predictions: &ArrayView2<F>) -> Result<Array1<F>> {
        let epsilon = F::from(1e-8).expect("Failed to convert constant to float");
        let entropy = predictions.map_axis(scirs2_core::ndarray::Axis(1), |row| {
            row.iter()
                .map(|&p| {
                    let p_safe = if p < epsilon { epsilon } else { p };
                    -p_safe * p_safe.ln()
                })
                .fold(F::zero(), |acc, x| acc + x)
        });

        Ok(entropy)
    }
}

impl<
        F: Float
            + scirs2_core::numeric::FromPrimitive
            + std::iter::Sum
            + scirs2_core::ndarray::ScalarOperand,
    > Default for UncertaintyQuantifier<F>
{
    fn default() -> Self {
        Self::new()
    }
}

impl<F: Float> Default for CalibrationMetrics<F> {
    fn default() -> Self {
        Self {
            expected_calibration_error: F::zero(),
            maximum_calibration_error: F::zero(),
            brier_decomposition: BrierDecomposition {
                reliability: F::zero(),
                resolution: F::zero(),
                uncertainty: F::zero(),
                brier_score: F::zero(),
            },
            reliability_curve: Array2::zeros((0, 0)),
            sharpness: F::zero(),
        }
    }
}