use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array;
use tensorlogic_quantrs_hooks::{ExpectationPropagation, Factor, FactorGraph};
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== Expectation Propagation Example ===\n");
println!("Example 1: Simple Disease Diagnosis Network");
println!("-------------------------------------------");
simple_disease_network()?;
println!("\n");
println!("Example 2: EP vs Mean-Field Comparison");
println!("--------------------------------------");
compare_inference_methods()?;
println!("\n");
println!("Example 3: EP with Different Parameters");
println!("---------------------------------------");
ep_with_parameters()?;
Ok(())
}
fn simple_disease_network() -> Result<(), Box<dyn std::error::Error>> {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("Flu".to_string(), "Disease".to_string(), 2);
graph.add_variable_with_card("Fever".to_string(), "Symptom".to_string(), 2);
graph.add_variable_with_card("Cough".to_string(), "Symptom".to_string(), 2);
let p_flu = Factor::new(
"P(Flu)".to_string(),
vec!["Flu".to_string()],
Array::from_shape_vec(vec![2], vec![0.95, 0.05])?.into_dyn(),
)?;
graph.add_factor(p_flu)?;
let p_fever_given_flu = Factor::new(
"P(Fever|Flu)".to_string(),
vec!["Flu".to_string(), "Fever".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])?.into_dyn(),
)?;
graph.add_factor(p_fever_given_flu)?;
let p_cough_given_flu = Factor::new(
"P(Cough|Flu)".to_string(),
vec!["Flu".to_string(), "Cough".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.95, 0.05, 0.3, 0.7])?.into_dyn(),
)?;
graph.add_factor(p_cough_given_flu)?;
let ep = ExpectationPropagation::new(100, 1e-6, 0.0);
let marginals = ep.run(&graph)?;
println!("Prior probability of flu: P(Flu=yes) = 0.05");
println!("\nMarginal probabilities:");
for (var, marginal) in &marginals {
println!(" P({}=no) = {:.4}", var, marginal[[0]]);
println!(" P({}=yes) = {:.4}", var, marginal[[1]]);
}
for marginal in marginals.values() {
let sum: f64 = marginal.sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-4);
}
println!("\n✓ All marginals are properly normalized");
Ok(())
}
fn compare_inference_methods() -> Result<(), Box<dyn std::error::Error>> {
use tensorlogic_quantrs_hooks::MeanFieldInference;
let mut graph = FactorGraph::new();
graph.add_variable_with_card("X".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("Y".to_string(), "Binary".to_string(), 2);
let px = Factor::new(
"P(X)".to_string(),
vec!["X".to_string()],
Array::from_shape_vec(vec![2], vec![0.6, 0.4])?.into_dyn(),
)?;
graph.add_factor(px)?;
let pyx = Factor::new(
"P(Y|X)".to_string(),
vec!["X".to_string(), "Y".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.8, 0.2, 0.2, 0.8])?.into_dyn(),
)?;
graph.add_factor(pyx)?;
println!("Running Expectation Propagation...");
let ep = ExpectationPropagation::new(100, 1e-6, 0.0);
let ep_marginals = ep.run(&graph)?;
println!("Running Mean-Field Variational Inference...");
let mf = MeanFieldInference::new(1000, 1e-6);
let mf_marginals = mf.run(&graph)?;
println!("\nComparison of marginal probabilities:");
println!(
"{:<8} {:<15} {:<15} {:<15}",
"Variable", "EP P(=1)", "Mean-Field P(=1)", "Difference"
);
println!("{}", "-".repeat(60));
for var in ["X", "Y"] {
let ep_prob = ep_marginals.get(var).map(|m| m[[1]]).unwrap_or(0.0);
let mf_prob = mf_marginals.get(var).map(|p| p[[1]]).unwrap_or(0.0);
let diff = (ep_prob - mf_prob).abs();
println!(
"{:<8} {:.6} {:.6} {:.6}",
var, ep_prob, mf_prob, diff
);
}
println!("\n✓ Both methods produce similar results (as expected for this simple graph)");
Ok(())
}
fn ep_with_parameters() -> Result<(), Box<dyn std::error::Error>> {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("X".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("Y".to_string(), "Binary".to_string(), 2);
let px = Factor::new(
"P(X)".to_string(),
vec!["X".to_string()],
Array::from_shape_vec(vec![2], vec![0.3, 0.7])?.into_dyn(),
)?;
graph.add_factor(px)?;
let pyx = Factor::new(
"P(Y|X)".to_string(),
vec!["X".to_string(), "Y".to_string()],
Array::from_shape_vec(vec![2, 2], vec![0.9, 0.1, 0.2, 0.8])?.into_dyn(),
)?;
graph.add_factor(pyx)?;
println!("Testing EP with different convergence tolerances:\n");
println!("1. Strict tolerance (1e-8):");
let ep_strict = ExpectationPropagation::new(100, 1e-8, 0.0);
let marginals_strict = ep_strict.run(&graph)?;
println!(
" P(Y=1) = {:.8}",
marginals_strict.get("Y").map(|m| m[[1]]).unwrap_or(0.0)
);
println!("\n2. Loose tolerance (1e-4):");
let ep_loose = ExpectationPropagation::new(100, 1e-4, 0.0);
let marginals_loose = ep_loose.run(&graph)?;
println!(
" P(Y=1) = {:.8}",
marginals_loose.get("Y").map(|m| m[[1]]).unwrap_or(0.0)
);
println!("\n3. With damping (damping=0.5):");
let ep_damped = ExpectationPropagation::new(100, 1e-6, 0.5);
let marginals_damped = ep_damped.run(&graph)?;
println!(
" P(Y=1) = {:.8}",
marginals_damped.get("Y").map(|m| m[[1]]).unwrap_or(0.0)
);
println!("\n✓ EP parameters can be tuned for different convergence behaviors");
println!(" (Stricter tolerance may require more iterations but gives more precise results)");
Ok(())
}