use super::distributions::{CategoricalNP, DirichletNP, GaussianNP};
use super::engine::{VariationalMessagePassing, VariationalState, VmpConfig, VmpFactor};
use super::exponential_family::ExponentialFamily;
fn gaussian_mean(state: &VariationalState) -> f64 {
match state {
VariationalState::Gaussian { q, .. } => q.mean,
_ => panic!("expected Gaussian"),
}
}
fn gaussian_precision(state: &VariationalState) -> f64 {
match state {
VariationalState::Gaussian { q, .. } => q.precision,
_ => panic!("expected Gaussian"),
}
}
#[test]
fn gaussian_single_observation_matches_closed_form() {
let config = VmpConfig::new()
.with_gaussian("mu", 0.0, 1.0)
.expect("prior")
.with_factor(VmpFactor::GaussianObservation {
target: "mu".to_string(),
observation: 3.0,
precision: 2.0,
})
.with_limits(50, 1e-10);
let mut engine = VariationalMessagePassing::new(config).expect("engine");
let result = engine.run().expect("run");
assert!(result.converged, "should converge on single observation");
let state = result.states.get("mu").expect("mu");
assert!((gaussian_mean(state) - 2.0).abs() < 1e-9);
assert!((gaussian_precision(state) - 3.0).abs() < 1e-9);
}
#[test]
fn gaussian_chain_recovers_analytical_joint() {
let config = VmpConfig::new()
.with_gaussian("m1", 0.0, 1.0)
.expect("prior m1")
.with_gaussian("m2", 0.0, 1.0)
.expect("prior m2")
.with_factor(VmpFactor::GaussianObservation {
target: "m1".to_string(),
observation: 1.0,
precision: 1.0,
})
.with_factor(VmpFactor::GaussianObservation {
target: "m2".to_string(),
observation: 5.0,
precision: 1.0,
})
.with_factor(VmpFactor::GaussianStep {
lhs: "m1".to_string(),
rhs: "m2".to_string(),
precision: 10.0,
})
.with_limits(400, 1e-10);
let mut engine = VariationalMessagePassing::new(config).expect("engine");
let result = engine.run().expect("run");
assert!(result.converged, "chain should converge");
let m1 = gaussian_mean(result.states.get("m1").expect("m1"));
let m2 = gaussian_mean(result.states.get("m2").expect("m2"));
let expected_m1 = 31.0 / 22.0;
let expected_m2 = 35.0 / 22.0;
assert!(
(m1 - expected_m1).abs() < 1e-4,
"m1 = {}, expected {}",
m1,
expected_m1
);
assert!(
(m2 - expected_m2).abs() < 1e-4,
"m2 = {}, expected {}",
m2,
expected_m2
);
let midpoint = 0.5 * (m1 + m2);
assert!(
(midpoint - 1.5).abs() < 1e-4,
"midpoint = {}, m1 = {}, m2 = {}",
midpoint,
m1,
m2
);
}
#[test]
fn dirichlet_categorical_conjugate_counts_posterior() {
let config = VmpConfig::new()
.with_dirichlet("pi", vec![1.0, 1.0, 1.0])
.expect("dir prior")
.with_factor(VmpFactor::CategoricalObservation {
dirichlet: "pi".to_string(),
observation: 0,
num_categories: 3,
})
.with_factor(VmpFactor::CategoricalObservation {
dirichlet: "pi".to_string(),
observation: 0,
num_categories: 3,
})
.with_factor(VmpFactor::CategoricalObservation {
dirichlet: "pi".to_string(),
observation: 0,
num_categories: 3,
})
.with_limits(10, 1e-10);
let mut engine = VariationalMessagePassing::new(config).expect("engine");
let result = engine.run().expect("run");
assert!(result.converged);
match result.states.get("pi").expect("pi") {
VariationalState::Dirichlet { q, .. } => {
assert!((q.concentration[0] - 4.0).abs() < 1e-12);
assert!((q.concentration[1] - 1.0).abs() < 1e-12);
assert!((q.concentration[2] - 1.0).abs() < 1e-12);
}
_ => panic!("expected Dirichlet"),
}
}
#[test]
fn elbo_is_monotonically_non_decreasing() {
let config = VmpConfig::new()
.with_gaussian("mu", 0.0, 1.0)
.expect("gauss")
.with_dirichlet("pi", vec![1.0, 1.0])
.expect("dir")
.with_categorical("x", 2)
.expect("cat")
.with_factor(VmpFactor::GaussianObservation {
target: "mu".to_string(),
observation: 2.0,
precision: 1.0,
})
.with_factor(VmpFactor::DirichletCategorical {
dirichlet: "pi".to_string(),
categorical: "x".to_string(),
})
.with_factor(VmpFactor::CategoricalObservation {
dirichlet: "pi".to_string(),
observation: 1,
num_categories: 2,
})
.with_limits(50, 1e-10);
let mut engine = VariationalMessagePassing::new(config).expect("engine");
let result = engine.run().expect("run");
for window in result.elbo_history.windows(2) {
let prev = window[0];
let next = window[1];
assert!(
next + 1e-6 >= prev,
"ELBO decreased: {} -> {} (history: {:?})",
prev,
next,
result.elbo_history
);
}
}
#[test]
fn divergence_tolerance_triggers_convergence_failure() {
let config = VmpConfig::new()
.with_gaussian("mu", 0.0, 1.0)
.expect("gauss")
.with_factor(VmpFactor::GaussianObservation {
target: "mu".to_string(),
observation: 1.0,
precision: 1.0,
});
let mut engine = VariationalMessagePassing::new(config).expect("engine");
let result = engine.run().expect("should not diverge");
assert!(result.converged);
}
#[test]
fn validate_rejects_family_mismatch() {
let config = VmpConfig::new()
.with_categorical("x", 3)
.expect("categorical")
.with_factor(VmpFactor::GaussianObservation {
target: "x".to_string(),
observation: 0.5,
precision: 1.0,
});
let result = VariationalMessagePassing::new(config);
assert!(result.is_err(), "family mismatch must be rejected");
}
#[test]
fn categorical_natural_params_renormalise_after_update() {
let mut cat = CategoricalNP::from_probs(&[0.3, 0.3, 0.4]).expect("ctor");
cat.set_natural(&[2.0, -1.5, 0.7]).expect("set nat");
let probs = cat.probs();
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-12);
for p in &probs {
assert!(*p >= 0.0 && *p <= 1.0);
}
}
#[test]
fn dirichlet_posterior_stays_positive_after_multiple_updates() {
let d = DirichletNP::new(vec![0.5, 0.5]).expect("ctor");
let config = VmpConfig::new()
.with_dirichlet("pi", d.concentration.clone())
.expect("dir")
.with_factor(VmpFactor::CategoricalObservation {
dirichlet: "pi".to_string(),
observation: 0,
num_categories: 2,
})
.with_factor(VmpFactor::CategoricalObservation {
dirichlet: "pi".to_string(),
observation: 1,
num_categories: 2,
})
.with_limits(5, 1e-10);
let mut engine = VariationalMessagePassing::new(config).expect("engine");
let result = engine.run().expect("run");
match result.states.get("pi").expect("pi") {
VariationalState::Dirichlet { q, .. } => {
for &a in &q.concentration {
assert!(a > 0.0, "concentration must stay positive");
}
}
_ => panic!("expected Dirichlet"),
}
}
#[test]
fn natural_params_round_trip_through_sweep() {
let prior = GaussianNP::new(0.7, 1.3).expect("prior");
let config = VmpConfig::new()
.with_gaussian("mu", prior.mean, prior.precision)
.expect("gauss")
.with_limits(10, 1e-12);
let mut engine = VariationalMessagePassing::new(config).expect("engine");
let result = engine.run().expect("run");
let state = result.states.get("mu").expect("mu");
assert!((gaussian_mean(state) - prior.mean).abs() < 1e-12);
assert!((gaussian_precision(state) - prior.precision).abs() < 1e-12);
}
#[test]
fn gaussian_natural_params_are_tau_mu() {
let g = GaussianNP::new(2.5, 4.0).expect("ctor");
let eta = g.natural_params();
assert_eq!(eta.len(), 1);
assert!((eta[0] - 10.0).abs() < 1e-12);
}