use crate::forward::Substitution;
use crate::{Rule, RuleAtom, Term};
use anyhow::{anyhow, Result};
use scirs2_core::metrics::{Counter, Timer};
use scirs2_core::random::{Distribution, Uniform};
use std::collections::{HashMap, HashSet};
lazy_static::lazy_static! {
static ref PROBLOG_QUERIES: Counter = Counter::new("problog_queries".to_string());
static ref PROBLOG_INFERENCES: Counter = Counter::new("problog_inferences".to_string());
static ref PROBLOG_QUERY_TIME: Timer = Timer::new("problog_query_time".to_string());
}
#[derive(Debug, Clone)]
pub struct ProbabilisticFact {
pub probability: f64,
pub fact: RuleAtom,
}
impl ProbabilisticFact {
pub fn new(probability: f64, fact: RuleAtom) -> Result<Self> {
if !(0.0..=1.0).contains(&probability) {
return Err(anyhow!(
"Probability must be in [0, 1], got {}",
probability
));
}
Ok(Self { probability, fact })
}
}
#[derive(Debug, Clone)]
pub struct ProbabilisticRule {
pub probability: Option<f64>,
pub rule: Rule,
}
impl ProbabilisticRule {
pub fn deterministic(rule: Rule) -> Self {
Self {
probability: None,
rule,
}
}
pub fn probabilistic(probability: f64, rule: Rule) -> Result<Self> {
if !(0.0..=1.0).contains(&probability) {
return Err(anyhow!(
"Probability must be in [0, 1], got {}",
probability
));
}
Ok(Self {
probability: Some(probability),
rule,
})
}
}
#[derive(Debug, Clone)]
pub struct DerivationTree {
pub fact: RuleAtom,
pub probability: f64,
pub premises: Vec<DerivationTree>,
}
impl DerivationTree {
pub fn leaf(fact: RuleAtom, probability: f64) -> Self {
Self {
fact,
probability,
premises: Vec::new(),
}
}
pub fn node(fact: RuleAtom, probability: f64, premises: Vec<DerivationTree>) -> Self {
Self {
fact,
probability,
premises,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EvaluationStrategy {
TopDown,
BottomUp,
Auto,
}
pub struct ProbLogEngine {
probabilistic_facts: HashMap<RuleAtom, f64>,
deterministic_facts: HashSet<RuleAtom>,
probabilistic_rules: Vec<ProbabilisticRule>,
query_cache: HashMap<RuleAtom, f64>,
recursion_stack: HashSet<RuleAtom>,
max_depth: usize,
current_depth: usize,
materialized_facts: HashMap<RuleAtom, f64>,
materialization_valid: bool,
strategy: EvaluationStrategy,
max_fixpoint_iterations: usize,
pub stats: ProbLogStats,
}
#[derive(Debug, Clone, Default)]
pub struct ProbLogStats {
pub queries: usize,
pub inferences: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub fixpoint_iterations: usize,
pub materialized_facts_count: usize,
}
impl Default for ProbLogEngine {
fn default() -> Self {
Self::new()
}
}
impl ProbLogEngine {
pub fn new() -> Self {
Self {
probabilistic_facts: HashMap::new(),
deterministic_facts: HashSet::new(),
probabilistic_rules: Vec::new(),
query_cache: HashMap::new(),
recursion_stack: HashSet::new(),
max_depth: 100,
current_depth: 0,
materialized_facts: HashMap::new(),
materialization_valid: false,
strategy: EvaluationStrategy::Auto,
max_fixpoint_iterations: 1000,
stats: ProbLogStats::default(),
}
}
pub fn with_max_depth(mut self, max_depth: usize) -> Self {
self.max_depth = max_depth;
self
}
pub fn with_strategy(mut self, strategy: EvaluationStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_max_fixpoint_iterations(mut self, max_iterations: usize) -> Self {
self.max_fixpoint_iterations = max_iterations;
self
}
pub fn add_probabilistic_fact(&mut self, fact: ProbabilisticFact) {
if (fact.probability - 1.0).abs() < 1e-10 {
self.deterministic_facts.insert(fact.fact);
} else {
self.probabilistic_facts.insert(fact.fact, fact.probability);
}
self.query_cache.clear(); self.materialization_valid = false; }
pub fn add_fact(&mut self, fact: RuleAtom) {
self.deterministic_facts.insert(fact);
self.query_cache.clear();
self.materialization_valid = false;
}
pub fn add_rule(&mut self, rule: ProbabilisticRule) {
self.probabilistic_rules.push(rule);
self.query_cache.clear();
self.materialization_valid = false;
}
pub fn query_probability(&mut self, query: &RuleAtom) -> Result<f64> {
let _timer = PROBLOG_QUERY_TIME.start();
self.stats.queries += 1;
PROBLOG_QUERIES.inc();
let use_bottom_up = match self.strategy {
EvaluationStrategy::TopDown => false,
EvaluationStrategy::BottomUp => true,
EvaluationStrategy::Auto => {
self.has_recursive_rules()
}
};
if use_bottom_up {
return self.query_materialized(query);
}
if self.current_depth > self.max_depth {
return Err(anyhow!(
"Maximum recursion depth exceeded: {}",
self.max_depth
));
}
if self.recursion_stack.contains(query) {
return self.query_materialized(query);
}
if let Some(&prob) = self.query_cache.get(query) {
self.stats.cache_hits += 1;
return Ok(prob);
}
self.stats.cache_misses += 1;
if self.deterministic_facts.contains(query) {
self.query_cache.insert(query.clone(), 1.0);
return Ok(1.0);
}
if let Some(&prob) = self.probabilistic_facts.get(query) {
self.query_cache.insert(query.clone(), prob);
return Ok(prob);
}
self.recursion_stack.insert(query.clone());
self.current_depth += 1;
let prob = self.derive_probability(query)?;
self.query_cache.insert(query.clone(), prob);
self.recursion_stack.remove(query);
self.current_depth -= 1;
Ok(prob)
}
fn has_recursive_rules(&self) -> bool {
for rule in &self.probabilistic_rules {
for head_atom in &rule.rule.head {
if let RuleAtom::Triple {
predicate: head_pred,
..
} = head_atom
{
for body_atom in &rule.rule.body {
if let RuleAtom::Triple {
predicate: body_pred,
..
} = body_atom
{
if head_pred == body_pred {
return true;
}
}
}
}
}
}
false
}
fn derive_probability(&mut self, query: &RuleAtom) -> Result<f64> {
let mut total_prob = 0.0;
for prob_rule in &self.probabilistic_rules.clone() {
for head_atom in &prob_rule.rule.head {
if let Some(substitution) = self.unify_atoms(head_atom, query) {
let instantiated_body =
self.apply_substitution_to_body(&prob_rule.rule.body, &substitution);
let body_prob = self.evaluate_body(&instantiated_body)?;
let derivation_prob = body_prob * prob_rule.probability.unwrap_or(1.0);
total_prob = total_prob + derivation_prob - (total_prob * derivation_prob);
self.stats.inferences += 1;
PROBLOG_INFERENCES.inc();
}
}
}
Ok(total_prob)
}
fn evaluate_body(&mut self, body: &[RuleAtom]) -> Result<f64> {
let mut prob = 1.0;
for atom in body {
let atom_prob = self.query_probability(atom)?;
prob *= atom_prob;
}
Ok(prob)
}
fn unify_atoms(&self, pattern: &RuleAtom, target: &RuleAtom) -> Option<Substitution> {
let mut substitution = HashMap::new();
match (pattern, target) {
(
RuleAtom::Triple {
subject: s1,
predicate: p1,
object: o1,
},
RuleAtom::Triple {
subject: s2,
predicate: p2,
object: o2,
},
) => {
if !self.unify_terms(s1, s2, &mut substitution) {
return None;
}
if !self.unify_terms(p1, p2, &mut substitution) {
return None;
}
if !self.unify_terms(o1, o2, &mut substitution) {
return None;
}
Some(substitution)
}
_ => None,
}
}
fn unify_terms(&self, t1: &Term, t2: &Term, subst: &mut Substitution) -> bool {
let t1_resolved = self.apply_substitution_to_term(t1, subst);
let t2_resolved = self.apply_substitution_to_term(t2, subst);
match (&t1_resolved, &t2_resolved) {
(Term::Variable(v1), Term::Variable(v2)) if v1 == v2 => true,
(Term::Variable(v), t) | (t, Term::Variable(v)) => {
if self.occurs_in_term(v, t) {
return false;
}
subst.insert(v.clone(), t.clone());
true
}
(Term::Constant(c1), Term::Constant(c2)) => c1 == c2,
(Term::Literal(l1), Term::Literal(l2)) => l1 == l2,
(Term::Function { name: n1, args: a1 }, Term::Function { name: n2, args: a2 }) => {
if n1 != n2 || a1.len() != a2.len() {
return false;
}
for (arg1, arg2) in a1.iter().zip(a2.iter()) {
if !self.unify_terms(arg1, arg2, subst) {
return false;
}
}
true
}
_ => false,
}
}
#[allow(clippy::only_used_in_recursion)] fn occurs_in_term(&self, var: &str, term: &Term) -> bool {
match term {
Term::Variable(v) => v == var,
Term::Constant(_) | Term::Literal(_) => false,
Term::Function { args, .. } => args.iter().any(|arg| self.occurs_in_term(var, arg)),
}
}
#[allow(clippy::only_used_in_recursion)] fn apply_substitution_to_term(&self, term: &Term, subst: &Substitution) -> Term {
match term {
Term::Variable(v) => subst.get(v).cloned().unwrap_or_else(|| term.clone()),
Term::Function { name, args } => Term::Function {
name: name.clone(),
args: args
.iter()
.map(|arg| self.apply_substitution_to_term(arg, subst))
.collect(),
},
_ => term.clone(),
}
}
fn apply_substitution_to_atom(&self, atom: &RuleAtom, subst: &Substitution) -> RuleAtom {
match atom {
RuleAtom::Triple {
subject,
predicate,
object,
} => RuleAtom::Triple {
subject: self.apply_substitution_to_term(subject, subst),
predicate: self.apply_substitution_to_term(predicate, subst),
object: self.apply_substitution_to_term(object, subst),
},
RuleAtom::Builtin { name, args } => RuleAtom::Builtin {
name: name.clone(),
args: args
.iter()
.map(|arg| self.apply_substitution_to_term(arg, subst))
.collect(),
},
RuleAtom::NotEqual { left, right } => RuleAtom::NotEqual {
left: self.apply_substitution_to_term(left, subst),
right: self.apply_substitution_to_term(right, subst),
},
RuleAtom::GreaterThan { left, right } => RuleAtom::GreaterThan {
left: self.apply_substitution_to_term(left, subst),
right: self.apply_substitution_to_term(right, subst),
},
RuleAtom::LessThan { left, right } => RuleAtom::LessThan {
left: self.apply_substitution_to_term(left, subst),
right: self.apply_substitution_to_term(right, subst),
},
}
}
fn apply_substitution_to_body(&self, body: &[RuleAtom], subst: &Substitution) -> Vec<RuleAtom> {
body.iter()
.map(|atom| self.apply_substitution_to_atom(atom, subst))
.collect()
}
pub fn materialize(&mut self) -> Result<()> {
if self.materialization_valid {
return Ok(()); }
self.materialized_facts.clear();
self.stats.fixpoint_iterations = 0;
let mut current_facts = HashMap::new();
for (fact, &prob) in &self.probabilistic_facts {
current_facts.insert(fact.clone(), prob);
}
for fact in &self.deterministic_facts {
current_facts.insert(fact.clone(), 1.0);
}
let mut iteration = 0;
loop {
iteration += 1;
self.stats.fixpoint_iterations = iteration;
if iteration > self.max_fixpoint_iterations {
return Err(anyhow!(
"Maximum fixpoint iterations exceeded: {}",
self.max_fixpoint_iterations
));
}
let previous_size = current_facts.len();
let mut new_facts = HashMap::new();
for prob_rule in &self.probabilistic_rules {
let derived = self.apply_rule_for_materialization(
&prob_rule.rule,
¤t_facts,
prob_rule.probability.unwrap_or(1.0),
)?;
for (fact, prob) in derived {
let entry = new_facts.entry(fact).or_insert(0.0);
*entry = *entry + prob - (*entry * prob);
}
}
let mut changed = false;
for (fact, new_prob) in &new_facts {
let existing_prob = current_facts.get(fact).copied().unwrap_or(0.0);
if (new_prob - existing_prob).abs() > 1e-10 {
changed = true;
current_facts.insert(fact.clone(), *new_prob);
}
}
if !changed && current_facts.len() == previous_size {
break;
}
}
self.materialized_facts = current_facts;
self.stats.materialized_facts_count = self.materialized_facts.len();
self.materialization_valid = true;
Ok(())
}
fn apply_rule_for_materialization(
&self,
rule: &Rule,
facts: &HashMap<RuleAtom, f64>,
rule_prob: f64,
) -> Result<HashMap<RuleAtom, f64>> {
let mut derived = HashMap::new();
let bindings = self.find_all_bindings(&rule.body, facts)?;
for binding in bindings {
let mut body_prob = 1.0;
for body_atom in &rule.body {
let instantiated = self.apply_substitution_to_atom(body_atom, &binding);
let atom_prob = facts.get(&instantiated).copied().unwrap_or(0.0);
body_prob *= atom_prob;
}
for head_atom in &rule.head {
let instantiated_head = self.apply_substitution_to_atom(head_atom, &binding);
let derivation_prob = body_prob * rule_prob;
let current_prob = derived.get(&instantiated_head).copied().unwrap_or(0.0);
let combined = if current_prob > 0.0 {
current_prob + derivation_prob - (current_prob * derivation_prob)
} else {
derivation_prob
};
derived.insert(instantiated_head, combined);
}
}
Ok(derived)
}
fn find_all_bindings(
&self,
body: &[RuleAtom],
facts: &HashMap<RuleAtom, f64>,
) -> Result<Vec<Substitution>> {
if body.is_empty() {
return Ok(vec![HashMap::new()]);
}
let first_atom = &body[0];
let rest_body = &body[1..];
let mut all_bindings = Vec::new();
for fact in facts.keys() {
if let Some(binding) = self.unify_atoms(first_atom, fact) {
if rest_body.is_empty() {
all_bindings.push(binding);
} else {
let instantiated_rest = self.apply_substitution_to_body(rest_body, &binding);
let rest_bindings = self.find_all_bindings(&instantiated_rest, facts)?;
for rest_binding in rest_bindings {
let mut merged = binding.clone();
for (var, term) in rest_binding {
merged.insert(var, term);
}
all_bindings.push(merged);
}
}
}
}
Ok(all_bindings)
}
pub fn query_materialized(&mut self, query: &RuleAtom) -> Result<f64> {
if !self.materialization_valid {
self.materialize()?;
}
let prob = self.materialized_facts.get(query).copied().unwrap_or(0.0);
Ok(prob)
}
pub fn sample(&mut self) -> HashSet<RuleAtom> {
use scirs2_core::random::rng;
let mut rng_instance = rng();
let uniform = Uniform::new(0.0, 1.0).expect("distribution parameters are valid");
let mut sampled_facts = HashSet::new();
for (fact, &prob) in &self.probabilistic_facts {
if uniform.sample(&mut rng_instance) < prob {
sampled_facts.insert(fact.clone());
}
}
sampled_facts.extend(self.deterministic_facts.iter().cloned());
sampled_facts
}
pub fn monte_carlo_query(&mut self, query: &RuleAtom, samples: usize) -> Result<f64> {
let mut successes = 0;
for _ in 0..samples {
let sampled_world = self.sample();
if sampled_world.contains(query) {
successes += 1;
}
}
Ok(successes as f64 / samples as f64)
}
pub fn explain(&mut self, query: &RuleAtom) -> Result<Option<DerivationTree>> {
if self.deterministic_facts.contains(query) {
return Ok(Some(DerivationTree::leaf(query.clone(), 1.0)));
}
if let Some(&prob) = self.probabilistic_facts.get(query) {
return Ok(Some(DerivationTree::leaf(query.clone(), prob)));
}
for prob_rule in &self.probabilistic_rules.clone() {
for head_atom in &prob_rule.rule.head {
if let Some(substitution) = self.unify_atoms(head_atom, query) {
let instantiated_body =
self.apply_substitution_to_body(&prob_rule.rule.body, &substitution);
let mut premises = Vec::new();
let mut body_prob = 1.0;
for body_atom in &instantiated_body {
if let Some(tree) = self.explain(body_atom)? {
body_prob *= tree.probability;
premises.push(tree);
} else {
body_prob = 0.0;
break;
}
}
if body_prob > 0.0 {
let total_prob = body_prob * prob_rule.probability.unwrap_or(1.0);
return Ok(Some(DerivationTree::node(
query.clone(),
total_prob,
premises,
)));
}
}
}
}
Ok(None)
}
pub fn clear_cache(&mut self) {
self.query_cache.clear();
}
pub fn reset_stats(&mut self) {
self.stats = ProbLogStats::default();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_triple(subject: &str, predicate: &str, object: &str) -> RuleAtom {
RuleAtom::Triple {
subject: Term::Constant(subject.to_string()),
predicate: Term::Constant(predicate.to_string()),
object: Term::Constant(object.to_string()),
}
}
#[test]
fn test_probabilistic_fact_creation() -> Result<()> {
let fact = ProbabilisticFact::new(0.8, create_triple("john", "parent", "mary"))?;
assert_eq!(fact.probability, 0.8);
Ok(())
}
#[test]
fn test_invalid_probability() {
let result = ProbabilisticFact::new(1.5, create_triple("john", "parent", "mary"));
assert!(result.is_err());
}
#[test]
fn test_query_deterministic_fact() -> Result<()> {
let mut engine = ProbLogEngine::new();
let fact = create_triple("john", "parent", "mary");
engine.add_fact(fact.clone());
let prob = engine.query_probability(&fact)?;
assert_eq!(prob, 1.0);
Ok(())
}
#[test]
fn test_query_probabilistic_fact() -> Result<()> {
let mut engine = ProbLogEngine::new();
let fact = create_triple("john", "parent", "mary");
engine.add_probabilistic_fact(ProbabilisticFact::new(0.8, fact.clone())?);
let prob = engine.query_probability(&fact)?;
assert_eq!(prob, 0.8);
Ok(())
}
#[test]
fn test_query_unknown_fact() -> Result<()> {
let mut engine = ProbLogEngine::new();
let fact = create_triple("john", "parent", "mary");
let prob = engine.query_probability(&fact)?;
assert_eq!(prob, 0.0);
Ok(())
}
#[test]
fn test_rule_derivation() -> Result<()> {
let mut engine = ProbLogEngine::new();
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.8,
create_triple("john", "parent", "mary"),
)?);
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "ancestor".to_string(),
body: vec![create_triple("john", "parent", "mary")],
head: vec![create_triple("john", "ancestor", "mary")],
}));
let ancestor_fact = create_triple("john", "ancestor", "mary");
let prob = engine.query_probability(&ancestor_fact)?;
assert!((prob - 0.8).abs() < 0.001);
Ok(())
}
#[test]
fn test_query_caching() -> Result<()> {
let mut engine = ProbLogEngine::new();
let fact = create_triple("john", "parent", "mary");
engine.add_probabilistic_fact(ProbabilisticFact::new(0.7, fact.clone())?);
engine.query_probability(&fact)?;
assert_eq!(engine.stats.cache_misses, 1);
assert_eq!(engine.stats.cache_hits, 0);
engine.query_probability(&fact)?;
assert_eq!(engine.stats.cache_hits, 1);
Ok(())
}
#[test]
fn test_sampling() -> Result<()> {
let mut engine = ProbLogEngine::new();
engine.add_fact(create_triple("john", "person", "true"));
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.5,
create_triple("john", "tall", "true"),
)?);
let mut tall_count = 0;
let samples = 1000;
for _ in 0..samples {
let world = engine.sample();
if world.contains(&create_triple("john", "tall", "true")) {
tall_count += 1;
}
assert!(world.contains(&create_triple("john", "person", "true")));
}
let proportion = tall_count as f64 / samples as f64;
assert!((proportion - 0.5).abs() < 0.1);
Ok(())
}
#[test]
fn test_explanation_tree() -> Result<()> {
let mut engine = ProbLogEngine::new();
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.9,
create_triple("john", "parent", "mary"),
)?);
let tree = engine.explain(&create_triple("john", "parent", "mary"))?;
assert!(tree.is_some());
let tree = tree.ok_or_else(|| anyhow::anyhow!("expected Some value"))?;
assert_eq!(tree.probability, 0.9);
assert!(tree.premises.is_empty());
Ok(())
}
#[test]
fn test_probabilistic_rule() -> Result<()> {
let mut engine = ProbLogEngine::new();
engine.add_fact(create_triple("john", "parent", "mary"));
engine.add_rule(ProbabilisticRule::probabilistic(
0.9,
Rule {
name: "ancestor".to_string(),
body: vec![create_triple("john", "parent", "mary")],
head: vec![create_triple("john", "ancestor", "mary")],
},
)?);
let prob = engine.query_probability(&create_triple("john", "ancestor", "mary"))?;
assert!((prob - 0.9).abs() < 0.001);
Ok(())
}
#[test]
fn test_variable_unification_simple() -> Result<()> {
let mut engine = ProbLogEngine::new();
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.8,
create_triple("john", "parent", "mary"),
)?);
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "ancestor_rule".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("parent".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Y".to_string()),
}],
}));
let prob = engine.query_probability(&create_triple("john", "ancestor", "mary"))?;
assert!((prob - 0.8).abs() < 0.001, "Expected 0.8, got {}", prob);
Ok(())
}
#[test]
fn test_variable_unification_multiple_facts() -> Result<()> {
let mut engine = ProbLogEngine::new();
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.9,
create_triple("john", "parent", "mary"),
)?);
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.7,
create_triple("mary", "parent", "bob"),
)?);
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "ancestor".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("parent".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Y".to_string()),
}],
}));
let prob1 = engine.query_probability(&create_triple("john", "ancestor", "mary"))?;
assert!((prob1 - 0.9).abs() < 0.001, "Expected 0.9, got {}", prob1);
let prob2 = engine.query_probability(&create_triple("mary", "ancestor", "bob"))?;
assert!((prob2 - 0.7).abs() < 0.001, "Expected 0.7, got {}", prob2);
Ok(())
}
#[test]
fn test_variable_unification_transitive() -> Result<()> {
let mut engine = ProbLogEngine::new().with_strategy(EvaluationStrategy::Auto);
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.9,
create_triple("john", "parent", "mary"),
)?);
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.8,
create_triple("mary", "parent", "bob"),
)?);
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "ancestor_base".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("parent".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Y".to_string()),
}],
}));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "ancestor_trans".to_string(),
body: vec![
RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("parent".to_string()),
object: Term::Variable("Y".to_string()),
},
RuleAtom::Triple {
subject: Term::Variable("Y".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Z".to_string()),
},
],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Z".to_string()),
}],
}));
let prob = engine.query_probability(&create_triple("john", "ancestor", "bob"))?;
assert!((prob - 0.72).abs() < 0.001, "Expected 0.72, got {}", prob);
assert!(
engine.stats.fixpoint_iterations > 0,
"Should have used fixpoint iteration for recursive rules"
);
Ok(())
}
#[test]
fn test_variable_unification_with_probabilistic_rule() -> Result<()> {
let mut engine = ProbLogEngine::new();
engine.add_fact(create_triple("john", "parent", "mary"));
engine.add_rule(ProbabilisticRule::probabilistic(
0.95,
Rule {
name: "related_rule".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("parent".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("related".to_string()),
object: Term::Variable("Y".to_string()),
}],
},
)?);
let prob = engine.query_probability(&create_triple("john", "related", "mary"))?;
assert!((prob - 0.95).abs() < 0.001, "Expected 0.95, got {}", prob);
Ok(())
}
#[test]
fn test_cycle_detection() -> Result<()> {
let mut engine = ProbLogEngine::new();
engine.add_fact(create_triple("a", "edge", "b"));
engine.add_fact(create_triple("b", "edge", "c"));
engine.add_fact(create_triple("c", "edge", "a"));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "path_transitive".to_string(),
body: vec![
RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("edge".to_string()),
object: Term::Variable("Y".to_string()),
},
RuleAtom::Triple {
subject: Term::Variable("Y".to_string()),
predicate: Term::Constant("path".to_string()),
object: Term::Variable("Z".to_string()),
},
],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("path".to_string()),
object: Term::Variable("Z".to_string()),
}],
}));
let prob = engine.query_probability(&create_triple("a", "path", "a"))?;
assert_eq!(
prob, 0.0,
"Cycle detection should return 0.0 for cyclic queries"
);
Ok(())
}
#[test]
fn test_unification_failure() -> Result<()> {
let mut engine = ProbLogEngine::new();
engine.add_fact(create_triple("john", "parent", "mary"));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "specific_rule".to_string(),
body: vec![create_triple("john", "parent", "bob")],
head: vec![create_triple("john", "ancestor", "bob")],
}));
let prob = engine.query_probability(&create_triple("john", "ancestor", "bob"))?;
assert!(
prob.abs() < 0.001,
"Expected 0.0, got {} - unification should fail",
prob
);
Ok(())
}
#[test]
fn test_fixpoint_transitive_closure() -> Result<()> {
let mut engine = ProbLogEngine::new().with_strategy(EvaluationStrategy::BottomUp);
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.9,
create_triple("john", "parent", "mary"),
)?);
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.8,
create_triple("mary", "parent", "bob"),
)?);
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "ancestor_base".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("parent".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Y".to_string()),
}],
}));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "ancestor_trans".to_string(),
body: vec![
RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("parent".to_string()),
object: Term::Variable("Y".to_string()),
},
RuleAtom::Triple {
subject: Term::Variable("Y".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Z".to_string()),
},
],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Z".to_string()),
}],
}));
let prob1 = engine.query_probability(&create_triple("john", "ancestor", "mary"))?;
assert!((prob1 - 0.9).abs() < 0.001, "Expected 0.9, got {}", prob1);
let prob2 = engine.query_probability(&create_triple("mary", "ancestor", "bob"))?;
assert!((prob2 - 0.8).abs() < 0.001, "Expected 0.8, got {}", prob2);
let prob3 = engine.query_probability(&create_triple("john", "ancestor", "bob"))?;
assert!(
(prob3 - 0.72).abs() < 0.001,
"Expected 0.72 (transitive), got {}",
prob3
);
assert!(
engine.stats.fixpoint_iterations > 0,
"Should have used fixpoint iteration"
);
assert!(
engine.stats.materialized_facts_count >= 5,
"Should have materialized at least 5 facts (2 parent + 3 ancestor)"
);
Ok(())
}
#[test]
fn test_fixpoint_cyclic_graph() -> Result<()> {
let mut engine = ProbLogEngine::new().with_strategy(EvaluationStrategy::BottomUp);
engine.add_fact(create_triple("a", "edge", "b"));
engine.add_fact(create_triple("b", "edge", "c"));
engine.add_fact(create_triple("c", "edge", "a"));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "path_base".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("edge".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("path".to_string()),
object: Term::Variable("Y".to_string()),
}],
}));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "path_transitive".to_string(),
body: vec![
RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("edge".to_string()),
object: Term::Variable("Y".to_string()),
},
RuleAtom::Triple {
subject: Term::Variable("Y".to_string()),
predicate: Term::Constant("path".to_string()),
object: Term::Variable("Z".to_string()),
},
],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("path".to_string()),
object: Term::Variable("Z".to_string()),
}],
}));
let prob = engine.query_probability(&create_triple("a", "path", "a"))?;
assert_eq!(
prob, 1.0,
"Fixpoint iteration should correctly compute cyclic path (got {})",
prob
);
let prob_ab = engine.query_probability(&create_triple("a", "path", "b"))?;
let prob_bc = engine.query_probability(&create_triple("b", "path", "c"))?;
let prob_ca = engine.query_probability(&create_triple("c", "path", "a"))?;
assert_eq!(prob_ab, 1.0, "path(a,b) should be 1.0");
assert_eq!(prob_bc, 1.0, "path(b,c) should be 1.0");
assert_eq!(prob_ca, 1.0, "path(c,a) should be 1.0");
Ok(())
}
#[test]
fn test_fixpoint_auto_strategy() -> Result<()> {
let mut engine = ProbLogEngine::new().with_strategy(EvaluationStrategy::Auto);
engine.add_fact(create_triple("john", "parent", "mary"));
engine.add_fact(create_triple("mary", "parent", "bob"));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "ancestor_base".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("parent".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Y".to_string()),
}],
}));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "ancestor_trans".to_string(),
body: vec![
RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("parent".to_string()),
object: Term::Variable("Y".to_string()),
},
RuleAtom::Triple {
subject: Term::Variable("Y".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Z".to_string()),
},
],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Z".to_string()),
}],
}));
let prob = engine.query_probability(&create_triple("john", "ancestor", "bob"))?;
assert_eq!(prob, 1.0, "Auto strategy should compute transitive closure");
assert!(
engine.stats.fixpoint_iterations > 0,
"Auto strategy should have used fixpoint iteration for recursive rules"
);
Ok(())
}
#[test]
fn test_fixpoint_probabilistic_combination() -> Result<()> {
let mut engine = ProbLogEngine::new().with_strategy(EvaluationStrategy::BottomUp);
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.6,
create_triple("a", "edge", "b"),
)?);
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.7,
create_triple("a", "edge", "c"),
)?);
engine.add_probabilistic_fact(ProbabilisticFact::new(
0.8,
create_triple("c", "edge", "b"),
)?);
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "connected_direct".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("edge".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("connected".to_string()),
object: Term::Variable("Y".to_string()),
}],
}));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "connected_indirect".to_string(),
body: vec![
RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("edge".to_string()),
object: Term::Variable("Z".to_string()),
},
RuleAtom::Triple {
subject: Term::Variable("Z".to_string()),
predicate: Term::Constant("edge".to_string()),
object: Term::Variable("Y".to_string()),
},
],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("connected".to_string()),
object: Term::Variable("Y".to_string()),
}],
}));
let prob = engine.query_probability(&create_triple("a", "connected", "b"))?;
let expected = 0.6 + 0.56 - (0.6 * 0.56); assert!(
(prob - expected).abs() < 0.001,
"Expected {} (disjunctive combination), got {}",
expected,
prob
);
Ok(())
}
#[test]
fn test_fixpoint_max_iterations() -> Result<()> {
let mut engine = ProbLogEngine::new()
.with_strategy(EvaluationStrategy::BottomUp)
.with_max_fixpoint_iterations(2);
engine.add_fact(create_triple("a", "edge", "b"));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "path".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("edge".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("path".to_string()),
object: Term::Variable("Y".to_string()),
}],
}));
let result = engine.query_probability(&create_triple("a", "path", "b"));
assert!(result.is_ok(), "Should succeed with low iteration limit");
Ok(())
}
#[test]
fn test_materialization_invalidation() -> Result<()> {
let mut engine = ProbLogEngine::new().with_strategy(EvaluationStrategy::BottomUp);
engine.add_fact(create_triple("a", "edge", "b"));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "path".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("edge".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("path".to_string()),
object: Term::Variable("Y".to_string()),
}],
}));
let prob1 = engine.query_probability(&create_triple("a", "path", "b"))?;
assert_eq!(prob1, 1.0);
let iter1 = engine.stats.fixpoint_iterations;
engine.add_fact(create_triple("b", "edge", "c"));
let prob2 = engine.query_probability(&create_triple("b", "path", "c"))?;
assert_eq!(prob2, 1.0);
assert!(
engine.stats.fixpoint_iterations >= iter1,
"Should have re-materialized after adding fact"
);
Ok(())
}
#[test]
fn test_fixpoint_statistics() -> Result<()> {
let mut engine = ProbLogEngine::new().with_strategy(EvaluationStrategy::BottomUp);
engine.add_fact(create_triple("a", "edge", "b"));
engine.add_fact(create_triple("b", "edge", "c"));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "path_base".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("edge".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("path".to_string()),
object: Term::Variable("Y".to_string()),
}],
}));
engine.add_rule(ProbabilisticRule::deterministic(Rule {
name: "path_trans".to_string(),
body: vec![
RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("path".to_string()),
object: Term::Variable("Y".to_string()),
},
RuleAtom::Triple {
subject: Term::Variable("Y".to_string()),
predicate: Term::Constant("edge".to_string()),
object: Term::Variable("Z".to_string()),
},
],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("path".to_string()),
object: Term::Variable("Z".to_string()),
}],
}));
let _prob = engine.query_probability(&create_triple("a", "path", "c"))?;
assert!(
engine.stats.fixpoint_iterations > 0,
"Should have recorded fixpoint iterations"
);
assert!(
engine.stats.materialized_facts_count > 0,
"Should have recorded materialized facts count"
);
assert!(engine.stats.queries > 0, "Should have recorded query count");
Ok(())
}
}