use crate::{Factor, FactorGraph, PgmError, Result};
use scirs2_core::ndarray::ArrayD;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct Site {
pub factor: Factor,
pub variables: Vec<String>,
}
impl Site {
pub fn new_uniform(
name: String,
variables: Vec<String>,
cardinalities: &[usize],
) -> Result<Self> {
let total_size: usize = cardinalities.iter().product();
let uniform_value = 1.0 / total_size as f64;
let values = ArrayD::from_elem(cardinalities.to_vec(), uniform_value);
let factor = Factor::new(name, variables.clone(), values)?;
Ok(Self { factor, variables })
}
pub fn from_factor(factor: Factor) -> Self {
let variables = factor.variables.clone();
Self { factor, variables }
}
}
pub struct ExpectationPropagation {
max_iterations: usize,
tolerance: f64,
damping: f64,
min_value: f64,
}
impl Default for ExpectationPropagation {
fn default() -> Self {
Self::new(100, 1e-6, 0.0)
}
}
impl ExpectationPropagation {
pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
Self {
max_iterations,
tolerance,
damping,
min_value: 1e-10,
}
}
pub fn run(&self, graph: &FactorGraph) -> Result<HashMap<String, ArrayD<f64>>> {
let mut sites = self.initialize_sites(graph)?;
let mut approx = self.compute_global_approximation(graph, &sites)?;
for iteration in 0..self.max_iterations {
let mut max_change: f64 = 0.0;
for (factor_idx, factor) in graph.factors().enumerate() {
let cavity = self.compute_cavity(&approx, &sites[factor_idx])?;
let tilted = self.compute_tilted(&cavity, factor)?;
let new_site = self.moment_match(&cavity, &tilted, &sites[factor_idx])?;
let damped_site = self.apply_damping(&sites[factor_idx], &new_site)?;
let change = self.compute_site_change(&sites[factor_idx], &damped_site)?;
max_change = max_change.max(change);
sites[factor_idx] = damped_site;
}
approx = self.compute_global_approximation(graph, &sites)?;
if max_change < self.tolerance {
eprintln!(
"EP converged in {} iterations (max change: {:.6})",
iteration + 1,
max_change
);
break;
}
if iteration == self.max_iterations - 1 {
eprintln!(
"EP reached maximum iterations ({}) with max change: {:.6}",
self.max_iterations, max_change
);
}
}
self.extract_marginals(graph, &approx, &sites)
}
fn initialize_sites(&self, graph: &FactorGraph) -> Result<Vec<Site>> {
let mut sites = Vec::new();
for (idx, factor) in graph.factors().enumerate() {
let cardinalities: Vec<usize> = factor
.variables
.iter()
.map(|var| graph.get_variable(var).map(|v| v.cardinality).unwrap_or(2))
.collect();
let site = Site::new_uniform(
format!("site_{}", idx),
factor.variables.clone(),
&cardinalities,
)?;
sites.push(site);
}
Ok(sites)
}
fn compute_global_approximation(&self, _graph: &FactorGraph, sites: &[Site]) -> Result<Factor> {
if sites.is_empty() {
return Err(PgmError::InvalidGraph(
"No sites to compute approximation".to_string(),
));
}
let mut result = sites[0].factor.clone();
for site in sites.iter().skip(1) {
result = result.product(&site.factor)?;
}
result.normalize();
Ok(result)
}
fn compute_cavity(&self, approx: &Factor, site: &Site) -> Result<Factor> {
let approx_marginal = if approx.variables == site.variables {
approx.clone()
} else {
approx.marginalize_out_all_except(&site.variables)?
};
let cavity = approx_marginal.divide(&site.factor)?;
Ok(cavity)
}
fn compute_tilted(&self, cavity: &Factor, true_factor: &Factor) -> Result<Factor> {
let tilted = cavity.product(true_factor)?;
Ok(tilted)
}
fn moment_match(&self, cavity: &Factor, tilted: &Factor, _old_site: &Site) -> Result<Site> {
let new_factor = tilted.divide(cavity)?;
let mut stabilized = new_factor.clone();
stabilized.values.mapv_inplace(|v| v.max(self.min_value));
Ok(Site::from_factor(stabilized))
}
fn apply_damping(&self, old_site: &Site, new_site: &Site) -> Result<Site> {
if self.damping == 0.0 {
return Ok(new_site.clone());
}
let old_values = &old_site.factor.values;
let new_values = &new_site.factor.values;
let damped_values = (1.0 - self.damping) * new_values + self.damping * old_values;
let damped_factor = Factor::new(
new_site.factor.name.clone(),
new_site.factor.variables.clone(),
damped_values,
)?;
Ok(Site::from_factor(damped_factor))
}
fn compute_site_change(&self, old_site: &Site, new_site: &Site) -> Result<f64> {
let diff = &new_site.factor.values - &old_site.factor.values;
let change = diff.mapv(|v| v.abs()).sum();
Ok(change)
}
fn extract_marginals(
&self,
graph: &FactorGraph,
approx: &Factor,
_sites: &[Site],
) -> Result<HashMap<String, ArrayD<f64>>> {
let mut marginals = HashMap::new();
for (var, _) in graph.variables() {
let marginal = approx.marginalize_out_all_except(std::slice::from_ref(var))?;
let mut normalized = marginal.clone();
normalized.normalize();
marginals.insert(var.clone(), normalized.values);
}
Ok(marginals)
}
}
#[derive(Debug, Clone)]
pub struct GaussianSite {
pub variable: String,
pub precision: f64,
pub precision_mean: f64,
}
impl GaussianSite {
pub fn new(variable: String, precision: f64, precision_mean: f64) -> Self {
Self {
variable,
precision,
precision_mean,
}
}
pub fn uniform(variable: String) -> Self {
Self {
variable,
precision: 0.0,
precision_mean: 0.0,
}
}
pub fn mean(&self) -> f64 {
if self.precision > 1e-10 {
self.precision_mean / self.precision
} else {
0.0
}
}
pub fn variance(&self) -> f64 {
if self.precision > 1e-10 {
1.0 / self.precision
} else {
f64::INFINITY
}
}
pub fn product(&self, other: &GaussianSite) -> Self {
Self {
variable: self.variable.clone(),
precision: self.precision + other.precision,
precision_mean: self.precision_mean + other.precision_mean,
}
}
pub fn divide(&self, other: &GaussianSite) -> Self {
Self {
variable: self.variable.clone(),
precision: self.precision - other.precision,
precision_mean: self.precision_mean - other.precision_mean,
}
}
}
#[allow(dead_code)]
pub struct GaussianEP {
max_iterations: usize,
tolerance: f64,
damping: f64,
}
impl Default for GaussianEP {
fn default() -> Self {
Self::new(100, 1e-6, 0.0)
}
}
impl GaussianEP {
pub fn new(max_iterations: usize, tolerance: f64, damping: f64) -> Self {
Self {
max_iterations,
tolerance,
damping,
}
}
pub fn compute_moments(
&self,
cavity: &GaussianSite,
_true_factor_callback: impl Fn(f64) -> f64,
) -> (f64, f64) {
let mean = cavity.mean();
let variance = cavity.variance();
(mean, variance)
}
pub fn match_moments(
&self,
cavity: &GaussianSite,
tilted_mean: f64,
tilted_var: f64,
) -> GaussianSite {
let new_precision = 1.0 / tilted_var - cavity.precision;
let new_precision_mean = tilted_mean / tilted_var - cavity.precision_mean;
GaussianSite::new(
cavity.variable.clone(),
new_precision.max(0.0), new_precision_mean,
)
}
pub fn damp_site(&self, old_site: &GaussianSite, new_site: &GaussianSite) -> GaussianSite {
if self.damping == 0.0 {
return new_site.clone();
}
let damped_precision =
(1.0 - self.damping) * new_site.precision + self.damping * old_site.precision;
let damped_precision_mean =
(1.0 - self.damping) * new_site.precision_mean + self.damping * old_site.precision_mean;
GaussianSite::new(
new_site.variable.clone(),
damped_precision,
damped_precision_mean,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array;
#[test]
fn test_site_creation() {
let site = Site::new_uniform("test_site".to_string(), vec!["X".to_string()], &[2])
.expect("unwrap");
assert_eq!(site.variables.len(), 1);
assert_eq!(site.factor.variables[0], "X");
let sum: f64 = site.factor.values.sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
}
#[test]
fn test_gaussian_site_moments() {
let site = GaussianSite::new("X".to_string(), 2.0, 4.0);
assert_abs_diff_eq!(site.mean(), 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(site.variance(), 0.5, epsilon = 1e-10);
}
#[test]
fn test_gaussian_site_product() {
let site1 = GaussianSite::new("X".to_string(), 2.0, 4.0);
let site2 = GaussianSite::new("X".to_string(), 3.0, 6.0);
let product = site1.product(&site2);
assert_abs_diff_eq!(product.precision, 5.0, epsilon = 1e-10);
assert_abs_diff_eq!(product.precision_mean, 10.0, epsilon = 1e-10);
}
#[test]
fn test_gaussian_site_divide() {
let site1 = GaussianSite::new("X".to_string(), 5.0, 10.0);
let site2 = GaussianSite::new("X".to_string(), 2.0, 4.0);
let quotient = site1.divide(&site2);
assert_abs_diff_eq!(quotient.precision, 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(quotient.precision_mean, 6.0, epsilon = 1e-10);
}
#[test]
fn test_ep_initialization() {
let ep = ExpectationPropagation::new(50, 1e-5, 0.5);
assert_eq!(ep.max_iterations, 50);
assert_abs_diff_eq!(ep.tolerance, 1e-5, epsilon = 1e-10);
assert_abs_diff_eq!(ep.damping, 0.5, epsilon = 1e-10);
}
#[test]
fn test_ep_simple_graph() {
use crate::FactorGraph;
let mut graph = FactorGraph::new();
graph.add_variable_with_card("X".to_string(), "Binary".to_string(), 2);
let values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
.expect("unwrap")
.into_dyn();
let factor =
Factor::new("P(X)".to_string(), vec!["X".to_string()], values).expect("unwrap");
graph.add_factor(factor).expect("unwrap");
let ep = ExpectationPropagation::default();
let marginals = ep.run(&graph).expect("unwrap");
assert!(marginals.contains_key("X"));
let marginal = &marginals["X"];
assert_eq!(marginal.ndim(), 1);
assert_eq!(marginal.len(), 2);
let sum: f64 = marginal.sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
}
#[test]
fn test_gaussian_ep_moment_matching() {
let gep = GaussianEP::default();
let cavity = GaussianSite::new("X".to_string(), 1.0, 0.0);
let tilted_mean = 2.0;
let tilted_var = 0.5;
let new_site = gep.match_moments(&cavity, tilted_mean, tilted_var);
let product = cavity.product(&new_site);
assert_abs_diff_eq!(product.mean(), tilted_mean, epsilon = 1e-6);
assert_abs_diff_eq!(product.variance(), tilted_var, epsilon = 1e-6);
}
#[test]
fn test_ep_two_factor_graph() {
use crate::FactorGraph;
let mut graph = FactorGraph::new();
graph.add_variable_with_card("X".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("Y".to_string(), "Binary".to_string(), 2);
let px_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
.expect("unwrap")
.into_dyn();
let px = Factor::new("P(X)".to_string(), vec!["X".to_string()], px_values).expect("unwrap");
graph.add_factor(px).expect("unwrap");
let pyx_values = Array::from_shape_vec(
vec![2, 2],
vec![0.8, 0.2, 0.3, 0.7], )
.expect("unwrap")
.into_dyn();
let pyx = Factor::new(
"P(Y|X)".to_string(),
vec!["X".to_string(), "Y".to_string()],
pyx_values,
)
.expect("unwrap");
graph.add_factor(pyx).expect("unwrap");
let ep = ExpectationPropagation::new(100, 1e-6, 0.0);
let marginals = ep.run(&graph).expect("unwrap");
assert!(marginals.contains_key("X"));
assert!(marginals.contains_key("Y"));
let sum_x: f64 = marginals["X"].sum();
let sum_y: f64 = marginals["Y"].sum();
assert_abs_diff_eq!(sum_x, 1.0, epsilon = 1e-6);
assert_abs_diff_eq!(sum_y, 1.0, epsilon = 1e-6);
}
}