use scirs2_core::ndarray::ArrayD;
use scirs2_core::random::{RngExt, SeedableRng, StdRng};
use tensorlogic_quantrs_hooks::vmp::beta as vmp_beta;
use tensorlogic_quantrs_hooks::vmp::gamma as vmp_gamma;
use tensorlogic_quantrs_hooks::{
BayesianNetwork, BetaNP, GammaNP, VariationalMessagePassing, VariationalState, VmpConfig,
VmpFactor,
};
#[test]
fn vmp_on_bayesian_network_structure() {
let mut bn = BayesianNetwork::new();
bn.add_variable("mu".to_string(), 1);
let dummy_cpd = ArrayD::from_shape_vec(vec![1], vec![1.0]).expect("cpd");
bn.add_cpd("mu".to_string(), vec![], dummy_cpd)
.expect("cpd");
let config = VmpConfig::new()
.with_gaussian("mu", 0.0, 1.0)
.expect("register mu")
.with_factor(VmpFactor::GaussianObservation {
target: "mu".to_string(),
observation: 4.0,
precision: 3.0,
})
.with_limits(100, 1e-10);
let mut engine = VariationalMessagePassing::with_graph(bn.graph(), config).expect("engine");
let result = engine.run().expect("run");
assert!(result.converged);
match result.states.get("mu").expect("mu") {
VariationalState::Gaussian { q, .. } => {
assert!((q.mean - 3.0).abs() < 1e-9, "posterior mean = {}", q.mean);
assert!(
(q.precision - 4.0).abs() < 1e-9,
"posterior precision = {}",
q.precision
);
}
_ => panic!("expected Gaussian"),
}
assert!(!result.elbo_history.is_empty());
for window in result.elbo_history.windows(2) {
assert!(
window[1] + 1e-7 >= window[0],
"ELBO decreased: {} -> {}",
window[0],
window[1]
);
}
}
#[test]
fn vmp_rejects_variables_missing_from_graph() {
let bn = BayesianNetwork::new();
let config = VmpConfig::new()
.with_gaussian("missing", 0.0, 1.0)
.expect("register missing");
let result = VariationalMessagePassing::with_graph(bn.graph(), config);
assert!(result.is_err(), "missing variable must be rejected");
}
#[test]
fn vmp_dirichlet_categorical_conjugate_integration() {
let mut bn = BayesianNetwork::new();
bn.add_variable("pi".to_string(), 3); bn.add_variable("x".to_string(), 3);
let config = VmpConfig::new()
.with_dirichlet("pi", vec![1.0, 1.0, 1.0])
.expect("dir")
.with_categorical("x", 3)
.expect("cat")
.with_factor(VmpFactor::DirichletCategorical {
dirichlet: "pi".to_string(),
categorical: "x".to_string(),
})
.with_factor(VmpFactor::CategoricalObservation {
dirichlet: "pi".to_string(),
observation: 1,
num_categories: 3,
})
.with_factor(VmpFactor::CategoricalObservation {
dirichlet: "pi".to_string(),
observation: 1,
num_categories: 3,
})
.with_limits(100, 1e-8);
let mut engine = VariationalMessagePassing::with_graph(bn.graph(), config).expect("engine");
let result = engine.run().expect("run");
assert!(result.converged);
match result.states.get("pi").expect("pi") {
VariationalState::Dirichlet { q, .. } => {
let alpha = &q.concentration;
assert_eq!(alpha.len(), 3);
for &a in alpha {
assert!(a > 0.0, "concentrations must be strictly positive");
}
assert!(
(alpha[0] - alpha[2]).abs() < 1e-8,
"α[0] ({}) should equal α[2] ({}) by symmetry",
alpha[0],
alpha[2]
);
assert!(
alpha[1] > alpha[0],
"α[1] ({}) must dominate α[0] ({})",
alpha[1],
alpha[0]
);
assert!(
alpha[1] > alpha[2],
"α[1] ({}) must dominate α[2] ({})",
alpha[1],
alpha[2]
);
let alpha_sum: f64 = alpha.iter().sum();
assert!(
(alpha_sum - 6.0).abs() < 1e-8,
"Σα = {} (expected 6 = 3 prior + 2 obs + 1 latent)",
alpha_sum
);
}
_ => panic!("expected Dirichlet"),
}
}
fn sample_poisson(lambda: f64, rng: &mut StdRng) -> u64 {
let l = (-lambda).exp();
let mut k: u64 = 0;
let mut p: f64 = 1.0;
loop {
k += 1;
let u: f64 = rng.random();
p *= u;
if p <= l {
return k - 1;
}
}
}
#[test]
fn vmp_gamma_poisson_end_to_end() {
let true_lambda = 2.5_f64;
let n = 100_usize;
let mut rng = StdRng::seed_from_u64(42);
let observations: Vec<u64> = (0..n)
.map(|_| sample_poisson(true_lambda, &mut rng))
.collect();
let prior = GammaNP::new(1.0, 1.0).expect("prior");
let posterior =
vmp_gamma::posterior_from_prior_and_observations(&prior, &observations).expect("posterior");
let posterior_mean = posterior.alpha / posterior.beta;
assert!(
(posterior_mean - true_lambda).abs() < 0.3,
"Gamma-Poisson posterior mean {:.4} should be within 0.3 of true λ = {}",
posterior_mean,
true_lambda
);
let sum: u64 = observations.iter().sum();
assert!(
(posterior.alpha - (1.0 + sum as f64)).abs() < 1e-12,
"posterior alpha = {}, expected {}",
posterior.alpha,
1.0 + sum as f64
);
assert!(
(posterior.beta - (1.0 + n as f64)).abs() < 1e-12,
"posterior beta = {}, expected {}",
posterior.beta,
1.0 + n as f64
);
let kl = posterior.kl_to(&prior);
assert!(kl > 0.0, "KL(posterior || prior) = {}", kl);
}
#[test]
fn vmp_beta_bernoulli_end_to_end() {
let true_p = 0.7_f64;
let n = 500_usize;
let mut rng = StdRng::seed_from_u64(99);
let mut successes: u64 = 0;
let mut failures: u64 = 0;
for _ in 0..n {
let u: f64 = rng.random();
if u < true_p {
successes += 1;
} else {
failures += 1;
}
}
let prior = BetaNP::new(1.0, 1.0).expect("prior");
let posterior = vmp_beta::posterior_from_prior_and_observations(&prior, successes, failures)
.expect("posterior");
let posterior_mean = posterior.alpha / (posterior.alpha + posterior.beta);
assert!(
(posterior_mean - true_p).abs() < 0.05,
"Beta-Bernoulli posterior mean {:.4} should be within 0.05 of true p = {}",
posterior_mean,
true_p
);
assert!(
(posterior.alpha - (1.0 + successes as f64)).abs() < 1e-12,
"posterior alpha = {}, expected {}",
posterior.alpha,
1.0 + successes as f64
);
assert!(
(posterior.beta - (1.0 + failures as f64)).abs() < 1e-12,
"posterior beta = {}, expected {}",
posterior.beta,
1.0 + failures as f64
);
assert_eq!(successes + failures, n as u64);
let kl = posterior.kl_to(&prior);
assert!(kl > 0.0, "KL(posterior || prior) = {}", kl);
}