use approx::assert_abs_diff_eq;
use proptest::prelude::*;
use scirs2_core::ndarray::Array;
use tensorlogic_quantrs_hooks::{
BetheApproximation, Factor, FactorGraph, JunctionTree, MeanFieldInference,
MessagePassingAlgorithm, SumProductAlgorithm, VariableElimination,
};
fn create_chain_graph(length: usize, card: usize) -> FactorGraph {
let mut graph = FactorGraph::new();
for i in 0..length {
graph.add_variable_with_card(format!("X_{}", i), "Domain".to_string(), card);
}
for i in 0..(length - 1) {
let size = card * card;
let values: Vec<f64> = (0..size)
.map(|i| (i as f64 + 1.0) / size as f64 + 0.1)
.collect();
let shape = vec![card, card];
let array = Array::from_shape_vec(shape, values)
.expect("create_chain_graph: Failed to create factor array")
.into_dyn();
let factor = Factor::new(
format!("psi_{}_{}", i, i + 1),
vec![format!("X_{}", i), format!("X_{}", i + 1)],
array,
)
.expect("create_chain_graph: Failed to create factor");
graph
.add_factor(factor)
.expect("create_chain_graph: Failed to add factor");
}
graph
}
proptest! {
#[test]
fn factor_product_commutative(
values1 in prop::collection::vec(0.1f64..10.0, 2),
values2 in prop::collection::vec(0.1f64..10.0, 2)
) {
let f1 = Factor::new(
"f1".to_string(),
vec!["X".to_string()],
Array::from_shape_vec(vec![2], values1)
.expect("factor_product_commutative: f1 array failed")
.into_dyn()
).expect("factor_product_commutative: f1 factor failed");
let f2 = Factor::new(
"f2".to_string(),
vec!["Y".to_string()],
Array::from_shape_vec(vec![2], values2)
.expect("factor_product_commutative: f2 array failed")
.into_dyn()
).expect("factor_product_commutative: f2 factor failed");
let p1 = f1.product(&f2).expect("factor_product_commutative: p1 product failed");
let p2 = f2.product(&f1).expect("factor_product_commutative: p2 product failed");
assert_eq!(p1.variables.len(), 2);
assert_eq!(p2.variables.len(), 2);
let sum1: f64 = p1.values.iter().sum();
let sum2: f64 = p2.values.iter().sum();
assert_abs_diff_eq!(sum1, sum2, epsilon = 1e-10);
}
#[test]
fn factor_product_associative(
values1 in prop::collection::vec(0.1f64..10.0, 2),
values2 in prop::collection::vec(0.1f64..10.0, 2),
values3 in prop::collection::vec(0.1f64..10.0, 2)
) {
let f1 = Factor::new(
"f1".to_string(),
vec!["X".to_string()],
Array::from_shape_vec(vec![2], values1)
.expect("factor_product_associative: f1 array failed")
.into_dyn()
).expect("factor_product_associative: f1 factor failed");
let f2 = Factor::new(
"f2".to_string(),
vec!["Y".to_string()],
Array::from_shape_vec(vec![2], values2)
.expect("factor_product_associative: f2 array failed")
.into_dyn()
).expect("factor_product_associative: f2 factor failed");
let f3 = Factor::new(
"f3".to_string(),
vec!["Z".to_string()],
Array::from_shape_vec(vec![2], values3)
.expect("factor_product_associative: f3 array failed")
.into_dyn()
).expect("factor_product_associative: f3 factor failed");
let p1 = f1.product(&f2)
.expect("factor_product_associative: (f1*f2) failed")
.product(&f3)
.expect("factor_product_associative: (f1*f2)*f3 failed");
let f2f3 = f2.product(&f3).expect("factor_product_associative: (f2*f3) failed");
let p2 = f1.product(&f2f3).expect("factor_product_associative: f1*(f2*f3) failed");
prop_assert_eq!(p1.variables.len(), 3);
prop_assert_eq!(p2.variables.len(), 3);
let sum1: f64 = p1.values.iter().sum();
let sum2: f64 = p2.values.iter().sum();
assert_abs_diff_eq!(sum1, sum2, epsilon = 1e-10);
}
#[test]
fn marginalization_normalizes(
values in prop::collection::vec(0.1f64..10.0, 4)
) {
let factor = Factor::new(
"joint".to_string(),
vec!["X".to_string(), "Y".to_string()],
Array::from_shape_vec(vec![2, 2], values)
.expect("marginalization_normalizes: array failed")
.into_dyn()
).expect("marginalization_normalizes: factor failed");
let mut marginal = factor.marginalize_out("Y")
.expect("marginalization_normalizes: marginalize_out failed");
marginal.normalize();
let sum: f64 = marginal.values.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
}
#[test]
fn marginalization_order_independent(
values in prop::collection::vec(0.1f64..10.0, 8)
) {
let factor = Factor::new(
"joint".to_string(),
vec!["X".to_string(), "Y".to_string(), "Z".to_string()],
Array::from_shape_vec(vec![2, 2, 2], values)
.expect("marginalization_order_independent: array failed")
.into_dyn()
).expect("marginalization_order_independent: factor failed");
let m1 = factor.marginalize_out("Y")
.expect("marginalization_order_independent: marginalize Y failed")
.marginalize_out("Z")
.expect("marginalization_order_independent: marginalize Z (after Y) failed");
let m2 = factor.marginalize_out("Z")
.expect("marginalization_order_independent: marginalize Z failed")
.marginalize_out("Y")
.expect("marginalization_order_independent: marginalize Y (after Z) failed");
prop_assert_eq!(m1.variables.len(), 1);
prop_assert_eq!(m2.variables.len(), 1);
for i in 0..m1.values.len() {
assert_abs_diff_eq!(m1.values[[i]], m2.values[[i]], epsilon = 1e-10);
}
}
#[test]
fn factor_division_inverse(
values1 in prop::collection::vec(0.5f64..10.0, 2),
values2 in prop::collection::vec(0.1f64..2.0, 2)
) {
let f1 = Factor::new(
"f1".to_string(),
vec!["X".to_string()],
Array::from_shape_vec(vec![2], values1.clone())
.expect("factor_division_inverse: f1 array failed")
.into_dyn()
).expect("factor_division_inverse: f1 factor failed");
let f2 = Factor::new(
"f2".to_string(),
vec!["X".to_string()],
Array::from_shape_vec(vec![2], values2)
.expect("factor_division_inverse: f2 array failed")
.into_dyn()
).expect("factor_division_inverse: f2 factor failed");
let product = f1.product(&f2).expect("factor_division_inverse: product failed");
let quotient = product.divide(&f2).expect("factor_division_inverse: divide failed");
prop_assert_eq!(quotient.variables.len(), 1);
for (i, &expected) in values1.iter().enumerate().take(2) {
assert_abs_diff_eq!(quotient.values[[i]], expected, epsilon = 1e-6);
}
}
#[test]
fn reduction_preserves_normalization(
values in prop::collection::vec(0.1f64..10.0, 4),
evidence_val in 0usize..2
) {
let mut factor = Factor::new(
"joint".to_string(),
vec!["X".to_string(), "Y".to_string()],
Array::from_shape_vec(vec![2, 2], values)
.expect("reduction_preserves_normalization: array failed")
.into_dyn()
).expect("reduction_preserves_normalization: factor failed");
factor.normalize();
let mut reduced = factor.reduce("Y", evidence_val)
.expect("reduction_preserves_normalization: reduce failed");
reduced.normalize();
let sum: f64 = reduced.values.iter().sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
}
#[test]
#[ignore]
fn product_marginalize_commute(
values1 in prop::collection::vec(0.1f64..10.0, 4),
values2 in prop::collection::vec(0.1f64..10.0, 4)
) {
let f1 = Factor::new(
"f1".to_string(),
vec!["X".to_string(), "Y".to_string()],
Array::from_shape_vec(vec![2, 2], values1)
.expect("product_marginalize_commute: f1 array failed")
.into_dyn()
).expect("product_marginalize_commute: f1 factor failed");
let f2 = Factor::new(
"f2".to_string(),
vec!["Y".to_string(), "Z".to_string()],
Array::from_shape_vec(vec![2, 2], values2)
.expect("product_marginalize_commute: f2 array failed")
.into_dyn()
).expect("product_marginalize_commute: f2 factor failed");
let p1 = f1.product(&f2)
.expect("product_marginalize_commute: (f1*f2) failed")
.marginalize_out("Z")
.expect("product_marginalize_commute: marginalize Z failed");
let f2_marg = f2.marginalize_out("Z")
.expect("product_marginalize_commute: f2 marginalize Z failed");
let p2 = f1.product(&f2_marg)
.expect("product_marginalize_commute: f1*(f2_marg) failed");
prop_assert_eq!(p1.variables.len(), 2);
prop_assert_eq!(p2.variables.len(), 2);
let mut p1_norm = p1.clone();
p1_norm.normalize();
let mut p2_norm = p2.clone();
p2_norm.normalize();
for i in 0..p1_norm.values.len() {
assert_abs_diff_eq!(p1_norm.values[[i]], p2_norm.values[[i]], epsilon = 1e-6);
}
}
#[test]
fn conditioning_reduces_variables(
values in prop::collection::vec(0.1f64..10.0, 4),
evidence in 0usize..2
) {
let factor = Factor::new(
"joint".to_string(),
vec!["X".to_string(), "Y".to_string()],
Array::from_shape_vec(vec![2, 2], values)
.expect("conditioning_reduces_variables: array failed")
.into_dyn()
).expect("conditioning_reduces_variables: factor failed");
let reduced = factor.reduce("Y", evidence)
.expect("conditioning_reduces_variables: reduce failed");
prop_assert_eq!(factor.variables.len(), 2);
prop_assert_eq!(reduced.variables.len(), 1);
}
}
proptest! {
#[test]
fn sum_product_produces_normalized_marginals(length in 3usize..6) {
let graph = create_chain_graph(length, 2);
let sp = SumProductAlgorithm::default();
let marginals = sp.run(&graph)
.expect("sum_product_produces_normalized_marginals: run failed");
for (var_name, marginal) in &marginals {
let sum: f64 = marginal.iter().sum();
prop_assert!(
(sum - 1.0).abs() < 1e-6,
"Marginal for {} does not sum to 1: {}",
var_name,
sum
);
}
}
#[test] #[ignore]
#[ignore]
fn ve_equals_bp_on_trees(length in 3usize..5) {
let graph = create_chain_graph(length, 2);
let sp = SumProductAlgorithm::default();
let bp_marginals = sp.run(&graph)
.expect("ve_equals_bp_on_trees: sp.run failed");
let ve = VariableElimination::new();
for var_name in graph.variable_names() {
let ve_marginal = ve.marginalize(&graph, var_name)
.expect("ve_equals_bp_on_trees: ve.marginalize failed");
let bp_marginal = bp_marginals.get(var_name)
.expect("ve_equals_bp_on_trees: bp_marginals.get failed");
for i in 0..ve_marginal.len() {
assert_abs_diff_eq!(ve_marginal[[i]], bp_marginal[[i]], epsilon = 1e-2);
}
}
}
#[test] #[ignore]
#[ignore]
fn junction_tree_produces_consistent_marginals(length in 3usize..5) {
let graph = create_chain_graph(length, 2);
let mut jt = JunctionTree::from_factor_graph(&graph)
.expect("junction_tree_produces_consistent_marginals: from_factor_graph failed");
jt.calibrate()
.expect("junction_tree_produces_consistent_marginals: calibrate failed");
for var_name in graph.variable_names() {
let marginal = jt.query_marginal(var_name)
.expect("junction_tree_produces_consistent_marginals: query_marginal failed");
let sum: f64 = marginal.iter().sum();
prop_assert!(
(sum - 1.0).abs() < 1e-6,
"Marginal for {} does not sum to 1: {}",
var_name,
sum
);
for val in marginal.iter() {
prop_assert!(
*val >= -1e-10,
"Marginal contains negative value: {}",
val
);
}
}
}
#[test]
fn mean_field_produces_valid_distributions(length in 3usize..5) {
let graph = create_chain_graph(length, 2);
let mf = MeanFieldInference::new(100, 1e-4);
let result = mf.run(&graph);
if let Ok(marginals) = result {
for (var_name, marginal) in &marginals {
let sum: f64 = marginal.iter().sum();
prop_assert!(
(sum - 1.0).abs() < 1e-3,
"Mean-field marginal for {} does not sum to 1: {}",
var_name,
sum
);
for val in marginal.iter() {
prop_assert!(
*val >= -1e-10,
"Mean-field marginal contains negative value: {}",
val
);
}
}
}
}
#[test]
fn bethe_produces_valid_distributions(length in 3usize..4) {
let graph = create_chain_graph(length, 2);
let bethe = BetheApproximation::new(50, 1e-4, 0.0);
let result = bethe.run(&graph);
if let Ok(marginals) = result {
for (var_name, marginal) in &marginals {
let sum: f64 = marginal.iter().sum();
prop_assert!(
(sum - 1.0).abs() < 1e-3,
"Bethe marginal for {} does not sum to 1: {}",
var_name,
sum
);
for val in marginal.iter() {
prop_assert!(
*val >= -1e-10,
"Bethe marginal contains negative value: {}",
val
);
}
}
}
}
}
proptest! {
#[test] #[ignore]
#[ignore]
fn exact_inference_consistency_on_trees(length in 3usize..4) {
let graph = create_chain_graph(length, 2);
let sp = SumProductAlgorithm::default();
let bp_marginals = sp.run(&graph)
.expect("exact_inference_consistency_on_trees: sp.run failed");
let ve = VariableElimination::new();
let mut jt = JunctionTree::from_factor_graph(&graph)
.expect("exact_inference_consistency_on_trees: from_factor_graph failed");
jt.calibrate()
.expect("exact_inference_consistency_on_trees: calibrate failed");
for var_name in graph.variable_names() {
let bp_m = bp_marginals.get(var_name)
.expect("exact_inference_consistency_on_trees: bp_marginals.get failed");
let ve_m = ve.marginalize(&graph, var_name)
.expect("exact_inference_consistency_on_trees: ve.marginalize failed");
let jt_m = jt.query_marginal(var_name)
.expect("exact_inference_consistency_on_trees: jt.query_marginal failed");
for i in 0..bp_m.len() {
assert_abs_diff_eq!(bp_m[[i]], ve_m[[i]], epsilon = 1e-2);
}
for i in 0..bp_m.len() {
assert_abs_diff_eq!(bp_m[[i]], jt_m[[i]], epsilon = 1e-2);
}
}
}
}