use scirs2_core::ndarray::{Array1, ArrayD};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::graph::FactorGraph;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BetheFreeEnergy {
pub factor_energy: f64,
pub variable_entropy: f64,
pub total: f64,
pub log_z: f64,
}
pub fn bethe_free_energy(
graph: &FactorGraph,
beliefs_var: &HashMap<String, Array1<f64>>,
beliefs_fac: &HashMap<String, ArrayD<f64>>,
) -> BetheFreeEnergy {
let eps = 1e-300_f64;
let mut factor_energy = 0.0_f64;
for (fac_id, fac_belief) in beliefs_fac {
if let Some(factor) = graph.get_factor(fac_id) {
for (b, phi) in fac_belief.iter().zip(factor.values.iter()) {
if *b > eps {
let log_phi = if *phi > eps { phi.ln() } else { -700.0 };
factor_energy += b * (log_phi - b.ln());
}
}
}
}
let mut variable_entropy = 0.0_f64;
for (var_name, belief) in beliefs_var {
let degree = graph
.get_adjacent_factors(var_name)
.map(|v| v.len())
.unwrap_or(0) as f64;
let entropy_i: f64 = belief
.iter()
.filter(|&&b| b > eps)
.map(|&b| b * b.ln())
.sum::<f64>();
variable_entropy += (1.0 - degree) * entropy_i;
}
let total = -(factor_energy + variable_entropy);
BetheFreeEnergy {
factor_energy,
variable_entropy,
total,
log_z: -total,
}
}