use crate::formats::n3_reasoning::{Matcher, Substitution, VariableBindings};
use crate::formats::n3_types::{N3Formula, N3Implication, N3Statement, N3Term, N3Variable};
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicUsize, Ordering};
static RULE_INSTANCE_COUNTER: AtomicUsize = AtomicUsize::new(0);
fn rename_rule(rule: &N3Implication) -> N3Implication {
let id = RULE_INSTANCE_COUNTER.fetch_add(1, Ordering::Relaxed);
let mut renames: HashMap<String, String> = HashMap::new();
for stmt in rule
.antecedent
.triples
.iter()
.chain(rule.consequent.triples.iter())
{
collect_vars_in_term(&stmt.subject, &mut renames, id);
collect_vars_in_term(&stmt.predicate, &mut renames, id);
collect_vars_in_term(&stmt.object, &mut renames, id);
}
let new_ant = rename_formula(&rule.antecedent, &renames);
let new_con = rename_formula(&rule.consequent, &renames);
N3Implication::new(new_ant, new_con)
}
fn collect_vars_in_term(term: &N3Term, renames: &mut HashMap<String, String>, id: usize) {
if let N3Term::Variable(v) = term {
renames
.entry(v.name.clone())
.or_insert_with(|| format!("{}_r{}", v.name, id));
}
}
fn rename_term(term: &N3Term, renames: &HashMap<String, String>) -> N3Term {
match term {
N3Term::Variable(v) => {
let new_name = renames
.get(&v.name)
.cloned()
.unwrap_or_else(|| v.name.clone());
N3Term::Variable(N3Variable::universal(&new_name))
}
N3Term::Formula(f) => N3Term::Formula(Box::new(rename_formula(f, renames))),
other => other.clone(),
}
}
fn rename_stmt(stmt: &N3Statement, renames: &HashMap<String, String>) -> N3Statement {
N3Statement::new(
rename_term(&stmt.subject, renames),
rename_term(&stmt.predicate, renames),
rename_term(&stmt.object, renames),
)
}
fn rename_formula(formula: &N3Formula, renames: &HashMap<String, String>) -> N3Formula {
N3Formula::with_statements(
formula
.triples
.iter()
.map(|s| rename_stmt(s, renames))
.collect(),
)
}
#[derive(Debug, Clone)]
pub struct SolvingContext {
pub depth: u32,
pub max_depth: u32,
pub visited: HashSet<String>,
}
impl SolvingContext {
pub fn new(max_depth: u32) -> Self {
Self {
depth: 0,
max_depth,
visited: HashSet::new(),
}
}
fn descend(&self) -> Self {
Self {
depth: self.depth + 1,
max_depth: self.max_depth,
visited: self.visited.clone(),
}
}
fn is_too_deep(&self) -> bool {
self.depth > self.max_depth
}
}
#[derive(Debug, Clone)]
pub struct ProofStep {
pub rule_applied: String,
pub bindings_summary: String,
pub depth: u32,
}
impl ProofStep {
fn new(
rule_applied: impl Into<String>,
bindings_summary: impl Into<String>,
depth: u32,
) -> Self {
Self {
rule_applied: rule_applied.into(),
bindings_summary: bindings_summary.into(),
depth,
}
}
}
#[derive(Debug, Clone)]
pub struct ProofTrace {
pub steps: Vec<ProofStep>,
pub goal_summary: String,
pub succeeded: bool,
}
impl ProofTrace {
fn new(goal_summary: impl Into<String>) -> Self {
Self {
steps: Vec::new(),
goal_summary: goal_summary.into(),
succeeded: false,
}
}
pub fn record_step(&mut self, step: ProofStep) {
self.steps.push(step);
}
}
fn unify_term(pattern: &N3Term, concrete: &N3Term, bindings: &mut VariableBindings) -> bool {
match pattern {
N3Term::Variable(var) => {
if let Some(already_bound) = bindings.get(&var.name) {
already_bound == concrete
} else {
bindings.bind(var.name.clone(), concrete.clone());
true
}
}
N3Term::Formula(_) => {
false
}
_ => {
pattern == concrete
}
}
}
fn unify_statements(pattern: &N3Statement, concrete: &N3Statement) -> Option<VariableBindings> {
let mut bindings = VariableBindings::new();
if unify_term(&pattern.subject, &concrete.subject, &mut bindings)
&& unify_term(&pattern.predicate, &concrete.predicate, &mut bindings)
&& unify_term(&pattern.object, &concrete.object, &mut bindings)
{
Some(bindings)
} else {
None
}
}
fn apply_bindings_to_statement(stmt: &N3Statement, bindings: &VariableBindings) -> N3Statement {
N3Statement::new(
bindings.substitute(&stmt.subject),
bindings.substitute(&stmt.predicate),
bindings.substitute(&stmt.object),
)
}
fn statement_key(stmt: &N3Statement) -> String {
format!("{} {} {}", stmt.subject, stmt.predicate, stmt.object)
}
fn merge_bindings(a: &VariableBindings, b: &VariableBindings) -> Option<VariableBindings> {
if a.is_compatible(b) {
let mut merged = a.clone();
merged.merge(b);
Some(merged)
} else {
None
}
}
fn resolve_goal(
goal: &N3Statement,
rules: &[N3Implication],
facts: &[N3Statement],
ctx: &SolvingContext,
trace: &mut ProofTrace,
) -> Vec<VariableBindings> {
if ctx.is_too_deep() {
return vec![];
}
let key = statement_key(goal);
if ctx.visited.contains(&key) {
return vec![];
}
let mut child_ctx = ctx.descend();
child_ctx.visited.insert(key.clone());
let mut results: Vec<VariableBindings> = Vec::new();
for fact in facts {
if let Some(bindings) = unify_statements(goal, fact) {
trace.record_step(ProofStep::new(
"fact",
format!("bound via fact: {}", statement_key(fact)),
ctx.depth,
));
results.push(bindings);
}
}
for rule in rules {
let fresh_rule = rename_rule(rule);
for consequent_stmt in &fresh_rule.consequent.triples {
if let Some(head_bindings) = unify_statements(consequent_stmt, goal) {
let body_goals: Vec<N3Statement> = fresh_rule
.antecedent
.triples
.iter()
.map(|s| apply_bindings_to_statement(s, &head_bindings))
.collect();
let sub_solutions =
resolve_goals_conjunction(&body_goals, rules, facts, &child_ctx, trace);
for sub_bindings in sub_solutions {
if let Some(merged) = merge_bindings(&head_bindings, &sub_bindings) {
trace.record_step(ProofStep::new(
format!("rule: {} => {}", rule.antecedent, rule.consequent),
format!("depth={}", ctx.depth),
ctx.depth,
));
results.push(merged);
}
}
}
}
}
results
}
fn resolve_goals_conjunction(
goals: &[N3Statement],
rules: &[N3Implication],
facts: &[N3Statement],
ctx: &SolvingContext,
trace: &mut ProofTrace,
) -> Vec<VariableBindings> {
if goals.is_empty() {
return vec![VariableBindings::new()];
}
let first = &goals[0];
let rest = &goals[1..];
let first_solutions = resolve_goal(first, rules, facts, ctx, trace);
let mut combined: Vec<VariableBindings> = Vec::new();
for first_binding in first_solutions {
let substituted_rest: Vec<N3Statement> = rest
.iter()
.map(|s| apply_bindings_to_statement(s, &first_binding))
.collect();
let rest_solutions = resolve_goals_conjunction(&substituted_rest, rules, facts, ctx, trace);
for rest_binding in rest_solutions {
if let Some(merged) = merge_bindings(&first_binding, &rest_binding) {
combined.push(merged);
}
}
}
combined
}
#[derive(Debug, Clone)]
pub struct BackwardChainer {
max_depth: u32,
#[allow(dead_code)]
matcher: Matcher,
}
impl BackwardChainer {
pub fn new(max_depth: u32) -> Self {
Self {
max_depth,
matcher: Matcher::new(),
}
}
pub fn default_depth() -> Self {
Self::new(100)
}
pub fn resolve(
&self,
goal: &N3Statement,
rules: &[N3Implication],
facts: &[N3Statement],
) -> Vec<VariableBindings> {
let ctx = SolvingContext::new(self.max_depth);
let mut trace = ProofTrace::new(statement_key(goal));
resolve_goal(goal, rules, facts, &ctx, &mut trace)
}
pub fn resolve_all(
&self,
goals: &[N3Statement],
rules: &[N3Implication],
facts: &[N3Statement],
) -> Vec<VariableBindings> {
let ctx = SolvingContext::new(self.max_depth);
let mut trace = ProofTrace::new("multi-goal conjunction");
resolve_goals_conjunction(goals, rules, facts, &ctx, &mut trace)
}
pub fn resolve_with_trace(
&self,
goal: &N3Statement,
rules: &[N3Implication],
facts: &[N3Statement],
) -> (Vec<VariableBindings>, ProofTrace) {
let ctx = SolvingContext::new(self.max_depth);
let mut trace = ProofTrace::new(statement_key(goal));
let results = resolve_goal(goal, rules, facts, &ctx, &mut trace);
trace.succeeded = !results.is_empty();
(results, trace)
}
}
impl Default for BackwardChainer {
fn default() -> Self {
Self::default_depth()
}
}
#[derive(Debug, Clone, Default)]
pub struct BackwardChainingEngine {
chainer: BackwardChainer,
}
impl BackwardChainingEngine {
pub fn new(max_depth: u32) -> Self {
Self {
chainer: BackwardChainer::new(max_depth),
}
}
pub fn solve_with_trace(
&self,
goal: &N3Statement,
rules: &[N3Implication],
facts: &[N3Statement],
) -> (Vec<VariableBindings>, ProofTrace) {
self.chainer.resolve_with_trace(goal, rules, facts)
}
pub fn solve(
&self,
goal: &N3Statement,
rules: &[N3Implication],
facts: &[N3Statement],
) -> Vec<VariableBindings> {
self.chainer.resolve(goal, rules, facts)
}
}
#[cfg(test)]
mod tests {
use super::*;
use oxirs_core::model::NamedNode;
fn iri(s: &str) -> N3Term {
N3Term::NamedNode(NamedNode::new(s).expect("valid IRI"))
}
fn var(name: &str) -> N3Term {
use crate::formats::n3_types::N3Variable;
N3Term::Variable(N3Variable::universal(name))
}
fn stmt(s: N3Term, p: N3Term, o: N3Term) -> N3Statement {
N3Statement::new(s, p, o)
}
fn rule(ant_stmts: Vec<N3Statement>, con_stmts: Vec<N3Statement>) -> N3Implication {
use crate::formats::n3_types::N3Formula;
N3Implication::new(
N3Formula::with_statements(ant_stmts),
N3Formula::with_statements(con_stmts),
)
}
#[test]
fn test_resolve_ground_fact() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let chainer = BackwardChainer::new(10);
let results = chainer.resolve(&stmt(a, p, b), &[], &facts);
assert_eq!(
results.len(),
1,
"ground goal should match exactly one fact"
);
}
#[test]
fn test_resolve_no_match() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let c = iri("http://ex.org/c");
let facts = vec![stmt(a.clone(), p.clone(), b)];
let chainer = BackwardChainer::new(10);
let results = chainer.resolve(&stmt(a, p, c), &[], &facts);
assert!(
results.is_empty(),
"goal with no matching fact returns empty"
);
}
#[test]
fn test_resolve_variable_subject() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let chainer = BackwardChainer::new(10);
let goal = stmt(var("x"), p, b);
let results = chainer.resolve(&goal, &[], &facts);
assert!(!results.is_empty());
let binding = results[0].get("x").expect("x must be bound");
assert_eq!(binding, &a);
}
#[test]
fn test_resolve_variable_object() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let chainer = BackwardChainer::new(10);
let goal = stmt(a, p, var("y"));
let results = chainer.resolve(&goal, &[], &facts);
assert!(!results.is_empty());
let binding = results[0].get("y").expect("y must be bound");
assert_eq!(binding, &b);
}
#[test]
fn test_resolve_both_variables() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![
stmt(a.clone(), p.clone(), b.clone()),
stmt(b.clone(), p.clone(), a.clone()),
];
let chainer = BackwardChainer::new(10);
let goal = stmt(var("x"), p, var("y"));
let results = chainer.resolve(&goal, &[], &facts);
assert_eq!(results.len(), 2, "should match both facts");
}
#[test]
fn test_symmetry_rule() {
let knows = iri("http://ex.org/knows");
let alice = iri("http://ex.org/alice");
let bob = iri("http://ex.org/bob");
let facts = vec![stmt(alice.clone(), knows.clone(), bob.clone())];
let rules = vec![rule(
vec![stmt(var("x"), knows.clone(), var("y"))],
vec![stmt(var("y"), knows.clone(), var("x"))],
)];
let chainer = BackwardChainer::new(10);
let goal = stmt(bob.clone(), knows.clone(), alice.clone());
let results = chainer.resolve(&goal, &rules, &facts);
assert!(
!results.is_empty(),
"symmetry rule should derive bob knows alice"
);
}
#[test]
fn test_transitivity_two_hops() {
let parent = iri("http://ex.org/parent");
let grandparent = iri("http://ex.org/grandparent");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let c = iri("http://ex.org/c");
let facts = vec![
stmt(a.clone(), parent.clone(), b.clone()),
stmt(b.clone(), parent.clone(), c.clone()),
];
let rules = vec![rule(
vec![
stmt(var("x"), parent.clone(), var("y")),
stmt(var("y"), parent.clone(), var("z")),
],
vec![stmt(var("x"), grandparent.clone(), var("z"))],
)];
let chainer = BackwardChainer::new(20);
let goal = stmt(a.clone(), grandparent.clone(), c.clone());
let results = chainer.resolve(&goal, &rules, &facts);
assert!(
!results.is_empty(),
"grandparent should be derivable via two hops"
);
}
#[test]
fn test_transitivity_three_hops() {
let ancestor = iri("http://ex.org/ancestor");
let parent = iri("http://ex.org/parent");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let c = iri("http://ex.org/c");
let d = iri("http://ex.org/d");
let facts = vec![
stmt(a.clone(), parent.clone(), b.clone()),
stmt(b.clone(), parent.clone(), c.clone()),
stmt(c.clone(), parent.clone(), d.clone()),
];
let rules = vec![
rule(
vec![stmt(var("x"), parent.clone(), var("y"))],
vec![stmt(var("x"), ancestor.clone(), var("y"))],
),
rule(
vec![
stmt(var("x"), ancestor.clone(), var("y")),
stmt(var("y"), parent.clone(), var("z")),
],
vec![stmt(var("x"), ancestor.clone(), var("z"))],
),
];
let chainer = BackwardChainer::new(30);
let goal = stmt(a.clone(), ancestor.clone(), d.clone());
let results = chainer.resolve(&goal, &rules, &facts);
assert!(
!results.is_empty(),
"ancestor should be derivable across three hops"
);
}
#[test]
fn test_depth_limit_prevents_stack_overflow() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let rules = vec![rule(
vec![stmt(var("x"), p.clone(), var("y"))],
vec![stmt(var("y"), p.clone(), var("x"))],
)];
let chainer = BackwardChainer::new(3);
let goal = stmt(b.clone(), p.clone(), a.clone());
let results = chainer.resolve(&goal, &rules, &facts);
let _ = results;
}
#[test]
fn test_zero_depth_limit() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let rules = vec![rule(
vec![stmt(var("x"), p.clone(), var("y"))],
vec![stmt(var("y"), p.clone(), var("x"))],
)];
let chainer = BackwardChainer::new(0);
let goal = stmt(a.clone(), p.clone(), b.clone());
let results = chainer.resolve(&goal, &[], &facts);
assert!(!results.is_empty());
let rule_goal = stmt(b.clone(), p.clone(), a.clone());
let rule_results = chainer.resolve(&rule_goal, &rules, &facts);
let _ = rule_results;
}
#[test]
fn test_resolve_all_conjunction() {
let p = iri("http://ex.org/p");
let q = iri("http://ex.org/q");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![
stmt(a.clone(), p.clone(), b.clone()),
stmt(b.clone(), q.clone(), a.clone()),
];
let chainer = BackwardChainer::new(10);
let goals = vec![
stmt(var("x"), p.clone(), var("y")),
stmt(var("y"), q.clone(), var("x")),
];
let results = chainer.resolve_all(&goals, &[], &facts);
assert!(!results.is_empty(), "conjunction should be satisfiable");
let binding = &results[0];
assert!(binding.is_bound("x"));
assert!(binding.is_bound("y"));
}
#[test]
fn test_resolve_all_empty_goals() {
let chainer = BackwardChainer::new(10);
let results = chainer.resolve_all(&[], &[], &[]);
assert_eq!(results.len(), 1);
assert!(results[0].all_bindings().is_empty());
}
#[test]
fn test_resolve_all_unsatisfiable() {
let p = iri("http://ex.org/p");
let q = iri("http://ex.org/q");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let c = iri("http://ex.org/c");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let chainer = BackwardChainer::new(10);
let goals = vec![
stmt(var("x"), p.clone(), var("y")),
stmt(var("y"), q.clone(), var("z")),
];
let results = chainer.resolve_all(&goals, &[], &facts);
assert!(results.is_empty(), "conjunction should be unsatisfiable");
let _ = c;
}
#[test]
fn test_proof_trace_fact_match() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let chainer = BackwardChainer::new(10);
let (results, trace) = chainer.resolve_with_trace(&stmt(a, p, b), &[], &facts);
assert!(!results.is_empty());
assert!(trace.succeeded);
assert!(!trace.steps.is_empty());
assert_eq!(trace.steps[0].rule_applied, "fact");
}
#[test]
fn test_proof_trace_rule_application() {
let knows = iri("http://ex.org/knows");
let alice = iri("http://ex.org/alice");
let bob = iri("http://ex.org/bob");
let facts = vec![stmt(alice.clone(), knows.clone(), bob.clone())];
let rules = vec![rule(
vec![stmt(var("x"), knows.clone(), var("y"))],
vec![stmt(var("y"), knows.clone(), var("x"))],
)];
let engine = BackwardChainingEngine::new(10);
let goal = stmt(bob.clone(), knows.clone(), alice.clone());
let (results, trace) = engine.solve_with_trace(&goal, &rules, &facts);
assert!(!results.is_empty());
assert!(trace.succeeded);
assert!(!trace.steps.is_empty());
}
#[test]
fn test_proof_trace_failed_goal() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let c = iri("http://ex.org/c");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let chainer = BackwardChainer::new(10);
let (results, trace) = chainer.resolve_with_trace(&stmt(a, p, c), &[], &facts);
assert!(results.is_empty());
assert!(!trace.succeeded);
}
#[test]
fn test_engine_solve() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let engine = BackwardChainingEngine::new(10);
let results = engine.solve(&stmt(a, p, var("y")), &[], &facts);
assert!(!results.is_empty());
}
#[test]
fn test_engine_default() {
let engine = BackwardChainingEngine::default();
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let results = engine.solve(&stmt(a, p, b), &[], &facts);
assert!(!results.is_empty());
}
#[test]
fn test_multiple_applicable_rules() {
let child = iri("http://ex.org/child");
let parent = iri("http://ex.org/parent");
let offspring = iri("http://ex.org/offspring");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![
stmt(a.clone(), child.clone(), b.clone()),
stmt(b.clone(), parent.clone(), a.clone()),
];
let rules = vec![
rule(
vec![stmt(var("x"), child.clone(), var("y"))],
vec![stmt(var("x"), offspring.clone(), var("y"))],
),
rule(
vec![stmt(var("y"), parent.clone(), var("x"))],
vec![stmt(var("x"), offspring.clone(), var("y"))],
),
];
let chainer = BackwardChainer::new(10);
let goal = stmt(a.clone(), offspring.clone(), b.clone());
let results = chainer.resolve(&goal, &rules, &facts);
assert!(!results.is_empty(), "at least one derivation should exist");
}
#[test]
fn test_no_rules_no_facts() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let chainer = BackwardChainer::new(10);
let results = chainer.resolve(&stmt(a, p, b), &[], &[]);
assert!(results.is_empty());
}
#[test]
fn test_variable_predicate() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let chainer = BackwardChainer::new(10);
let goal = stmt(a.clone(), var("pred"), b.clone());
let results = chainer.resolve(&goal, &[], &facts);
assert!(!results.is_empty());
let binding = results[0].get("pred").expect("pred must be bound");
assert_eq!(binding, &p);
}
#[test]
fn test_solving_context_depth_tracking() {
let ctx = SolvingContext::new(5);
assert_eq!(ctx.depth, 0);
assert!(!ctx.is_too_deep());
let child = ctx.descend();
assert_eq!(child.depth, 1);
let mut deep = ctx.clone();
for _ in 0..6 {
deep = deep.descend();
}
assert!(deep.is_too_deep());
}
#[test]
fn test_visited_set_prevents_cycles() {
let r = iri("http://ex.org/r");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![stmt(a.clone(), r.clone(), b.clone())];
let rules = vec![rule(
vec![stmt(var("x"), r.clone(), var("y"))],
vec![stmt(var("y"), r.clone(), var("x"))],
)];
let chainer = BackwardChainer::new(5);
let goal = stmt(b.clone(), r.clone(), a.clone());
let results = chainer.resolve(&goal, &rules, &facts);
let _ = results;
}
#[test]
fn test_proof_step_depth_recorded() {
let p = iri("http://ex.org/p");
let a = iri("http://ex.org/a");
let b = iri("http://ex.org/b");
let facts = vec![stmt(a.clone(), p.clone(), b.clone())];
let chainer = BackwardChainer::new(10);
let (_, trace) = chainer.resolve_with_trace(&stmt(a, p, b), &[], &facts);
assert!(!trace.steps.is_empty());
assert_eq!(trace.steps[0].depth, 0, "top-level match is at depth 0");
}
}