#[derive(Debug, Clone)]
pub struct AdviConfig {
pub n_samples: usize,
pub n_iter: usize,
pub lr: f64,
pub tol: f64,
pub prior_precision: f64,
pub fd_step: f64,
pub seed: u64,
}
impl Default for AdviConfig {
fn default() -> Self {
Self {
n_samples: 1,
n_iter: 1000,
lr: 0.01,
tol: 1e-4,
prior_precision: 1.0,
fd_step: 1e-5,
seed: 42,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VariationalFamily {
MeanField,
FullRank,
NormalizingFlow,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum ConstraintType {
Unconstrained,
Positive,
Simplex,
Bounded {
lo: f64,
hi: f64,
},
}
#[derive(Debug, Clone)]
pub struct AdviResult {
pub elbo_history: Vec<f64>,
pub mu: Vec<f64>,
pub log_sigma: Vec<f64>,
pub converged: bool,
pub n_iter_performed: usize,
}
impl AdviResult {
pub fn sigma(&self) -> Vec<f64> {
self.log_sigma.iter().map(|&w| w.exp()).collect()
}
pub fn final_elbo(&self) -> f64 {
self.elbo_history.last().copied().unwrap_or(f64::NAN)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_advi_config_default() {
let cfg = AdviConfig::default();
assert_eq!(cfg.n_samples, 1);
assert_eq!(cfg.n_iter, 1000);
assert!((cfg.lr - 0.01).abs() < 1e-12);
assert!((cfg.tol - 1e-4).abs() < 1e-12);
assert!((cfg.prior_precision - 1.0).abs() < 1e-12);
assert!((cfg.fd_step - 1e-5).abs() < 1e-12);
}
#[test]
fn test_variational_family_eq() {
assert_eq!(VariationalFamily::MeanField, VariationalFamily::MeanField);
assert_ne!(VariationalFamily::MeanField, VariationalFamily::FullRank);
}
#[test]
fn test_constraint_type_variants() {
let c = ConstraintType::Bounded { lo: 0.0, hi: 1.0 };
match c {
ConstraintType::Bounded { lo, hi } => {
assert!((lo - 0.0).abs() < 1e-15);
assert!((hi - 1.0).abs() < 1e-15);
}
_ => panic!("Expected Bounded variant"),
}
}
#[test]
fn test_advi_result_sigma() {
let result = AdviResult {
elbo_history: vec![-10.0, -5.0, -2.0],
mu: vec![1.0, 2.0],
log_sigma: vec![0.0, -1.0],
converged: true,
n_iter_performed: 3,
};
let sigma = result.sigma();
assert!((sigma[0] - 1.0).abs() < 1e-12, "exp(0) = 1");
assert!((sigma[1] - (-1.0_f64).exp()).abs() < 1e-12);
assert!((result.final_elbo() - (-2.0)).abs() < 1e-12);
}
}