use scirs2_core::ndarray::Array;
use tensorlogic_quantrs_hooks::{
BayesianNetwork, InferenceEngine, MarginalizationQuery, SumProductAlgorithm,
VariableElimination,
};
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== Bayesian Network: Student Performance Model ===\n");
let mut bn = BayesianNetwork::new();
bn.add_variable("Difficulty".to_string(), 2); bn.add_variable("Intelligence".to_string(), 2); bn.add_variable("Grade".to_string(), 3); bn.add_variable("SAT".to_string(), 2);
println!("Building network structure...");
let p_difficulty = Array::from_shape_vec(vec![2], vec![0.6, 0.4])?.into_dyn();
bn.add_prior("Difficulty".to_string(), p_difficulty)?;
let p_intelligence = Array::from_shape_vec(vec![2], vec![0.7, 0.3])?.into_dyn();
bn.add_prior("Intelligence".to_string(), p_intelligence)?;
let p_grade = Array::from_shape_vec(
vec![2, 2, 3],
vec![
0.3, 0.4, 0.3, 0.9, 0.08, 0.02, 0.05, 0.25, 0.7, 0.5, 0.3, 0.2, ],
)?
.into_dyn();
bn.add_cpd(
"Grade".to_string(),
vec!["Difficulty".to_string(), "Intelligence".to_string()],
p_grade,
)?;
let p_sat = Array::from_shape_vec(
vec![2, 2],
vec![
0.95, 0.05, 0.2, 0.8, ],
)?
.into_dyn();
bn.add_cpd("SAT".to_string(), vec!["Intelligence".to_string()], p_sat)?;
println!(
"Network has {} variables and {} factors",
bn.graph().num_variables(),
bn.graph().num_factors()
);
assert!(bn.is_acyclic(), "Network must be acyclic!");
println!("✓ Network is a valid DAG\n");
let topo_order = bn.topological_order()?;
println!("Topological order: {:?}\n", topo_order);
println!("=== Query 1: What's the probability distribution over Grades? ===");
let algorithm = Box::new(SumProductAlgorithm::default());
let engine = InferenceEngine::new(bn.graph().clone(), algorithm);
let query = MarginalizationQuery {
variable: "Grade".to_string(),
};
let grade_marginal = engine.marginalize(&query)?;
println!("P(Grade):");
println!(" P(Grade=A) = {:.3}", grade_marginal[[0]]);
println!(" P(Grade=B) = {:.3}", grade_marginal[[1]]);
println!(" P(Grade=C) = {:.3}\n", grade_marginal[[2]]);
println!("=== Query 2: What's the probability distribution over SAT scores? ===");
let ve = VariableElimination::new();
let sat_marginal = ve.marginalize(bn.graph(), "SAT")?;
println!("P(SAT):");
println!(" P(SAT=Low) = {:.3}", sat_marginal[[0]]);
println!(" P(SAT=High) = {:.3}\n", sat_marginal[[1]]);
println!("=== Query 3: If we observe SAT=High, what's the distribution over Intelligence? ===");
use tensorlogic_quantrs_hooks::Factor;
let mut evidence_graph = bn.graph().clone();
let evidence_values = Array::from_shape_vec(vec![2], vec![0.0, 1.0])?.into_dyn();
let evidence = Factor::new(
"Evidence_SAT".to_string(),
vec!["SAT".to_string()],
evidence_values,
)?;
evidence_graph.add_factor(evidence)?;
let ve = VariableElimination::new();
let intel_given_sat = ve.marginalize(&evidence_graph, "Intelligence")?;
println!("P(Intelligence | SAT=High):");
println!(
" P(Intelligence=Low | SAT=High) = {:.3}",
intel_given_sat[[0]]
);
println!(
" P(Intelligence=High | SAT=High) = {:.3}\n",
intel_given_sat[[1]]
);
println!("=== Query 4: Computing joint probability ===");
let joint = engine.joint()?;
println!("Joint distribution shape: {:?}", joint.shape());
let joint_sum: f64 = joint.iter().sum();
println!(
"Joint distribution sums to: {:.6} (should be ~1.0)\n",
joint_sum
);
println!("=== Summary ===");
println!("✓ Built a 4-variable Bayesian Network");
println!("✓ Verified DAG property and computed topological order");
println!("✓ Performed marginal inference using Sum-Product");
println!("✓ Performed marginal inference using Variable Elimination");
println!("✓ Computed conditional probabilities with evidence");
println!("✓ Computed joint distribution over all variables");
Ok(())
}