use super::{cpd::CPD, dag::DAG, exact_inference::BayesianNetwork};
use crate::StatsError;
use std::collections::HashMap;
use std::sync::Arc;
pub trait Rng {
fn next_f64(&mut self) -> f64;
fn next_usize(&mut self, n: usize) -> usize {
(self.next_f64() * n as f64) as usize
}
}
#[derive(Debug, Clone)]
pub struct LcgRng {
state: u64,
}
impl LcgRng {
pub fn new(seed: u64) -> Self {
Self { state: seed.max(1) }
}
}
impl Rng for LcgRng {
fn next_f64(&mut self) -> f64 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let bits = (self.state >> 11) | 0x3FF0_0000_0000_0000u64;
let f = f64::from_bits(bits) - 1.0;
f.clamp(0.0, 1.0 - f64::EPSILON)
}
}
fn sample_categorical(probs: &[f64], rng: &mut impl Rng) -> usize {
let sum: f64 = probs.iter().sum();
let u = rng.next_f64() * sum;
let mut cumsum = 0.0;
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if u < cumsum {
return i;
}
}
probs.len() - 1 }
pub struct GibbsSampler {
pub bn: Arc<BayesianNetwork>,
pub n_samples: usize,
pub burn_in: usize,
}
impl GibbsSampler {
pub fn new(bn: Arc<BayesianNetwork>, n_samples: usize, burn_in: usize) -> Self {
Self {
bn,
n_samples,
burn_in,
}
}
pub fn sample(
&self,
evidence: &HashMap<usize, usize>,
rng: &mut impl Rng,
) -> Result<Vec<Vec<usize>>, StatsError> {
let n = self.bn.dag.n_nodes;
let free_vars: Vec<usize> = (0..n).filter(|v| !evidence.contains_key(v)).collect();
if free_vars.is_empty() {
return Ok(Vec::new());
}
let mut state: Vec<usize> = (0..n)
.map(|v| {
if let Some(&val) = evidence.get(&v) {
val
} else {
let card = self.bn.cpds[v].cardinality();
if card == 0 {
0
} else {
rng.next_usize(card)
}
}
})
.collect();
let total = self.burn_in + self.n_samples;
let mut samples = Vec::with_capacity(self.n_samples);
for iter in 0..total {
for &v in &free_vars {
let cond_dist = self.compute_conditional(v, &state)?;
state[v] = sample_categorical(&cond_dist, rng);
}
if iter >= self.burn_in {
samples.push(state.clone());
}
}
Ok(samples)
}
pub fn query(
&self,
query_var: usize,
evidence: &HashMap<usize, usize>,
rng: &mut impl Rng,
) -> Result<Vec<f64>, StatsError> {
let card = self.bn.cpds[query_var].cardinality();
if card == 0 {
return Err(StatsError::InvalidInput(format!(
"Node {query_var} is continuous; use density estimation instead"
)));
}
let samples = self.sample(evidence, rng)?;
let mut counts = vec![0usize; card];
for sample in &samples {
counts[sample[query_var]] += 1;
}
let total = samples.len() as f64;
let mut probs: Vec<f64> = counts.iter().map(|&c| c as f64 / total).collect();
let s: f64 = probs.iter().sum();
if s > 1e-300 {
for p in &mut probs {
*p /= s;
}
}
Ok(probs)
}
fn compute_conditional(&self, v: usize, state: &[usize]) -> Result<Vec<f64>, StatsError> {
let card = self.bn.cpds[v].cardinality();
if card == 0 {
return Err(StatsError::InvalidInput(format!(
"Node {v} is continuous; Gibbs sampling requires discrete nodes"
)));
}
let dag = &self.bn.dag;
let mut probs = vec![0.0f64; card];
for val in 0..card {
let mut log_prob = 0.0f64;
let pa: Vec<usize> = dag.parents[v].iter().map(|&p| state[p]).collect();
let p = self.bn.cpds[v].prob(val, &pa);
if p <= 0.0 {
probs[val] = 0.0;
continue;
}
log_prob += p.ln();
for &ch in &dag.children[v] {
let ch_pa: Vec<usize> = dag.parents[ch]
.iter()
.map(|&p| if p == v { val } else { state[p] })
.collect();
let p_ch = self.bn.cpds[ch].prob(state[ch], &ch_pa);
if p_ch <= 0.0 {
log_prob = f64::NEG_INFINITY;
break;
}
log_prob += p_ch.ln();
}
probs[val] = log_prob.exp();
}
Ok(probs)
}
}
pub struct LikelihoodWeighting {
pub n_samples: usize,
}
impl LikelihoodWeighting {
pub fn new(n_samples: usize) -> Self {
Self { n_samples }
}
pub fn query(
&self,
bn: &BayesianNetwork,
query_var: usize,
evidence: &HashMap<usize, usize>,
rng: &mut impl Rng,
) -> Result<Vec<f64>, StatsError> {
let card = bn.cpds[query_var].cardinality();
if card == 0 {
return Err(StatsError::InvalidInput(format!(
"Node {query_var} is continuous"
)));
}
let topo = bn.dag.topological_sort();
let mut weighted_counts = vec![0.0f64; card];
for _ in 0..self.n_samples {
let (sample, weight) = self.generate_weighted_sample(bn, &topo, evidence, rng)?;
weighted_counts[sample[query_var]] += weight;
}
let total: f64 = weighted_counts.iter().sum();
if total < 1e-300 {
let card_f = card as f64;
return Ok(vec![1.0 / card_f; card]);
}
Ok(weighted_counts.iter().map(|&c| c / total).collect())
}
fn generate_weighted_sample(
&self,
bn: &BayesianNetwork,
topo: &[usize],
evidence: &HashMap<usize, usize>,
rng: &mut impl Rng,
) -> Result<(Vec<usize>, f64), StatsError> {
let n = bn.dag.n_nodes;
let mut sample = vec![0usize; n];
let mut log_weight = 0.0f64;
for &node in topo {
let cpd = &bn.cpds[node];
let pa: Vec<usize> = bn.dag.parents[node].iter().map(|&p| sample[p]).collect();
if let Some(&obs_val) = evidence.get(&node) {
sample[node] = obs_val;
let p = cpd.prob(obs_val, &pa);
if p <= 0.0 {
log_weight = f64::NEG_INFINITY;
} else {
log_weight += p.ln();
}
} else {
let card = cpd.cardinality();
if card == 0 {
sample[node] = 0; } else {
let probs: Vec<f64> = (0..card).map(|v| cpd.prob(v, &pa)).collect();
sample[node] = sample_categorical(&probs, rng);
}
}
}
Ok((sample, log_weight.exp()))
}
}
pub struct MeanFieldVI {
pub max_iter: usize,
pub tol: f64,
}
impl Default for MeanFieldVI {
fn default() -> Self {
Self {
max_iter: 100,
tol: 1e-6,
}
}
}
impl MeanFieldVI {
pub fn new(max_iter: usize, tol: f64) -> Self {
Self { max_iter, tol }
}
pub fn run(
&self,
bn: &BayesianNetwork,
evidence: &HashMap<usize, usize>,
) -> Result<Vec<Vec<f64>>, StatsError> {
let n = bn.dag.n_nodes;
let mut q: Vec<Vec<f64>> = (0..n)
.map(|i| {
let card = bn.cpds[i].cardinality();
if card == 0 {
return vec![1.0];
}
if let Some(&val) = evidence.get(&i) {
let mut v = vec![0.0; card];
v[val] = 1.0;
v
} else {
vec![1.0 / card as f64; card]
}
})
.collect();
let topo = bn.dag.topological_sort();
for _iter in 0..self.max_iter {
let old_q = q.clone();
for &node in &topo {
if evidence.contains_key(&node) {
continue;
}
let card = bn.cpds[node].cardinality();
if card == 0 {
continue;
}
let mut log_q = vec![0.0f64; card];
for val in 0..card {
log_q[val] += self.expected_log_cpd(bn, node, val, &q);
for &ch in &bn.dag.children[node] {
log_q[val] += self.expected_log_child(bn, ch, node, val, &q);
}
}
let max_l = log_q.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = log_q.iter().map(|&l| (l - max_l).exp()).collect();
let sum: f64 = exps.iter().sum();
q[node] = if sum > 1e-300 {
exps.iter().map(|e| e / sum).collect()
} else {
vec![1.0 / card as f64; card]
};
}
let max_change = q
.iter()
.zip(&old_q)
.flat_map(|(qi, qi_old)| qi.iter().zip(qi_old).map(|(&a, &b)| (a - b).abs()))
.fold(0.0f64, f64::max);
if max_change < self.tol {
break;
}
}
Ok(q)
}
fn expected_log_cpd(
&self,
bn: &BayesianNetwork,
node: usize,
val: usize,
q: &[Vec<f64>],
) -> f64 {
let cpd = &bn.cpds[node];
let parents = &bn.dag.parents[node];
if parents.is_empty() {
let p = cpd.prob(val, &[]);
return if p > 0.0 { p.ln() } else { -1e10 };
}
let parent_cards: Vec<usize> = parents.iter().map(|&p| q[p].len()).collect();
let n_configs: usize = parent_cards.iter().product();
let mut expected = 0.0f64;
for config_idx in 0..n_configs {
let pa_vals = decode_config(config_idx, &parent_cards);
let weight: f64 = parents
.iter()
.zip(&pa_vals)
.map(|(&p, &pv)| q[p][pv])
.product();
let p = cpd.prob(val, &pa_vals);
let log_p = if p > 0.0 { p.ln() } else { -1e10 };
expected += weight * log_p;
}
expected
}
fn expected_log_child(
&self,
bn: &BayesianNetwork,
child: usize,
node: usize,
node_val: usize,
q: &[Vec<f64>],
) -> f64 {
let cpd = &bn.cpds[child];
let ch_card = cpd.cardinality();
if ch_card == 0 {
return 0.0;
}
let parents = &bn.dag.children[node]; let ch_parents = &bn.dag.parents[child];
let other_parents: Vec<usize> = ch_parents.iter().copied().filter(|&p| p != node).collect();
let other_cards: Vec<usize> = other_parents.iter().map(|&p| q[p].len()).collect();
let n_configs: usize = if other_cards.is_empty() {
1
} else {
other_cards.iter().product()
};
let mut expected = 0.0f64;
for config_idx in 0..n_configs {
let other_vals = decode_config(config_idx, &other_cards);
let weight: f64 = other_parents
.iter()
.zip(&other_vals)
.map(|(&p, &pv)| q[p][pv])
.product::<f64>();
let pa_vals: Vec<usize> = ch_parents
.iter()
.map(|&p| {
if p == node {
node_val
} else {
let pos = other_parents.iter().position(|&op| op == p).unwrap_or(0);
other_vals[pos]
}
})
.collect();
let mut child_expected = 0.0f64;
for ch_val in 0..ch_card {
let q_ch = q[child][ch_val];
let p = cpd.prob(ch_val, &pa_vals);
let log_p = if p > 0.0 { p.ln() } else { -1e10 };
child_expected += q_ch * log_p;
}
expected += weight * child_expected;
}
let _ = parents; expected
}
}
fn decode_config(mut idx: usize, cards: &[usize]) -> Vec<usize> {
let n = cards.len();
let mut result = vec![0usize; n];
for i in (0..n).rev() {
if cards[i] == 0 {
continue;
}
result[i] = idx % cards[i];
idx /= cards[i];
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bayesian_network::cpd::TabularCPD;
use crate::bayesian_network::dag::DAG;
use crate::bayesian_network::exact_inference::BayesianNetwork;
fn wet_grass_network() -> Arc<BayesianNetwork> {
let mut dag = DAG::new(3);
dag.add_edge(0, 2).unwrap();
dag.add_edge(1, 2).unwrap();
let cpd_rain = TabularCPD::new(0, 2, vec![], vec![], vec![vec![0.8, 0.2]]).unwrap();
let cpd_spr = TabularCPD::new(1, 2, vec![], vec![], vec![vec![0.5, 0.5]]).unwrap();
let cpd_wg = TabularCPD::new(
2,
2,
vec![0, 1],
vec![2, 2],
vec![
vec![0.99, 0.01],
vec![0.01, 0.99],
vec![0.01, 0.99],
vec![0.01, 0.99],
],
)
.unwrap();
let cpds: Vec<Box<dyn CPD>> = vec![Box::new(cpd_rain), Box::new(cpd_spr), Box::new(cpd_wg)];
Arc::new(BayesianNetwork::new(dag, cpds).unwrap())
}
#[test]
fn test_gibbs_prior_rain() {
let bn = wet_grass_network();
let sampler = GibbsSampler::new(Arc::clone(&bn), 5000, 500);
let mut rng = LcgRng::new(42);
let probs = sampler.query(0, &HashMap::new(), &mut rng).unwrap();
assert!(
(probs[0] - 0.8).abs() < 0.05,
"P(Rain=0) ≈ 0.8, got {}",
probs[0]
);
}
#[test]
fn test_likelihood_weighting_prior() {
let bn = wet_grass_network();
let lw = LikelihoodWeighting::new(5000);
let mut rng = LcgRng::new(42);
let probs = lw.query(&bn, 0, &HashMap::new(), &mut rng).unwrap();
assert!(
(probs[0] - 0.8).abs() < 0.05,
"P(Rain=0) ≈ 0.8, got {}",
probs[0]
);
}
#[test]
fn test_likelihood_weighting_conditional() {
let bn = wet_grass_network();
let lw = LikelihoodWeighting::new(5000);
let mut rng = LcgRng::new(99);
let mut evidence = HashMap::new();
evidence.insert(2usize, 1usize); let probs = lw.query(&bn, 0, &evidence, &mut rng).unwrap();
assert!(
probs[1] > 0.2,
"P(Rain=1|WG=1) should be > 0.2, got {}",
probs[1]
);
}
#[test]
fn test_mean_field_prior() {
let bn = wet_grass_network();
let mf = MeanFieldVI::default();
let q = mf.run(&bn, &HashMap::new()).unwrap();
assert!(
(q[0][0] - 0.8).abs() < 0.1,
"q(Rain=0) ≈ 0.8, got {}",
q[0][0]
);
}
#[test]
fn test_mean_field_with_evidence() {
let bn = wet_grass_network();
let mf = MeanFieldVI::default();
let mut evidence = HashMap::new();
evidence.insert(2usize, 1usize); let q = mf.run(&bn, &evidence).unwrap();
assert!((q[2][1] - 1.0).abs() < 1e-9);
}
}