use std::collections::{HashMap, HashSet};
use biodivine_lib_bdd::{Bdd, BddPointer, BddVariable, BddVariableSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SemiringOp {
Disjunction,
Conjunction,
}
#[derive(Debug)]
pub struct WmcResult {
pub probability: f64,
pub approximated: bool,
pub variable_count: usize,
}
pub fn weighted_model_count(
group_lineage: &[HashSet<Vec<u8>>],
base_fact_weights: &HashMap<Vec<u8>, f64>,
semiring_op: SemiringOp,
max_bdd_variables: usize,
) -> WmcResult {
let is_disjunction = semiring_op == SemiringOp::Disjunction;
if group_lineage.is_empty() {
return WmcResult {
probability: if is_disjunction { 0.0 } else { 1.0 },
approximated: false,
variable_count: 0,
};
}
let mut all_facts: Vec<Vec<u8>> =
HashSet::<&Vec<u8>>::from_iter(group_lineage.iter().flat_map(|s| s.iter()))
.into_iter()
.cloned()
.collect();
all_facts.sort();
let variable_count = all_facts.len();
if variable_count > max_bdd_variables || variable_count > u16::MAX as usize {
return WmcResult {
probability: 0.0, approximated: true,
variable_count,
};
}
let fact_to_idx: HashMap<&Vec<u8>, usize> =
all_facts.iter().enumerate().map(|(i, f)| (f, i)).collect();
let vars = BddVariableSet::new_anonymous(variable_count as u16);
let bdd_vars: Vec<BddVariable> = vars.variables();
let mut combined: Option<Bdd> = None;
for row_facts in group_lineage {
if row_facts.is_empty() {
let term = vars.mk_true();
combined = Some(match combined {
Some(acc) => {
if is_disjunction {
acc.or(&term)
} else {
acc.and(&term)
}
}
None => term,
});
continue;
}
let mut term = vars.mk_true();
for fact in row_facts {
if let Some(&idx) = fact_to_idx.get(fact) {
let var_bdd = vars.mk_var(bdd_vars[idx]);
term = term.and(&var_bdd);
}
}
combined = Some(match combined {
Some(acc) => {
if is_disjunction {
acc.or(&term)
} else {
acc.and(&term)
}
}
None => term,
});
}
let bdd = match combined {
Some(b) => b,
None => {
return WmcResult {
probability: if is_disjunction { 0.0 } else { 1.0 },
approximated: false,
variable_count,
};
}
};
let prob_map: HashMap<BddVariable, f64> = all_facts
.iter()
.enumerate()
.map(|(i, fact)| {
let p = base_fact_weights.get(fact).copied().unwrap_or(0.5);
(bdd_vars[i], p)
})
.collect();
let probability = eval_bdd_probability(&bdd, &prob_map);
WmcResult {
probability,
approximated: false,
variable_count,
}
}
fn eval_bdd_probability(bdd: &Bdd, prob_map: &HashMap<BddVariable, f64>) -> f64 {
let mut memo: HashMap<BddPointer, f64> = HashMap::new();
eval_ptr(bdd, bdd.root_pointer(), prob_map, &mut memo)
}
fn eval_ptr(
bdd: &Bdd,
ptr: BddPointer,
prob_map: &HashMap<BddVariable, f64>,
memo: &mut HashMap<BddPointer, f64>,
) -> f64 {
if ptr.is_zero() {
return 0.0;
}
if ptr.is_one() {
return 1.0;
}
if let Some(&cached) = memo.get(&ptr) {
return cached;
}
let var = bdd.var_of(ptr);
let p = prob_map.get(&var).copied().unwrap_or(0.5);
let lo = eval_ptr(bdd, bdd.low_link_of(ptr), prob_map, memo);
let hi = eval_ptr(bdd, bdd.high_link_of(ptr), prob_map, memo);
let result = (1.0 - p) * lo + p * hi;
memo.insert(ptr, result);
result
}
#[cfg(test)]
mod tests {
use super::*;
fn noisy_or_independent(probs: &[f64]) -> f64 {
1.0 - probs.iter().fold(1.0, |acc, &p| acc * (1.0 - p))
}
#[test]
fn independent_facts_mnor_matches_noisy_or() {
let a = b"fact_a".to_vec();
let b = b"fact_b".to_vec();
let rows = vec![HashSet::from([a.clone()]), HashSet::from([b.clone()])];
let probs = HashMap::from([(a, 0.3), (b, 0.5)]);
let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
assert!(!result.approximated);
assert_eq!(result.variable_count, 2);
let expected = noisy_or_independent(&[0.3, 0.5]);
assert!(
(result.probability - expected).abs() < 1e-10,
"BDD={}, expected={}",
result.probability,
expected
);
}
#[test]
fn shared_facts_mnor_differs_from_independence() {
let a = b"fact_a".to_vec();
let b = b"fact_b".to_vec();
let c = b"fact_c".to_vec();
let rows = vec![
HashSet::from([a.clone(), c.clone()]),
HashSet::from([b.clone(), c.clone()]),
];
let probs = HashMap::from([(a, 0.3), (b, 0.5), (c, 0.7)]);
let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
assert!(!result.approximated);
assert_eq!(result.variable_count, 3);
let expected_exact = 0.455;
assert!(
(result.probability - expected_exact).abs() < 1e-10,
"BDD={}, expected={}",
result.probability,
expected_exact
);
let independence = noisy_or_independent(&[0.3 * 0.7, 0.5 * 0.7]);
assert!(
(result.probability - independence).abs() > 0.01,
"BDD result should differ from independence: BDD={}, indep={}",
result.probability,
independence
);
}
#[test]
fn shared_facts_mprod() {
let a = b"fact_a".to_vec();
let b = b"fact_b".to_vec();
let c = b"fact_c".to_vec();
let rows = vec![
HashSet::from([a.clone(), c.clone()]),
HashSet::from([b.clone(), c.clone()]),
];
let probs = HashMap::from([(a, 0.3), (b, 0.5), (c, 0.7)]);
let result = weighted_model_count(&rows, &probs, SemiringOp::Conjunction, 1000);
assert!(!result.approximated);
assert_eq!(result.variable_count, 3);
let expected = 0.3 * 0.5 * 0.7;
assert!(
(result.probability - expected).abs() < 1e-10,
"BDD={}, expected={}",
result.probability,
expected
);
}
#[test]
fn bdd_limit_exceeded_falls_back() {
let a = b"fact_a".to_vec();
let b = b"fact_b".to_vec();
let rows = vec![HashSet::from([a.clone()]), HashSet::from([b.clone()])];
let probs = HashMap::from([(a, 0.3), (b, 0.5)]);
let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1);
assert!(result.approximated);
assert_eq!(result.variable_count, 2);
}
#[test]
fn empty_group_returns_identity() {
let probs = HashMap::new();
let nor_result = weighted_model_count(&[], &probs, SemiringOp::Disjunction, 1000);
assert!(!nor_result.approximated);
assert!((nor_result.probability - 0.0).abs() < 1e-10);
let prod_result = weighted_model_count(&[], &probs, SemiringOp::Conjunction, 1000);
assert!(!prod_result.approximated);
assert!((prod_result.probability - 1.0).abs() < 1e-10);
}
#[test]
fn single_row_returns_product_of_base_facts() {
let a = b"fact_a".to_vec();
let b = b"fact_b".to_vec();
let rows = vec![HashSet::from([a.clone(), b.clone()])];
let probs = HashMap::from([(a, 0.3), (b, 0.5)]);
let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
assert!((result.probability - 0.15).abs() < 1e-10);
let result = weighted_model_count(&rows, &probs, SemiringOp::Conjunction, 1000);
assert!((result.probability - 0.15).abs() < 1e-10);
}
#[test]
fn independent_facts_mprod_matches_product() {
let a = b"fact_a".to_vec();
let b = b"fact_b".to_vec();
let rows = vec![HashSet::from([a.clone()]), HashSet::from([b.clone()])];
let probs = HashMap::from([(a, 0.3), (b, 0.5)]);
let result = weighted_model_count(&rows, &probs, SemiringOp::Conjunction, 1000);
assert!(!result.approximated);
assert_eq!(result.variable_count, 2);
let expected = 0.3 * 0.5;
assert!(
(result.probability - expected).abs() < 1e-10,
"MPROD BDD={}, expected={}",
result.probability,
expected
);
}
#[test]
fn bdd_limit_exceeded_returns_zero_probability() {
let a = b"fact_a".to_vec();
let b = b"fact_b".to_vec();
let c = b"fact_c".to_vec();
let rows = vec![
HashSet::from([a.clone()]),
HashSet::from([b.clone()]),
HashSet::from([c.clone()]),
];
let probs = HashMap::from([(a, 0.3), (b, 0.5), (c, 0.7)]);
let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 2);
assert!(
result.approximated,
"Expected approximated=true when limit exceeded"
);
assert_eq!(result.variable_count, 3);
assert!(
(result.probability - 0.0).abs() < 1e-10,
"Fallback probability should be 0.0, got {}",
result.probability
);
}
#[test]
fn three_way_shared_mnor() {
let a = b"fact_a".to_vec();
let b = b"fact_b".to_vec();
let c = b"fact_c".to_vec();
let d = b"fact_d".to_vec();
let rows = vec![
HashSet::from([a.clone(), d.clone()]),
HashSet::from([b.clone(), d.clone()]),
HashSet::from([c.clone(), d.clone()]),
];
let probs = HashMap::from([(a, 0.3), (b, 0.5), (c, 0.4), (d, 0.8)]);
let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
assert!(!result.approximated);
assert_eq!(result.variable_count, 4);
let expected = 0.8 * (1.0 - (1.0 - 0.3) * (1.0 - 0.5) * (1.0 - 0.4));
assert!(
(result.probability - expected).abs() < 1e-10,
"BDD={}, expected={}",
result.probability,
expected
);
let row0_prod = 0.3 * 0.8;
let row1_prod = 0.5 * 0.8;
let row2_prod = 0.4 * 0.8;
let independence = 1.0 - (1.0 - row0_prod) * (1.0 - row1_prod) * (1.0 - row2_prod);
assert!(
(result.probability - independence).abs() > 0.01,
"BDD result should differ from independence: BDD={}, indep={}",
result.probability,
independence
);
}
#[test]
fn partially_overlapping_rows_mnor() {
let a = b"fact_a".to_vec();
let b = b"fact_b".to_vec();
let c = b"fact_c".to_vec();
let e = b"fact_e".to_vec();
let rows = vec![
HashSet::from([a.clone(), c.clone()]),
HashSet::from([b.clone(), c.clone()]),
HashSet::from([e.clone()]),
];
let probs = HashMap::from([(a, 0.3), (b, 0.5), (c, 0.7), (e, 0.6)]);
let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
assert!(!result.approximated);
assert_eq!(result.variable_count, 4);
let p_c_times_a_or_b = 0.7 * (1.0 - (1.0 - 0.3) * (1.0 - 0.5));
let expected = 1.0 - (1.0 - p_c_times_a_or_b) * (1.0 - 0.6);
assert!(
(result.probability - expected).abs() < 1e-10,
"BDD={}, expected={}",
result.probability,
expected
);
}
#[test]
fn missing_probability_defaults_to_half() {
let unknown = b"unknown_fact".to_vec();
let rows = vec![HashSet::from([unknown.clone()])];
let probs = HashMap::new();
let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
assert!(!result.approximated);
assert_eq!(result.variable_count, 1);
assert!(
(result.probability - 0.5).abs() < 1e-10,
"Expected default probability 0.5, got {}",
result.probability
);
}
}