use scirs2_core::ndarray::{Array1, Array2};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum UncertaintyType {
Aleatoric,
Epistemic,
Total,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ApproximationMethod {
LastLayerLaplace,
FullLaplace,
SWAG,
SWAGDiag,
}
#[derive(Debug, Clone)]
pub struct BNNConfig {
pub method: ApproximationMethod,
pub n_samples: usize,
pub prior_precision: f64,
pub swag_rank: usize,
pub swag_collection_start: usize,
pub swag_collection_freq: usize,
}
impl Default for BNNConfig {
fn default() -> Self {
Self {
method: ApproximationMethod::FullLaplace,
n_samples: 30,
prior_precision: 1.0,
swag_rank: 20,
swag_collection_start: 0,
swag_collection_freq: 1,
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum CovarianceType {
Full(Array2<f64>),
Diagonal(Array1<f64>),
LowRankPlusDiagonal {
d_diag: Array1<f64>,
deviation: Array2<f64>,
},
KroneckerFactored {
a_factor: Array2<f64>,
b_factor: Array2<f64>,
},
}
#[derive(Debug, Clone)]
pub struct BNNPosterior {
pub mean: Array1<f64>,
pub covariance_type: CovarianceType,
pub log_marginal_likelihood: f64,
}
#[derive(Debug, Clone)]
pub struct PredictiveDistribution {
pub mean: Array1<f64>,
pub variance: Array1<f64>,
pub samples: Option<Array2<f64>>,
}
#[derive(Debug, Clone)]
pub struct ReliabilityBin {
pub mean_predicted: f64,
pub mean_observed: f64,
pub count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let cfg = BNNConfig::default();
assert_eq!(cfg.n_samples, 30);
assert!((cfg.prior_precision - 1.0).abs() < 1e-12);
assert_eq!(cfg.swag_rank, 20);
assert_eq!(cfg.method, ApproximationMethod::FullLaplace);
}
#[test]
fn test_uncertainty_type_variants() {
let u = UncertaintyType::Total;
assert_eq!(u, UncertaintyType::Total);
assert_ne!(UncertaintyType::Aleatoric, UncertaintyType::Epistemic);
}
#[test]
fn test_covariance_type_diagonal() {
let diag = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let cov = CovarianceType::Diagonal(diag.clone());
match &cov {
CovarianceType::Diagonal(d) => assert_eq!(d.len(), 3),
_ => panic!("Expected Diagonal variant"),
}
}
#[test]
fn test_predictive_distribution_creation() {
let pd = PredictiveDistribution {
mean: Array1::from_vec(vec![1.0, 2.0]),
variance: Array1::from_vec(vec![0.1, 0.2]),
samples: None,
};
assert_eq!(pd.mean.len(), 2);
assert!(pd.samples.is_none());
}
#[test]
fn test_reliability_bin() {
let bin = ReliabilityBin {
mean_predicted: 0.5,
mean_observed: 0.48,
count: 100,
};
assert_eq!(bin.count, 100);
assert!((bin.mean_predicted - 0.5).abs() < 1e-12);
}
}