use scirs2_core::ndarray::Array;
use tensorlogic_quantrs_hooks::{Factor, FactorGraph, JunctionTree};
fn main() -> anyhow::Result<()> {
println!("=== Junction Tree Exact Inference Example ===\n");
let graph = build_student_network()?;
println!("Factor Graph Statistics:");
println!(" Variables: {}", graph.num_variables());
println!(" Factors: {}", graph.num_factors());
println!();
println!("Constructing junction tree...");
let mut tree = JunctionTree::from_factor_graph(&graph)?;
println!("Junction Tree Structure:");
println!(" Number of cliques: {}", tree.cliques.len());
println!(" Number of edges: {}", tree.edges.len());
println!(" Treewidth: {}", tree.treewidth());
println!();
println!("Cliques:");
for (i, clique) in tree.cliques.iter().enumerate() {
let vars: Vec<&String> = clique.variables.iter().collect();
println!(" Clique {}: {:?}", i, vars);
}
println!();
println!(
"Running Intersection Property: {}",
if tree.verify_running_intersection_property() {
"✓ Satisfied"
} else {
"✗ Not satisfied"
}
);
println!();
println!("Calibrating junction tree...");
tree.calibrate()?;
println!("✓ Calibration complete\n");
println!("=== Marginal Queries ===\n");
println!("Query: P(Intelligence)");
let p_intelligence = tree.query_marginal("Intelligence")?;
println!(" P(Intelligence = Low) = {:.4}", p_intelligence[[0]]);
println!(" P(Intelligence = High) = {:.4}", p_intelligence[[1]]);
println!();
println!("Query: P(Difficulty)");
let p_difficulty = tree.query_marginal("Difficulty")?;
println!(" P(Difficulty = Easy) = {:.4}", p_difficulty[[0]]);
println!(" P(Difficulty = Hard) = {:.4}", p_difficulty[[1]]);
println!();
println!("Query: P(Grade)");
let p_grade = tree.query_marginal("Grade")?;
println!(" P(Grade = A) = {:.4}", p_grade[[0]]);
println!(" P(Grade = B) = {:.4}", p_grade[[1]]);
println!(" P(Grade = C) = {:.4}", p_grade[[2]]);
println!();
println!("Query: P(Letter)");
let p_letter = tree.query_marginal("Letter")?;
println!(" P(Letter = Weak) = {:.4}", p_letter[[0]]);
println!(" P(Letter = Strong) = {:.4}", p_letter[[1]]);
println!();
println!("=== Joint Query ===\n");
println!("Query: P(Intelligence, Difficulty)");
let p_joint =
tree.query_joint_marginal(&["Intelligence".to_string(), "Difficulty".to_string()])?;
println!("Shape: {:?}", p_joint.shape());
println!(
" P(Intelligence=Low, Difficulty=Easy) = {:.4}",
p_joint[[0, 0]]
);
println!(
" P(Intelligence=Low, Difficulty=Hard) = {:.4}",
p_joint[[0, 1]]
);
println!(
" P(Intelligence=High, Difficulty=Easy) = {:.4}",
p_joint[[1, 0]]
);
println!(
" P(Intelligence=High, Difficulty=Hard) = {:.4}",
p_joint[[1, 1]]
);
println!();
println!("=== Performance Analysis ===\n");
println!("Complexity Analysis:");
println!(" Treewidth: {}", tree.treewidth());
println!(
" Max clique size: {}",
tree.cliques
.iter()
.map(|c| c.variables.len())
.max()
.unwrap_or(0)
);
println!(
" Avg separator size: {:.2}",
if tree.edges.is_empty() {
0.0
} else {
tree.edges
.iter()
.map(|e| e.separator.variables.len())
.sum::<usize>() as f64
/ tree.edges.len() as f64
}
);
println!();
println!("Advantages of Junction Tree Algorithm:");
println!(" ✓ Exact inference (no approximation)");
println!(" ✓ Efficient message passing on tree structure");
println!(" ✓ Handles any query without recomputation");
println!(" ✓ Guarantees consistency across marginals");
println!();
println!("✓ Example completed successfully!");
Ok(())
}
fn build_student_network() -> anyhow::Result<FactorGraph> {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("Intelligence".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("Difficulty".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("Grade".to_string(), "Ternary".to_string(), 3);
graph.add_variable_with_card("Letter".to_string(), "Binary".to_string(), 2);
let p_intelligence = Factor::new(
"P(Intelligence)".to_string(),
vec!["Intelligence".to_string()],
Array::from_shape_vec(vec![2], vec![0.7, 0.3])?.into_dyn(),
)?;
graph.add_factor(p_intelligence)?;
let p_difficulty = Factor::new(
"P(Difficulty)".to_string(),
vec!["Difficulty".to_string()],
Array::from_shape_vec(vec![2], vec![0.6, 0.4])?.into_dyn(),
)?;
graph.add_factor(p_difficulty)?;
#[rustfmt::skip]
let grade_values = vec![
0.3, 0.4, 0.3, 0.05, 0.25, 0.7, 0.9, 0.08, 0.02, 0.5, 0.3, 0.2, ];
let p_grade = Factor::new(
"P(Grade|Intelligence,Difficulty)".to_string(),
vec![
"Intelligence".to_string(),
"Difficulty".to_string(),
"Grade".to_string(),
],
Array::from_shape_vec(vec![2, 2, 3], grade_values)?.into_dyn(),
)?;
graph.add_factor(p_grade)?;
let p_letter = Factor::new(
"P(Letter|Grade)".to_string(),
vec!["Grade".to_string(), "Letter".to_string()],
Array::from_shape_vec(
vec![3, 2],
vec![
0.1, 0.9, 0.4, 0.6, 0.99, 0.01, ],
)?
.into_dyn(),
)?;
graph.add_factor(p_letter)?;
Ok(graph)
}