use std::collections::{HashMap, HashSet};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ProbError {
#[error("Probability {0} is out of range [0, 1]")]
InvalidProbability(f64),
#[error("Circular dependency detected in Bayesian network: {0}")]
CircularDependency(String),
#[error("Variable not found in network: {0}")]
VariableNotFound(String),
#[error("Maximum inference iterations exceeded ({0})")]
MaxIterationsExceeded(usize),
#[error("Inconsistent conditional probability table for node {0}")]
InconsistentCpt(String),
}
#[derive(Debug, Clone, PartialEq)]
pub struct ProbabilisticTriple {
pub subject: String,
pub predicate: String,
pub object: String,
pub probability: f64,
pub evidence: Vec<String>,
pub is_base_fact: bool,
}
impl ProbabilisticTriple {
pub fn new_fact(s: &str, p: &str, o: &str, probability: f64) -> Result<Self, ProbError> {
if !(0.0..=1.0).contains(&probability) {
return Err(ProbError::InvalidProbability(probability));
}
Ok(Self {
subject: s.to_string(),
predicate: p.to_string(),
object: o.to_string(),
probability,
evidence: Vec::new(),
is_base_fact: true,
})
}
pub fn new_inferred(
s: &str,
p: &str,
o: &str,
probability: f64,
evidence: Vec<String>,
) -> Result<Self, ProbError> {
if !(0.0..=1.0).contains(&probability) {
return Err(ProbError::InvalidProbability(probability));
}
Ok(Self {
subject: s.to_string(),
predicate: p.to_string(),
object: o.to_string(),
probability,
evidence,
is_base_fact: false,
})
}
pub fn key(&self) -> String {
format!("({}, {}, {})", self.subject, self.predicate, self.object)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PatternVar {
Var(String),
Const(String),
}
impl PatternVar {
pub fn var(name: &str) -> Self {
Self::Var(name.to_string())
}
pub fn konst(value: &str) -> Self {
Self::Const(value.to_string())
}
}
#[derive(Debug, Clone)]
pub struct RulePattern {
pub subject: PatternVar,
pub predicate: PatternVar,
pub object: PatternVar,
}
impl RulePattern {
pub fn new(s: PatternVar, p: PatternVar, o: PatternVar) -> Self {
Self {
subject: s,
predicate: p,
object: o,
}
}
}
#[derive(Debug, Clone)]
pub struct ProbabilisticRule {
pub id: String,
pub name: String,
pub antecedents: Vec<RulePattern>,
pub consequent: RulePattern,
pub confidence: f64,
}
impl ProbabilisticRule {
pub fn new(
id: &str,
name: &str,
antecedents: Vec<RulePattern>,
consequent: RulePattern,
confidence: f64,
) -> Result<Self, ProbError> {
if !(0.0..=1.0).contains(&confidence) {
return Err(ProbError::InvalidProbability(confidence));
}
Ok(Self {
id: id.to_string(),
name: name.to_string(),
antecedents,
consequent,
confidence,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CombinationStrategy {
NoisyOr,
Maximum,
WeightedAverage,
Minimum,
}
#[derive(Debug, Clone)]
pub struct ConditionalProbabilityTable {
pub parent_names: Vec<String>,
probabilities: HashMap<Vec<bool>, f64>,
pub leak_prob: f64,
}
impl ConditionalProbabilityTable {
pub fn new_noisy_or(
parent_names: Vec<String>,
inhibition_probs: Vec<f64>,
leak_prob: f64,
) -> Result<Self, ProbError> {
if parent_names.len() != inhibition_probs.len() {
return Err(ProbError::InconsistentCpt(
"inhibition_probs length mismatch".to_string(),
));
}
if !(0.0..=1.0).contains(&leak_prob) {
return Err(ProbError::InvalidProbability(leak_prob));
}
let n = parent_names.len();
let mut probabilities = HashMap::new();
for config_idx in 0..(1usize << n) {
let config: Vec<bool> = (0..n).map(|i| (config_idx >> i) & 1 == 1).collect();
let prod: f64 = config
.iter()
.enumerate()
.filter(|(_, &active)| active)
.map(|(i, _)| inhibition_probs[i])
.product();
let prob = 1.0 - (1.0 - leak_prob) * prod;
probabilities.insert(config, prob.clamp(0.0, 1.0));
}
Ok(Self {
parent_names,
probabilities,
leak_prob,
})
}
pub fn deterministic(value: bool) -> Self {
Self {
parent_names: Vec::new(),
probabilities: {
let mut m = HashMap::new();
m.insert(vec![], if value { 1.0 } else { 0.0 });
m
},
leak_prob: if value { 1.0 } else { 0.0 },
}
}
pub fn get_probability(&self, parent_values: &[bool]) -> f64 {
self.probabilities
.get(parent_values)
.copied()
.unwrap_or(self.leak_prob)
}
pub fn get_prior(&self) -> f64 {
if self.parent_names.is_empty() {
self.probabilities
.get(&Vec::<bool>::new())
.copied()
.unwrap_or(self.leak_prob)
} else {
self.leak_prob
}
}
}
#[derive(Debug, Clone)]
pub struct BayesianRdfNode {
pub id: String,
pub triple_key: String,
pub parents: Vec<String>,
pub cpt: ConditionalProbabilityTable,
pub marginal_prob: f64,
}
pub struct BayesianRdfNetwork {
nodes: HashMap<String, BayesianRdfNode>,
edges: HashMap<String, HashSet<String>>,
}
impl BayesianRdfNetwork {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges: HashMap::new(),
}
}
pub fn add_node(&mut self, node: BayesianRdfNode) -> Result<(), ProbError> {
for parent_id in &node.parents {
if !self.nodes.contains_key(parent_id) {
return Err(ProbError::VariableNotFound(parent_id.clone()));
}
self.edges
.entry(parent_id.clone())
.or_default()
.insert(node.id.clone());
}
self.nodes.insert(node.id.clone(), node);
Ok(())
}
pub fn add_root_node(
&mut self,
id: &str,
triple_key: &str,
prior_prob: f64,
) -> Result<(), ProbError> {
if !(0.0..=1.0).contains(&prior_prob) {
return Err(ProbError::InvalidProbability(prior_prob));
}
let cpt = ConditionalProbabilityTable::deterministic(false);
let mut node = BayesianRdfNode {
id: id.to_string(),
triple_key: triple_key.to_string(),
parents: Vec::new(),
cpt,
marginal_prob: prior_prob,
};
node.cpt.probabilities.insert(vec![], prior_prob);
node.cpt.leak_prob = prior_prob;
self.nodes.insert(id.to_string(), node);
Ok(())
}
pub fn propagate_beliefs(&mut self) -> Result<(), ProbError> {
let order = self.topological_order()?;
for node_id in &order {
let (parent_probs, parent_ids): (Vec<f64>, Vec<String>) = {
let node = self
.nodes
.get(node_id)
.ok_or_else(|| ProbError::VariableNotFound(node_id.clone()))?;
let parent_ids = node.parents.clone();
let parent_probs: Vec<f64> = parent_ids
.iter()
.map(|pid| self.nodes.get(pid).map(|n| n.marginal_prob).unwrap_or(0.0))
.collect();
(parent_probs, parent_ids)
};
if parent_ids.is_empty() {
continue;
}
let n_parents = parent_ids.len();
let mut marginal = 0.0;
for config_idx in 0..(1usize << n_parents) {
let config: Vec<bool> =
(0..n_parents).map(|i| (config_idx >> i) & 1 == 1).collect();
let config_prob: f64 = config
.iter()
.enumerate()
.map(|(i, &active)| {
let p = parent_probs[i];
if active {
p
} else {
1.0 - p
}
})
.product();
let node = self
.nodes
.get(node_id)
.ok_or_else(|| ProbError::VariableNotFound(node_id.clone()))?;
let cond_prob = node.cpt.get_probability(&config);
marginal += config_prob * cond_prob;
}
if let Some(node) = self.nodes.get_mut(node_id) {
node.marginal_prob = marginal.clamp(0.0, 1.0);
}
}
Ok(())
}
pub fn get_marginal_prob(&self, node_id: &str) -> Option<f64> {
self.nodes.get(node_id).map(|n| n.marginal_prob)
}
pub fn set_evidence(&mut self, node_id: &str, observed: bool) -> Result<(), ProbError> {
let node = self
.nodes
.get_mut(node_id)
.ok_or_else(|| ProbError::VariableNotFound(node_id.to_string()))?;
node.marginal_prob = if observed { 1.0 } else { 0.0 };
Ok(())
}
fn topological_order(&self) -> Result<Vec<String>, ProbError> {
let mut in_degree: HashMap<&str, usize> = HashMap::new();
for id in self.nodes.keys() {
in_degree.entry(id).or_insert(0);
}
for children in self.edges.values() {
for child in children {
*in_degree.entry(child).or_insert(0) += 1;
}
}
let mut queue: Vec<String> = in_degree
.iter()
.filter(|(_, &d)| d == 0)
.map(|(id, _)| id.to_string())
.collect();
queue.sort();
let mut order = Vec::new();
let mut remaining_degrees = in_degree
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect::<HashMap<_, _>>();
while !queue.is_empty() {
queue.sort();
let node_id = queue.remove(0);
order.push(node_id.clone());
if let Some(children) = self.edges.get(&node_id) {
for child in children {
let deg = remaining_degrees.entry(child.clone()).or_insert(0);
if *deg > 0 {
*deg -= 1;
}
if *deg == 0 {
queue.push(child.clone());
}
}
}
}
if order.len() != self.nodes.len() {
return Err(ProbError::CircularDependency(
"Graph contains cycles".to_string(),
));
}
Ok(order)
}
}
impl Default for BayesianRdfNetwork {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ProbabilisticInferenceReport {
pub inferred_count: usize,
pub rules_fired: usize,
pub iterations: usize,
pub duration: std::time::Duration,
pub threshold: f64,
pub avg_probability: f64,
}
pub struct ProbabilisticRdfReasoner {
facts: Vec<ProbabilisticTriple>,
rules: Vec<ProbabilisticRule>,
threshold: f64,
combination_strategy: CombinationStrategy,
max_iterations: usize,
}
impl ProbabilisticRdfReasoner {
pub fn new(threshold: f64) -> Result<Self, ProbError> {
if !(0.0..=1.0).contains(&threshold) {
return Err(ProbError::InvalidProbability(threshold));
}
Ok(Self {
facts: Vec::new(),
rules: Vec::new(),
threshold,
combination_strategy: CombinationStrategy::NoisyOr,
max_iterations: 100,
})
}
pub fn with_combination_strategy(mut self, strategy: CombinationStrategy) -> Self {
self.combination_strategy = strategy;
self
}
pub fn with_max_iterations(mut self, n: usize) -> Self {
self.max_iterations = n;
self
}
pub fn add_fact(
&mut self,
s: &str,
p: &str,
o: &str,
probability: f64,
) -> Result<(), ProbError> {
let fact = ProbabilisticTriple::new_fact(s, p, o, probability)?;
self.facts.push(fact);
Ok(())
}
pub fn add_rule(&mut self, rule: ProbabilisticRule) {
self.rules.push(rule);
}
pub fn infer(
&self,
) -> Result<(Vec<ProbabilisticTriple>, ProbabilisticInferenceReport), ProbError> {
let start = std::time::Instant::now();
let mut working_set: HashMap<String, (f64, Vec<String>)> = HashMap::new();
for fact in &self.facts {
let key = fact.key();
working_set.insert(key, (fact.probability, vec![fact.key()]));
}
let mut iterations = 0usize;
let mut total_rules_fired = 0usize;
loop {
if iterations >= self.max_iterations {
return Err(ProbError::MaxIterationsExceeded(self.max_iterations));
}
iterations += 1;
let mut new_support: HashMap<String, Vec<(f64, Vec<String>)>> = HashMap::new();
for rule in &self.rules {
let matches = self.match_rule(rule, &working_set);
total_rules_fired += matches.len();
for (bindings, antecedent_probs, antecedent_evidence) in matches {
if let Some((s, p, o)) = self.instantiate_pattern(&rule.consequent, &bindings) {
let key = format!("({}, {}, {})", s, p, o);
let antecedent_prob: f64 =
antecedent_probs.iter().cloned().fold(1.0f64, f64::min);
let derived_prob = antecedent_prob * rule.confidence;
if derived_prob >= self.threshold {
let mut evidence = antecedent_evidence;
evidence.push(format!("rule:{}", rule.id));
new_support
.entry(key)
.or_default()
.push((derived_prob, evidence));
}
}
}
}
let mut changed = false;
for (key, supports) in new_support {
let new_prob = self
.combine_probabilities(&supports.iter().map(|(p, _)| *p).collect::<Vec<_>>());
let new_evidence: Vec<String> = supports
.into_iter()
.flat_map(|(_, evs)| evs)
.collect::<HashSet<_>>()
.into_iter()
.collect();
let existing_prob = working_set.get(&key).map(|(p, _)| *p).unwrap_or(0.0);
if new_prob > existing_prob + f64::EPSILON {
working_set.insert(key, (new_prob, new_evidence));
changed = true;
}
}
if !changed {
break;
}
}
let base_fact_keys: HashSet<String> = self.facts.iter().map(|f| f.key()).collect();
let mut inferred: Vec<ProbabilisticTriple> = Vec::new();
for (key, (prob, evidence)) in &working_set {
if !base_fact_keys.contains(key) && *prob >= self.threshold {
if let Some(triple) = self.parse_key(key) {
match ProbabilisticTriple::new_inferred(
&triple.0,
&triple.1,
&triple.2,
*prob,
evidence.clone(),
) {
Ok(t) => inferred.push(t),
Err(_) => continue,
}
}
}
}
inferred.sort_by(|a, b| {
b.probability
.partial_cmp(&a.probability)
.unwrap_or(std::cmp::Ordering::Equal)
});
let avg_prob = if inferred.is_empty() {
0.0
} else {
inferred.iter().map(|t| t.probability).sum::<f64>() / inferred.len() as f64
};
let report = ProbabilisticInferenceReport {
inferred_count: inferred.len(),
rules_fired: total_rules_fired,
iterations,
duration: start.elapsed(),
threshold: self.threshold,
avg_probability: avg_prob,
};
Ok((inferred, report))
}
pub fn combine_probabilities(&self, probs: &[f64]) -> f64 {
if probs.is_empty() {
return 0.0;
}
match self.combination_strategy {
CombinationStrategy::NoisyOr => {
let complement_product: f64 = probs.iter().map(|&p| 1.0 - p).product();
(1.0 - complement_product).clamp(0.0, 1.0)
}
CombinationStrategy::Maximum => probs.iter().cloned().fold(0.0f64, f64::max),
CombinationStrategy::WeightedAverage => probs.iter().sum::<f64>() / probs.len() as f64,
CombinationStrategy::Minimum => probs.iter().cloned().fold(1.0f64, f64::min),
}
}
pub fn get_fact_probability(&self, s: &str, p: &str, o: &str) -> Option<f64> {
let key = format!("({}, {}, {})", s, p, o);
self.facts
.iter()
.find(|f| f.key() == key)
.map(|f| f.probability)
}
#[allow(clippy::type_complexity)]
fn match_rule(
&self,
rule: &ProbabilisticRule,
working_set: &HashMap<String, (f64, Vec<String>)>,
) -> Vec<(HashMap<String, String>, Vec<f64>, Vec<String>)> {
if rule.antecedents.is_empty() {
return vec![(HashMap::new(), Vec::new(), Vec::new())];
}
#[allow(clippy::type_complexity)]
let mut current: Vec<(HashMap<String, String>, Vec<f64>, Vec<String>)> =
vec![(HashMap::new(), Vec::new(), Vec::new())];
for pattern in &rule.antecedents {
let mut next = Vec::new();
for (bindings, probs, evidence) in ¤t {
for (key, (prob, evids)) in working_set {
if let Some(triple) = self.parse_key(key) {
if let Some(extended) = self.try_match_pattern(pattern, &triple, bindings) {
let mut new_probs = probs.clone();
new_probs.push(*prob);
let mut new_evidence = evidence.clone();
new_evidence.extend(evids.iter().cloned());
next.push((extended, new_probs, new_evidence));
}
}
}
}
current = next;
if current.is_empty() {
break;
}
}
current
}
fn try_match_pattern(
&self,
pattern: &RulePattern,
triple: &(String, String, String),
bindings: &HashMap<String, String>,
) -> Option<HashMap<String, String>> {
let mut extended = bindings.clone();
let check =
|elem: &PatternVar, value: &str, bindings: &mut HashMap<String, String>| -> bool {
match elem {
PatternVar::Const(c) => c == value,
PatternVar::Var(v) => {
if let Some(bound) = bindings.get(v) {
bound == value
} else {
bindings.insert(v.clone(), value.to_string());
true
}
}
}
};
if check(&pattern.subject, &triple.0, &mut extended)
&& check(&pattern.predicate, &triple.1, &mut extended)
&& check(&pattern.object, &triple.2, &mut extended)
{
Some(extended)
} else {
None
}
}
fn instantiate_pattern(
&self,
pattern: &RulePattern,
bindings: &HashMap<String, String>,
) -> Option<(String, String, String)> {
let resolve = |elem: &PatternVar| -> Option<String> {
match elem {
PatternVar::Const(c) => Some(c.clone()),
PatternVar::Var(v) => bindings.get(v).cloned(),
}
};
let s = resolve(&pattern.subject)?;
let p = resolve(&pattern.predicate)?;
let o = resolve(&pattern.object)?;
Some((s, p, o))
}
fn parse_key(&self, key: &str) -> Option<(String, String, String)> {
let inner = key.strip_prefix('(')?.strip_suffix(')')?;
let mut depth = 0i32;
let mut separators = Vec::new();
let chars: Vec<char> = inner.chars().collect();
let mut i = 0;
while i < chars.len() {
match chars[i] {
'(' | '[' | '{' => depth += 1,
')' | ']' | '}' => depth -= 1,
',' if depth == 0 && i + 1 < chars.len() && chars[i + 1] == ' ' => {
separators.push(i);
}
_ => {}
}
i += 1;
}
if separators.len() >= 2 {
let s = inner[..separators[0]].to_string();
let p = inner[(separators[0] + 2)..separators[1]].to_string();
let o = inner[(separators[1] + 2)..].to_string();
Some((s, p, o))
} else {
None
}
}
}
pub fn make_subclass_rule(
id: &str,
sub_class: &str,
sup_class: &str,
confidence: f64,
) -> Result<ProbabilisticRule, ProbError> {
let rdf_type = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type";
ProbabilisticRule::new(
id,
&format!("{} ⊑ {}", sub_class, sup_class),
vec![RulePattern::new(
PatternVar::var("x"),
PatternVar::konst(rdf_type),
PatternVar::konst(sub_class),
)],
RulePattern::new(
PatternVar::var("x"),
PatternVar::konst(rdf_type),
PatternVar::konst(sup_class),
),
confidence,
)
}
pub fn make_symmetric_rule(
id: &str,
property: &str,
confidence: f64,
) -> Result<ProbabilisticRule, ProbError> {
ProbabilisticRule::new(
id,
&format!("symmetric({})", property),
vec![RulePattern::new(
PatternVar::var("x"),
PatternVar::konst(property),
PatternVar::var("y"),
)],
RulePattern::new(
PatternVar::var("y"),
PatternVar::konst(property),
PatternVar::var("x"),
),
confidence,
)
}
pub fn make_transitive_rule(
id: &str,
property: &str,
confidence: f64,
) -> Result<ProbabilisticRule, ProbError> {
ProbabilisticRule::new(
id,
&format!("transitive({})", property),
vec![
RulePattern::new(
PatternVar::var("x"),
PatternVar::konst(property),
PatternVar::var("y"),
),
RulePattern::new(
PatternVar::var("y"),
PatternVar::konst(property),
PatternVar::var("z"),
),
],
RulePattern::new(
PatternVar::var("x"),
PatternVar::konst(property),
PatternVar::var("z"),
),
confidence,
)
}
pub fn make_domain_rule(
id: &str,
property: &str,
domain_class: &str,
confidence: f64,
) -> Result<ProbabilisticRule, ProbError> {
let rdf_type = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type";
ProbabilisticRule::new(
id,
&format!("domain({}, {})", property, domain_class),
vec![RulePattern::new(
PatternVar::var("x"),
PatternVar::konst(property),
PatternVar::var("y"),
)],
RulePattern::new(
PatternVar::var("x"),
PatternVar::konst(rdf_type),
PatternVar::konst(domain_class),
),
confidence,
)
}
#[cfg(test)]
mod tests {
use super::*;
const RDF_TYPE: &str = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type";
fn reasoner(threshold: f64) -> ProbabilisticRdfReasoner {
ProbabilisticRdfReasoner::new(threshold).expect("valid threshold")
}
#[test]
fn test_invalid_probability() {
assert!(ProbabilisticTriple::new_fact("s", "p", "o", 1.5).is_err());
assert!(ProbabilisticTriple::new_fact("s", "p", "o", -0.1).is_err());
assert!(ProbabilisticTriple::new_fact("s", "p", "o", 0.8).is_ok());
}
#[test]
fn test_subclass_inference() -> Result<(), Box<dyn std::error::Error>> {
let mut r = reasoner(0.1);
r.add_fact("fido", RDF_TYPE, "Dog", 0.9).expect("add fact");
r.add_rule(make_subclass_rule("r1", "Dog", "Animal", 0.95).expect("rule"));
let (inferred, report) = r.infer().expect("infer");
assert!(
!inferred.is_empty(),
"Expected at least one inferred triple"
);
assert!(report.rules_fired > 0);
let fido_animal = inferred
.iter()
.find(|t| t.subject == "fido" && t.predicate == RDF_TYPE && t.object == "Animal");
assert!(
fido_animal.is_some(),
"Expected fido rdf:type Animal to be inferred"
);
assert!(fido_animal.ok_or("expected Some value")?.probability > 0.5);
Ok(())
}
#[test]
fn test_symmetric_inference() -> Result<(), Box<dyn std::error::Error>> {
let knows = "https://example.org/knows";
let mut r = reasoner(0.1);
r.add_fact("alice", knows, "bob", 0.9).expect("add fact");
r.add_rule(make_symmetric_rule("r1", knows, 0.95).expect("rule"));
let (inferred, _) = r.infer().expect("infer");
let bob_knows_alice = inferred
.iter()
.find(|t| t.subject == "bob" && t.predicate == knows && t.object == "alice");
assert!(bob_knows_alice.is_some(), "Expected symmetric inference");
assert!(bob_knows_alice.ok_or("expected Some value")?.probability > 0.5);
Ok(())
}
#[test]
fn test_transitive_inference() -> Result<(), Box<dyn std::error::Error>> {
let ancestor_of = "https://example.org/ancestorOf";
let mut r = reasoner(0.1);
r.add_fact("grandpa", ancestor_of, "parent", 0.99)
.expect("add fact");
r.add_fact("parent", ancestor_of, "child", 0.99)
.expect("add fact");
r.add_rule(make_transitive_rule("r1", ancestor_of, 0.95).expect("rule"));
let (inferred, _) = r.infer().expect("infer");
let grandpa_child = inferred
.iter()
.find(|t| t.subject == "grandpa" && t.predicate == ancestor_of && t.object == "child");
assert!(
grandpa_child.is_some(),
"Expected transitive inference. Inferred: {:?}",
inferred
.iter()
.map(|t| format!("({},{},{})", t.subject, t.predicate, t.object))
.collect::<Vec<_>>()
);
Ok(())
}
#[test]
fn test_domain_inference() {
let has_parent = "https://example.org/hasParent";
let mut r = reasoner(0.1);
r.add_fact("alice", has_parent, "bob", 0.85)
.expect("add fact");
r.add_rule(make_domain_rule("r1", has_parent, "Person", 0.9).expect("rule"));
let (inferred, _) = r.infer().expect("infer");
let alice_person = inferred
.iter()
.find(|t| t.subject == "alice" && t.predicate == RDF_TYPE && t.object == "Person");
assert!(
alice_person.is_some(),
"Expected alice rdf:type Person from domain"
);
}
#[test]
fn test_noisy_or_combination() {
let r = ProbabilisticRdfReasoner::new(0.0).expect("reasoner");
let result = r.combine_probabilities(&[0.6, 0.7]);
let expected = 1.0 - (1.0 - 0.6) * (1.0 - 0.7);
assert!(
(result - expected).abs() < 1e-10,
"Expected {}, got {}",
expected,
result
);
}
#[test]
fn test_maximum_combination() {
let r = ProbabilisticRdfReasoner::new(0.0)
.expect("reasoner")
.with_combination_strategy(CombinationStrategy::Maximum);
let result = r.combine_probabilities(&[0.3, 0.8, 0.5]);
assert!((result - 0.8).abs() < 1e-10, "Expected 0.8 max");
}
#[test]
fn test_weighted_average_combination() {
let r = ProbabilisticRdfReasoner::new(0.0)
.expect("reasoner")
.with_combination_strategy(CombinationStrategy::WeightedAverage);
let result = r.combine_probabilities(&[0.4, 0.6]);
assert!((result - 0.5).abs() < 1e-10, "Expected 0.5 average");
}
#[test]
fn test_threshold_filtering() {
let mut r = reasoner(0.8);
r.add_fact("fido", RDF_TYPE, "Dog", 0.5).expect("add fact");
r.add_rule(make_subclass_rule("r1", "Dog", "Animal", 0.7).expect("rule"));
let (inferred, _) = r.infer().expect("infer");
assert!(
inferred.is_empty() || inferred.iter().all(|t| t.probability >= 0.8),
"Expected no triples below threshold"
);
}
#[test]
fn test_cpt_noisy_or() {
let cpt = ConditionalProbabilityTable::new_noisy_or(
vec!["RainFall".to_string(), "Sprinkler".to_string()],
vec![0.1, 0.2], 0.01, )
.expect("valid cpt");
let p_no_parents = cpt.get_probability(&[false, false]);
assert!(
(p_no_parents - (1.0 - (1.0 - 0.01))).abs() < 0.01,
"P(wet | no rain, no sprinkler) ≈ 0.01, got {}",
p_no_parents
);
let p_both = cpt.get_probability(&[true, true]);
assert!(
p_both > 0.9,
"P(wet | rain + sprinkler) should be > 0.9, got {}",
p_both
);
}
#[test]
fn test_bayesian_network_propagation() {
let mut bn = BayesianRdfNetwork::new();
bn.add_root_node("rain", "rain-triple", 0.3)
.expect("add rain");
bn.add_root_node("sprinkler", "sprinkler-triple", 0.4)
.expect("add sprinkler");
let cpt = ConditionalProbabilityTable::new_noisy_or(
vec!["rain".to_string(), "sprinkler".to_string()],
vec![0.1, 0.2],
0.01,
)
.expect("cpt");
bn.add_node(BayesianRdfNode {
id: "wet_grass".to_string(),
triple_key: "wet-triple".to_string(),
parents: vec!["rain".to_string(), "sprinkler".to_string()],
cpt,
marginal_prob: 0.0,
})
.expect("add wet grass");
bn.propagate_beliefs().expect("propagate");
let wet_prob = bn.get_marginal_prob("wet_grass").expect("wet prob");
assert!(
wet_prob > 0.0 && wet_prob < 1.0,
"WetGrass probability should be in (0,1), got {}",
wet_prob
);
assert!(wet_prob > 0.3, "Expected wet grass > 0.3, got {}", wet_prob);
}
#[test]
fn test_evidence_update() {
let mut bn = BayesianRdfNetwork::new();
bn.add_root_node("cause", "cause-triple", 0.5)
.expect("add cause");
let cpt =
ConditionalProbabilityTable::new_noisy_or(vec!["cause".to_string()], vec![0.1], 0.05)
.expect("cpt");
bn.add_node(BayesianRdfNode {
id: "effect".to_string(),
triple_key: "effect-triple".to_string(),
parents: vec!["cause".to_string()],
cpt,
marginal_prob: 0.0,
})
.expect("add effect");
bn.set_evidence("cause", true).expect("set evidence");
bn.propagate_beliefs().expect("propagate");
let effect_prob = bn.get_marginal_prob("effect").expect("effect prob");
assert!(
effect_prob > 0.7,
"With observed cause, effect should be > 0.7, got {}",
effect_prob
);
}
#[test]
fn test_inference_report() {
let mut r = reasoner(0.1);
r.add_fact("alice", RDF_TYPE, "Person", 0.9)
.expect("add fact");
r.add_rule(make_subclass_rule("r1", "Person", "Agent", 0.85).expect("rule"));
let (_, report) = r.infer().expect("infer");
assert!(report.iterations >= 1);
assert!(report.threshold == 0.1);
assert!(report.duration.as_nanos() > 0);
}
#[test]
fn test_chained_rules() -> Result<(), Box<dyn std::error::Error>> {
let mut r = reasoner(0.1);
r.add_fact("fido", RDF_TYPE, "Labrador", 0.95)
.expect("add fact");
r.add_rule(make_subclass_rule("r1", "Labrador", "Dog", 0.99).expect("rule"));
r.add_rule(make_subclass_rule("r2", "Dog", "Mammal", 0.99).expect("rule"));
r.add_rule(make_subclass_rule("r3", "Mammal", "Animal", 0.99).expect("rule"));
let (inferred, _) = r.infer().expect("infer");
let fido_animal = inferred
.iter()
.find(|t| t.subject == "fido" && t.predicate == RDF_TYPE && t.object == "Animal");
assert!(
fido_animal.is_some(),
"Expected fido rdf:type Animal via chained rules. Inferred: {:?}",
inferred
.iter()
.map(|t| t.object.clone())
.collect::<Vec<_>>()
);
Ok(())
}
#[test]
fn test_multiple_facts_symmetric() {
let knows = "https://example.org/knows";
let mut r = reasoner(0.1);
r.add_fact("alice", knows, "bob", 0.9).expect("add fact");
r.add_fact("alice", knows, "carol", 0.7).expect("add fact");
r.add_rule(make_symmetric_rule("r1", knows, 0.9).expect("rule"));
let (inferred, _) = r.infer().expect("infer");
assert!(inferred.len() >= 2, "Expected at least 2 symmetric triples");
}
#[test]
fn test_zero_threshold() {
let mut r = ProbabilisticRdfReasoner::new(0.0).expect("reasoner");
r.add_fact("x", "p", "y", 0.01).expect("add fact");
r.add_rule(make_symmetric_rule("r1", "p", 0.5).expect("rule"));
let (inferred, _) = r.infer().expect("infer");
assert!(!inferred.is_empty() || inferred.is_empty()); }
}