use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array;
use tensorlogic_quantrs_hooks::*;
#[test]
fn test_chain_graph() {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
let px_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
.expect("Failed to create px array")
.into_dyn();
let px = Factor::new("P(X)".to_string(), vec!["x".to_string()], px_values)
.expect("Failed to create P(X) factor");
graph.add_factor(px).expect("Failed to add P(X) to graph");
let py_given_x_values = Array::from_shape_vec(
vec![2, 2],
vec![
0.8, 0.2, 0.3, 0.7, ],
)
.expect("Failed to create P(Y|X) array")
.into_dyn();
let py_given_x = Factor::new(
"P(Y|X)".to_string(),
vec!["x".to_string(), "y".to_string()],
py_given_x_values,
)
.expect("Failed to create P(Y|X) factor");
graph
.add_factor(py_given_x)
.expect("Failed to add P(Y|X) to graph");
let algorithm = Box::new(SumProductAlgorithm::default());
let engine = InferenceEngine::new(graph, algorithm);
let query_x = MarginalizationQuery {
variable: "x".to_string(),
};
let marginal_x = engine.marginalize(&query_x);
assert!(marginal_x.is_ok());
let marginal_x_values = marginal_x.expect("Failed to compute marginal for x");
let sum: f64 = marginal_x_values.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
}
#[test]
fn test_factor_product_integration() {
let f1_values = Array::from_shape_vec(vec![2], vec![0.6, 0.4])
.expect("Failed to create f1 array")
.into_dyn();
let f1 = Factor::new("f1".to_string(), vec!["x".to_string()], f1_values)
.expect("Failed to create f1 factor");
let f2_values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
.expect("Failed to create f2 array")
.into_dyn();
let f2 = Factor::new("f2".to_string(), vec!["y".to_string()], f2_values)
.expect("Failed to create f2 factor");
let product = f1.product(&f2).expect("Failed to compute factor product");
assert_eq!(product.variables.len(), 2);
assert_eq!(product.values.shape(), &[2, 2]);
assert_abs_diff_eq!(product.values[[0, 0]], 0.6 * 0.7, epsilon = 1e-10);
assert_abs_diff_eq!(product.values[[0, 1]], 0.6 * 0.3, epsilon = 1e-10);
assert_abs_diff_eq!(product.values[[1, 0]], 0.4 * 0.7, epsilon = 1e-10);
assert_abs_diff_eq!(product.values[[1, 1]], 0.4 * 0.3, epsilon = 1e-10);
}
#[test]
fn test_factor_marginalization_integration() {
let joint_values = Array::from_shape_vec(
vec![2, 2],
vec![
0.24, 0.06, 0.56, 0.14, ],
)
.expect("Failed to create joint array")
.into_dyn();
let joint = Factor::new(
"P(X,Y)".to_string(),
vec!["x".to_string(), "y".to_string()],
joint_values,
)
.expect("Failed to create joint factor");
let marginal_x = joint
.marginalize_out("y")
.expect("Failed to marginalize out y");
assert_eq!(marginal_x.variables.len(), 1);
assert_eq!(marginal_x.variables[0], "x");
assert_abs_diff_eq!(marginal_x.values[[0]], 0.30, epsilon = 1e-10);
assert_abs_diff_eq!(marginal_x.values[[1]], 0.70, epsilon = 1e-10);
}
#[test]
fn test_factor_reduce_integration() {
let joint_values = Array::from_shape_vec(
vec![2, 2],
vec![
0.24, 0.06, 0.56, 0.14, ],
)
.expect("Failed to create joint array")
.into_dyn();
let joint = Factor::new(
"P(X,Y)".to_string(),
vec!["x".to_string(), "y".to_string()],
joint_values,
)
.expect("Failed to create joint factor");
let reduced = joint.reduce("y", 0).expect("Failed to reduce with Y=0");
assert_eq!(reduced.variables.len(), 1);
assert_eq!(reduced.variables[0], "x");
assert_abs_diff_eq!(reduced.values[[0]], 0.24, epsilon = 1e-10);
assert_abs_diff_eq!(reduced.values[[1]], 0.56, epsilon = 1e-10);
}
#[test]
fn test_message_passing_convergence() {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
let algorithm = Box::new(SumProductAlgorithm::new(10, 1e-6, 0.0));
let engine = InferenceEngine::new(graph, algorithm);
let query = MarginalizationQuery {
variable: "x".to_string(),
};
let result = engine.marginalize(&query);
assert!(result.is_ok());
let marginal = result.expect("Failed to compute marginal for convergence test");
assert_abs_diff_eq!(marginal[[0]], 0.5, epsilon = 1e-6);
assert_abs_diff_eq!(marginal[[1]], 0.5, epsilon = 1e-6);
}
#[test]
fn test_tlexpr_to_factor_graph() {
use tensorlogic_ir::{TLExpr, Term};
let expr = TLExpr::and(
TLExpr::pred("P", vec![Term::var("x")]),
TLExpr::pred("Q", vec![Term::var("y")]),
);
let graph = tensorlogic_quantrs_hooks::expr_to_factor_graph(&expr);
assert!(graph.is_ok());
let graph = graph.expect("Failed to convert TLExpr to factor graph");
assert_eq!(graph.num_variables(), 2);
assert_eq!(graph.num_factors(), 2);
}
#[test]
fn test_tlexpr_with_exists() {
use tensorlogic_ir::{TLExpr, Term};
let expr = TLExpr::exists("x", "Domain", TLExpr::pred("P", vec![Term::var("x")]));
let graph = tensorlogic_quantrs_hooks::expr_to_factor_graph(&expr);
assert!(graph.is_ok());
let graph = graph.expect("Failed to convert existential TLExpr to factor graph");
assert!(graph.num_variables() > 0);
}
#[test]
fn test_marginalization_lib_function() {
use scirs2_core::ndarray::Array;
let joint = Array::from_shape_vec(
vec![2, 2],
vec![
0.25, 0.25, 0.25, 0.25, ],
)
.expect("Failed to create joint array")
.into_dyn();
let marginal = tensorlogic_quantrs_hooks::marginalize(&joint, 0, &[0, 1])
.expect("Failed to marginalize joint distribution");
assert_eq!(marginal.ndim(), 1);
let sum: f64 = marginal.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
}
#[test]
fn test_conditioning_lib_function() {
use scirs2_core::ndarray::Array;
use std::collections::HashMap;
let joint = Array::from_shape_vec(
vec![2, 2],
vec![
0.2, 0.3, 0.4, 0.1, ],
)
.expect("Failed to create joint array")
.into_dyn();
let mut evidence = HashMap::new();
evidence.insert(1, 0);
let conditional = tensorlogic_quantrs_hooks::condition(&joint, &evidence)
.expect("Failed to condition on Y=0");
assert_eq!(conditional.ndim(), 1);
let sum: f64 = conditional.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(conditional[[0]], 1.0 / 3.0, epsilon = 1e-6);
assert_abs_diff_eq!(conditional[[1]], 2.0 / 3.0, epsilon = 1e-6);
}
#[test]
fn test_joint_computation() {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
let algorithm = Box::new(SumProductAlgorithm::default());
let engine = InferenceEngine::new(graph, algorithm);
let joint = engine.joint();
assert!(joint.is_ok());
let joint_dist = joint.expect("Failed to compute joint distribution");
let sum: f64 = joint_dist.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
}
#[test]
fn test_loopy_bp_with_damping() {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
let algorithm = Box::new(SumProductAlgorithm::new(50, 1e-6, 0.5));
let engine = InferenceEngine::new(graph, algorithm);
let query = MarginalizationQuery {
variable: "x".to_string(),
};
let result = engine.marginalize(&query);
assert!(result.is_ok());
}
#[test]
fn test_factor_division() {
let f1_values = Array::from_shape_vec(vec![2], vec![0.6, 0.8])
.expect("Failed to create f1 array")
.into_dyn();
let f1 = Factor::new("f1".to_string(), vec!["x".to_string()], f1_values)
.expect("Failed to create f1 factor");
let f2_values = Array::from_shape_vec(vec![2], vec![0.3, 0.4])
.expect("Failed to create f2 array")
.into_dyn();
let f2 = Factor::new("f2".to_string(), vec!["x".to_string()], f2_values)
.expect("Failed to create f2 factor");
let result = f1.divide(&f2).expect("Failed to divide factors");
assert_abs_diff_eq!(result.values[[0]], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(result.values[[1]], 2.0, epsilon = 1e-10);
}
#[test]
fn test_message_passing_with_factors() {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("var_0".to_string(), "Binary".to_string(), 2);
let factor_values = Array::from_shape_vec(vec![2], vec![0.7, 0.3])
.expect("Failed to create factor_values array")
.into_dyn();
let factor = Factor::new(
"factor_0".to_string(),
vec!["var_0".to_string()],
factor_values,
)
.expect("Failed to create factor_0");
graph
.add_factor(factor)
.expect("Failed to add factor_0 to graph");
let algorithm = SumProductAlgorithm::default();
let result = algorithm.run(&graph);
assert!(result.is_ok());
let beliefs = result.expect("Failed to run sum-product algorithm");
if let Some(belief) = beliefs.get("var_0") {
let sum: f64 = belief.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
assert_abs_diff_eq!(belief[[0]], 0.7, epsilon = 1e-6);
assert_abs_diff_eq!(belief[[1]], 0.3, epsilon = 1e-6);
}
}