use tensorlogic_quantrs_hooks::{
BetheApproximation, FactorGraph, MeanFieldInference, TreeReweightedBP,
};
fn main() -> anyhow::Result<()> {
println!("=== Structured Variational Inference Example ===\n");
println!("Creating 3x3 grid Markov Random Field...");
let graph = create_grid_mrf(3, 3);
println!(" - 9 binary variables (grid positions)");
println!(
" - {} factors (node potentials + edge potentials)\n",
graph.factor_ids().count()
);
println!("=== Method 1: Mean-Field Approximation ===");
println!("(Assumes complete independence between variables)\n");
let mean_field = MeanFieldInference::new(100, 1e-6);
let mf_start = std::time::Instant::now();
let mf_beliefs = mean_field.run(&graph)?;
let mf_time = mf_start.elapsed();
println!("Converged in {:?}", mf_time);
println!("\nMean-Field Marginals (first 3 variables):");
for var_name in graph.variable_names().take(3) {
if let Some(belief) = mf_beliefs.get(var_name) {
println!(" P({}=1) = {:.4}", var_name, belief[[1]]);
}
}
let mf_elbo = mean_field.compute_elbo(&graph, &mf_beliefs)?;
println!("\nMean-Field ELBO: {:.4}\n", mf_elbo);
println!("=== Method 2: Bethe Approximation ===");
println!("(Uses factor graph structure for structured approximation)\n");
let bethe = BetheApproximation::new(100, 1e-6, 0.0);
let bethe_start = std::time::Instant::now();
let bethe_beliefs = bethe.run(&graph)?;
let bethe_time = bethe_start.elapsed();
println!("Converged in {:?}", bethe_time);
println!("\nBethe Marginals (first 3 variables):");
for var_name in graph.variable_names().take(3) {
if let Some(belief) = bethe_beliefs.get(var_name) {
println!(" P({}=1) = {:.4}", var_name, belief[[1]]);
}
}
let factor_beliefs = bethe.compute_factor_beliefs(&graph, &bethe_beliefs)?;
let bethe_free_energy = bethe.compute_free_energy(&graph, &bethe_beliefs, &factor_beliefs)?;
println!("\nBethe Free Energy: {:.4}\n", bethe_free_energy);
println!("=== Method 3: Tree-Reweighted BP ===");
println!("(Provides upper bounds on log partition function)\n");
let mut trw_bp = TreeReweightedBP::new(100, 1e-6);
let trw_start = std::time::Instant::now();
let trw_beliefs = trw_bp.run(&graph)?;
let trw_time = trw_start.elapsed();
println!("Converged in {:?}", trw_time);
println!("\nTRW-BP Marginals (first 3 variables):");
for var_name in graph.variable_names().take(3) {
if let Some(belief) = trw_beliefs.get(var_name) {
println!(" P({}=1) = {:.4}", var_name, belief[[1]]);
}
}
println!();
println!("=== Comparison Summary ===\n");
println!("Convergence Times:");
println!(" Mean-Field: {:?}", mf_time);
println!(" Bethe: {:?}", bethe_time);
println!(" TRW-BP: {:?}", trw_time);
println!();
println!("Variational Objectives:");
println!(" Mean-Field ELBO: {:.4}", mf_elbo);
println!(" Bethe Free Energy: {:.4}", bethe_free_energy);
println!();
println!("Marginal Differences (L1 norm):");
let mf_bethe_diff = compute_marginal_difference(&mf_beliefs, &bethe_beliefs);
let mf_trw_diff = compute_marginal_difference(&mf_beliefs, &trw_beliefs);
let bethe_trw_diff = compute_marginal_difference(&bethe_beliefs, &trw_beliefs);
println!(" Mean-Field vs Bethe: {:.4}", mf_bethe_diff);
println!(" Mean-Field vs TRW-BP: {:.4}", mf_trw_diff);
println!(" Bethe vs TRW-BP: {:.4}", bethe_trw_diff);
println!();
println!("=== Analysis ===\n");
println!("1. **Mean-Field** assumes complete independence, which is");
println!(" inappropriate for the grid structure. This leads to");
println!(" potential inaccuracies in marginal estimates.");
println!();
println!("2. **Bethe Approximation** respects the factor graph structure,");
println!(" providing more accurate marginals at similar computational cost.");
println!(" The Bethe free energy is typically tighter than mean-field ELBO.");
println!();
println!("3. **TRW-BP** uses edge reweighting to provide upper bounds on");
println!(" the log partition function, making it particularly robust for");
println!(" loopy graphs. Convergence is guaranteed for convex combinations");
println!(" of spanning trees.");
println!();
println!("For grid-structured MRFs with loops, Bethe and TRW-BP typically");
println!("outperform mean-field in terms of accuracy, while maintaining");
println!("similar computational efficiency.\n");
println!("✓ Structured variational inference demonstration completed!");
Ok(())
}
fn create_grid_mrf(rows: usize, cols: usize) -> FactorGraph {
use scirs2_core::ndarray::Array;
let mut graph = FactorGraph::new();
for i in 0..rows {
for j in 0..cols {
let var_name = format!("x_{}_{}", i, j);
graph.add_variable_with_card(var_name, "Binary".to_string(), 2);
}
}
let node_potential_values = vec![0.4, 0.6]; for i in 0..rows {
for j in 0..cols {
let var_name = format!("x_{}_{}", i, j);
let factor_name = format!("node_{}_{}", i, j);
let potential = Array::from_shape_vec(vec![2], node_potential_values.clone())
.expect("create_grid_mrf: Failed to create node potential array")
.into_dyn();
use tensorlogic_quantrs_hooks::Factor;
let factor = Factor {
name: factor_name,
variables: vec![var_name],
values: potential,
};
graph
.add_factor(factor)
.expect("create_grid_mrf: Failed to add node factor");
}
}
for i in 0..rows {
for j in 0..(cols - 1) {
let var1 = format!("x_{}_{}", i, j);
let var2 = format!("x_{}_{}", i, j + 1);
let factor_name = format!("h_edge_{}_{}_{}", i, j, j + 1);
let edge_values = vec![
0.8, 0.2, 0.2, 0.8, ];
let potential = Array::from_shape_vec(vec![2, 2], edge_values)
.expect("create_grid_mrf: Failed to create horizontal edge potential array")
.into_dyn();
use tensorlogic_quantrs_hooks::Factor;
let factor = Factor {
name: factor_name,
variables: vec![var1, var2],
values: potential,
};
graph
.add_factor(factor)
.expect("create_grid_mrf: Failed to add horizontal edge factor");
}
}
for i in 0..(rows - 1) {
for j in 0..cols {
let var1 = format!("x_{}_{}", i, j);
let var2 = format!("x_{}_{}", i + 1, j);
let factor_name = format!("v_edge_{}_{}_{}", i, i + 1, j);
let edge_values = vec![0.8, 0.2, 0.2, 0.8];
let potential = Array::from_shape_vec(vec![2, 2], edge_values)
.expect("create_grid_mrf: Failed to create vertical edge potential array")
.into_dyn();
use tensorlogic_quantrs_hooks::Factor;
let factor = Factor {
name: factor_name,
variables: vec![var1, var2],
values: potential,
};
graph
.add_factor(factor)
.expect("create_grid_mrf: Failed to add vertical edge factor");
}
}
graph
}
fn compute_marginal_difference(
beliefs1: &std::collections::HashMap<String, scirs2_core::ndarray::ArrayD<f64>>,
beliefs2: &std::collections::HashMap<String, scirs2_core::ndarray::ArrayD<f64>>,
) -> f64 {
let mut total_diff = 0.0;
let mut count = 0;
for (var, belief1) in beliefs1 {
if let Some(belief2) = beliefs2.get(var) {
let diff: f64 = (belief1 - belief2).mapv(|x| x.abs()).sum();
total_diff += diff;
count += 1;
}
}
if count > 0 {
total_diff / count as f64
} else {
0.0
}
}