use scirs2_core::ndarray::{Array1, Array2};
#[derive(Debug, Clone, Default)]
pub struct PerformanceMetrics {
pub training_metrics: Vec<EpochMetrics>,
pub validation_metrics: Vec<EpochMetrics>,
pub test_metrics: Option<TestMetrics>,
pub confusion_matrix: Option<Array2<usize>>,
pub feature_importance: Option<Array1<f64>>,
}
#[derive(Debug, Clone)]
pub struct EpochMetrics {
pub epoch: usize,
pub loss: f64,
pub accuracy: Option<f64>,
pub precision: Option<Vec<f64>>,
pub recall: Option<Vec<f64>>,
pub f1_score: Option<Vec<f64>>,
pub learning_rate: f64,
}
#[derive(Debug, Clone)]
pub struct TestMetrics {
pub accuracy: f64,
pub precision: Vec<f64>,
pub recall: Vec<f64>,
pub f1_score: Vec<f64>,
pub auc_roc: f64,
pub auc_pr: f64,
pub mcc: f64,
}
#[derive(Debug, Clone, Default)]
pub struct UncertaintyQuantification {
pub bayesian_config: Option<BayesianConfig>,
pub mc_dropout_config: Option<MCDropoutConfig>,
pub ensemble_config: Option<EnsembleConfig>,
pub conformal_config: Option<ConformalConfig>,
}
#[derive(Debug, Clone)]
pub struct BayesianConfig {
pub prior_params: PriorParams,
pub variational_method: VariationalMethod,
pub mc_samples: usize,
pub kl_weight: f64,
}
#[derive(Debug, Clone)]
pub struct PriorParams {
pub weight_mean: f64,
pub weight_std: f64,
pub bias_mean: f64,
pub bias_std: f64,
}
#[derive(Debug, Clone, Copy)]
pub enum VariationalMethod {
MeanField,
MatrixVariate,
NormalizingFlows,
}
#[derive(Debug, Clone)]
pub struct MCDropoutConfig {
pub dropoutrate: f64,
pub num_samples: usize,
pub stochastic_masks: bool,
}
#[derive(Debug, Clone)]
pub struct EnsembleConfig {
pub num_models: usize,
pub aggregation_method: EnsembleAggregation,
pub diversity_method: DiversityMethod,
}
#[derive(Debug, Clone, Copy)]
pub enum EnsembleAggregation {
Average,
WeightedAverage,
Voting,
Stacking,
}
#[derive(Debug, Clone, Copy)]
pub enum DiversityMethod {
Bagging,
RandomInit,
DifferentArchitectures,
AdversarialTraining,
}
#[derive(Debug, Clone)]
pub struct ConformalConfig {
pub confidence_level: f64,
pub score_function: ConformityScore,
pub calibration_size: usize,
}
#[derive(Debug, Clone, Copy)]
pub enum ConformityScore {
AbsoluteResiduals,
NormalizedResiduals,
SoftmaxScores,
MarginScores,
}