#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum HessianMethod {
#[default]
GGN,
Diagonal,
KFAC,
}
#[derive(Debug, Clone)]
pub struct LaplaceConfig {
pub hessian_method: HessianMethod,
pub damping: f64,
pub fd_step: f64,
}
impl Default for LaplaceConfig {
fn default() -> Self {
Self {
hessian_method: HessianMethod::GGN,
damping: 1.0,
fd_step: 1e-5,
}
}
}
#[derive(Debug, Clone)]
pub struct SwagConfig {
pub n_epochs: usize,
pub c: usize,
pub lr: f64,
}
impl Default for SwagConfig {
fn default() -> Self {
Self {
n_epochs: 20,
c: 20,
lr: 0.01,
}
}
}
#[derive(Debug, Clone)]
pub struct BnnApproxResult {
pub mean_weights: Vec<f64>,
pub uncertainty: Vec<f64>,
pub method: String,
}
impl BnnApproxResult {
pub fn std_devs(&self) -> Vec<f64> {
self.uncertainty.iter().map(|&v| v.sqrt()).collect()
}
pub fn n_params(&self) -> usize {
self.mean_weights.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bayesian_approx_config_default() {
let lap = LaplaceConfig::default();
assert_eq!(lap.hessian_method, HessianMethod::GGN);
assert!((lap.damping - 1.0).abs() < 1e-12);
let swag = SwagConfig::default();
assert_eq!(swag.n_epochs, 20);
assert_eq!(swag.c, 20);
assert!((swag.lr - 0.01).abs() < 1e-12);
}
#[test]
fn test_hessian_method_default_is_ggn() {
let m = HessianMethod::default();
assert_eq!(m, HessianMethod::GGN);
}
#[test]
fn test_bnn_approx_result_std_devs() {
let result = BnnApproxResult {
mean_weights: vec![1.0, 2.0, 3.0],
uncertainty: vec![4.0, 9.0, 16.0],
method: "Laplace".to_string(),
};
let stds = result.std_devs();
assert!((stds[0] - 2.0).abs() < 1e-12);
assert!((stds[1] - 3.0).abs() < 1e-12);
assert!((stds[2] - 4.0).abs() < 1e-12);
assert_eq!(result.n_params(), 3);
}
}