use approx::assert_abs_diff_eq;
use tensorlogic_ir::{TLExpr, Term};
use tensorlogic_quantrs_hooks::{
expr_to_factor_graph, InferenceEngine, MarginalizationQuery, MessagePassingAlgorithm,
ParallelSumProduct, SumProductAlgorithm, VariableElimination,
};
#[test]
fn test_single_predicate_conversion() {
let expr = TLExpr::pred("P", vec![Term::var("x")]);
let graph =
expr_to_factor_graph(&expr).expect("Failed to convert single predicate to factor graph");
assert_eq!(graph.num_variables(), 1);
assert_eq!(graph.num_factors(), 1);
}
#[test]
fn test_conjunction_conversion() {
let expr = TLExpr::and(
TLExpr::pred("P", vec![Term::var("x")]),
TLExpr::pred("Q", vec![Term::var("y")]),
);
let graph = expr_to_factor_graph(&expr).expect("Failed to convert conjunction to factor graph");
assert_eq!(graph.num_variables(), 2);
assert_eq!(graph.num_factors(), 2);
}
#[test]
fn test_existential_quantification() {
let expr = TLExpr::exists("x", "Domain", TLExpr::pred("P", vec![Term::var("x")]));
let graph = expr_to_factor_graph(&expr)
.expect("Failed to convert existential quantification to factor graph");
assert_eq!(graph.num_variables(), 1);
assert!(graph.get_variable("x").is_some());
}
#[test]
fn test_nested_expressions() {
let inner = TLExpr::and(
TLExpr::pred("P", vec![Term::var("x")]),
TLExpr::pred("Q", vec![Term::var("x")]),
);
let expr = TLExpr::and(inner, TLExpr::pred("R", vec![Term::var("y")]));
let graph =
expr_to_factor_graph(&expr).expect("Failed to convert nested expressions to factor graph");
assert_eq!(graph.num_variables(), 2); assert_eq!(graph.num_factors(), 3); }
#[test]
fn test_end_to_end_inference() {
let expr = TLExpr::and(
TLExpr::pred("P", vec![Term::var("x")]),
TLExpr::pred("Q", vec![Term::var("x"), Term::var("y")]),
);
let graph =
expr_to_factor_graph(&expr).expect("Failed to convert P(x) ∧ Q(x,y) to factor graph");
let algorithm = SumProductAlgorithm::default();
let marginals = algorithm
.run(&graph)
.expect("Failed to run sum-product on P(x) ∧ Q(x,y)");
assert!(marginals.contains_key("x"));
assert!(marginals.contains_key("y"));
for marginal in marginals.values() {
let sum: f64 = marginal.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
}
}
#[test]
fn test_parallel_inference_from_tlexpr() {
let expr = TLExpr::and(
TLExpr::pred("P", vec![Term::var("x")]),
TLExpr::pred("Q", vec![Term::var("y")]),
);
let graph = expr_to_factor_graph(&expr)
.expect("Failed to convert P(x) ∧ Q(y) to factor graph for parallel inference");
let parallel_bp = ParallelSumProduct::default();
let marginals = parallel_bp
.run_parallel(&graph)
.expect("Failed to run parallel sum-product");
assert_eq!(marginals.len(), 2);
for marginal in marginals.values() {
let sum: f64 = marginal.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
}
}
#[test]
fn test_variable_elimination_from_tlexpr() {
let expr = TLExpr::and(
TLExpr::pred("P", vec![Term::var("x")]),
TLExpr::pred("Q", vec![Term::var("x"), Term::var("y")]),
);
let graph =
expr_to_factor_graph(&expr).expect("Failed to convert expression to factor graph for VE");
let ve = VariableElimination::new();
let marginal_x = ve
.marginalize(&graph, "x")
.expect("Failed to marginalize x with variable elimination");
assert_eq!(marginal_x.len(), 2); let sum: f64 = marginal_x.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
}
#[test]
fn test_inference_engine_with_tlexpr() {
let expr = TLExpr::and(
TLExpr::pred("P", vec![Term::var("x")]),
TLExpr::pred("Q", vec![Term::var("y")]),
);
let graph = expr_to_factor_graph(&expr)
.expect("Failed to convert expression to factor graph for inference engine");
let algorithm = Box::new(SumProductAlgorithm::default());
let engine = InferenceEngine::new(graph, algorithm);
let query = MarginalizationQuery {
variable: "x".to_string(),
};
let marginal = engine
.marginalize(&query)
.expect("Failed to compute marginal for x with inference engine");
let sum: f64 = marginal.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
}
#[test]
fn test_implication_conversion() {
let expr = TLExpr::imply(
TLExpr::pred("P", vec![Term::var("x")]),
TLExpr::pred("Q", vec![Term::var("x")]),
);
let graph = expr_to_factor_graph(&expr).expect("Failed to convert implication to factor graph");
assert!(graph.num_factors() >= 2);
assert!(graph.get_variable("x").is_some());
}
#[test]
fn test_universal_quantification() {
let expr = TLExpr::forall("x", "Domain", TLExpr::pred("P", vec![Term::var("x")]));
let graph = expr_to_factor_graph(&expr)
.expect("Failed to convert universal quantification to factor graph");
assert!(graph.get_variable("x").is_some());
}
#[test]
fn test_negation_conversion() {
let expr = TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")]));
let graph = expr_to_factor_graph(&expr).expect("Failed to convert negation to factor graph");
assert!(graph.get_variable("x").is_some());
}
#[test]
fn test_nested_quantifiers() {
let inner = TLExpr::forall(
"y",
"Domain",
TLExpr::pred("P", vec![Term::var("x"), Term::var("y")]),
);
let expr = TLExpr::exists("x", "Domain", inner);
let graph =
expr_to_factor_graph(&expr).expect("Failed to convert nested quantifiers to factor graph");
assert!(graph.get_variable("x").is_some());
assert!(graph.get_variable("y").is_some());
}
#[test]
fn test_shared_variables() {
let expr = TLExpr::and(
TLExpr::and(
TLExpr::pred("P", vec![Term::var("x")]),
TLExpr::pred("Q", vec![Term::var("x")]),
),
TLExpr::pred("R", vec![Term::var("x"), Term::var("y")]),
);
let graph = expr_to_factor_graph(&expr)
.expect("Failed to convert shared-variable expression to factor graph");
assert_eq!(graph.num_variables(), 2); assert_eq!(graph.num_factors(), 3);
if let Some(factors) = graph.get_adjacent_factors("x") {
assert_eq!(factors.len(), 3);
} else {
panic!("x should be connected to factors");
}
}
#[test]
fn test_probabilistic_reasoning() {
let expr = TLExpr::and(
TLExpr::pred("P", vec![Term::var("x")]),
TLExpr::pred("Q", vec![Term::var("x"), Term::var("y")]),
);
let graph = expr_to_factor_graph(&expr)
.expect("Failed to convert expression to factor graph for probabilistic reasoning");
let serial_bp = SumProductAlgorithm::default();
let serial_marginals = serial_bp
.run(&graph)
.expect("Failed to run serial sum-product for probabilistic reasoning");
let parallel_bp = ParallelSumProduct::default();
let parallel_marginals = parallel_bp
.run_parallel(&graph)
.expect("Failed to run parallel sum-product for probabilistic reasoning");
for var in ["x", "y"] {
let serial_m = &serial_marginals[var];
let parallel_m = ¶llel_marginals[var];
for i in 0..serial_m.len() {
assert_abs_diff_eq!(serial_m[[i]], parallel_m[[i]], epsilon = 1e-5);
}
}
}