use crate::forward::Substitution;
use crate::{Rule, RuleAtom, Term};
use anyhow::Result;
use scirs2_core::metrics::{Counter, Gauge};
use std::collections::{HashMap, HashSet};
use tracing::{debug, info, trace, warn};
lazy_static::lazy_static! {
static ref SUBSTITUTION_CLONES: Counter = Counter::new("backward_chain_substitution_clones".to_string());
static ref CONTEXT_CLONES: Counter = Counter::new("backward_chain_context_clones".to_string());
static ref ACTIVE_PROOF_DEPTH: Gauge = Gauge::new("backward_chain_active_proof_depth".to_string());
}
#[derive(Debug, Clone, Default)]
pub struct ProofContext {
pub path: Vec<RuleAtom>,
pub substitution: Substitution,
pub depth: usize,
}
#[derive(Debug, Clone)]
pub enum ProofResult {
Success(Substitution),
Failure,
Partial(Vec<RuleAtom>),
}
#[derive(Debug)]
pub struct BackwardChainer {
rules: Vec<Rule>,
facts: HashSet<RuleAtom>,
max_depth: usize,
debug_mode: bool,
proof_cache: HashMap<RuleAtom, ProofResult>,
}
impl Default for BackwardChainer {
fn default() -> Self {
Self::new()
}
}
impl BackwardChainer {
pub fn new() -> Self {
Self {
rules: Vec::new(),
facts: HashSet::new(),
max_depth: 100,
debug_mode: false,
proof_cache: HashMap::new(),
}
}
pub fn with_config(max_depth: usize, debug_mode: bool) -> Self {
Self {
rules: Vec::new(),
facts: HashSet::new(),
max_depth,
debug_mode,
proof_cache: HashMap::new(),
}
}
pub fn add_rule(&mut self, rule: Rule) {
if self.debug_mode {
debug!("Adding rule: {}", rule.name);
}
self.rules.push(rule);
self.proof_cache.clear(); }
pub fn add_rules(&mut self, rules: Vec<Rule>) {
for rule in rules {
self.add_rule(rule);
}
}
pub fn add_fact(&mut self, fact: RuleAtom) {
if self.debug_mode {
trace!("Adding fact: {:?}", fact);
}
self.facts.insert(fact);
self.proof_cache.clear(); }
pub fn add_facts(&mut self, facts: Vec<RuleAtom>) {
for fact in facts {
self.add_fact(fact);
}
}
pub fn get_facts(&self) -> Vec<RuleAtom> {
self.facts.iter().cloned().collect()
}
pub fn clear_facts(&mut self) {
self.facts.clear();
self.proof_cache.clear();
}
pub fn clear_cache(&mut self) {
self.proof_cache.clear();
}
pub fn prove(&mut self, goal: &RuleAtom) -> Result<bool> {
info!("Starting backward chaining proof for goal: {:?}", goal);
let context = ProofContext::default();
let result = self.prove_goal(goal, &context)?;
match result {
ProofResult::Success(_) => {
info!("Goal successfully proven");
Ok(true)
}
ProofResult::Failure => {
info!("Goal failed to prove");
Ok(false)
}
ProofResult::Partial(remaining) => {
info!(
"Goal partially proven, {} remaining subgoals",
remaining.len()
);
Ok(false)
}
}
}
pub fn prove_all(&mut self, goal: &RuleAtom) -> Result<Vec<Substitution>> {
info!("Finding all proofs for goal: {:?}", goal);
let context = ProofContext::default();
let mut substitutions = Vec::new();
self.find_all_proofs(goal, &context, &mut substitutions)?;
info!("Found {} valid proofs", substitutions.len());
Ok(substitutions)
}
fn prove_goal(&mut self, goal: &RuleAtom, context: &ProofContext) -> Result<ProofResult> {
if context.depth > self.max_depth {
warn!("Maximum proof depth exceeded for goal: {:?}", goal);
return Ok(ProofResult::Failure);
}
if context.path.contains(goal) {
if self.debug_mode {
debug!("Cycle detected for goal: {:?}", goal);
}
return Ok(ProofResult::Failure);
}
if let Some(cached_result) = self.proof_cache.get(goal) {
if self.debug_mode {
trace!("Using cached result for goal: {:?}", goal);
}
return Ok(cached_result.clone());
}
let result = self.prove_goal_internal(goal, context)?;
self.proof_cache.insert(goal.clone(), result.clone());
Ok(result)
}
fn prove_goal_internal(
&mut self,
goal: &RuleAtom,
context: &ProofContext,
) -> Result<ProofResult> {
if self.debug_mode {
debug!("Proving goal at depth {}: {:?}", context.depth, goal);
}
if let Some(substitution) = self.match_against_facts(goal, &context.substitution)? {
if self.debug_mode {
debug!("Goal proven by direct fact match");
}
return Ok(ProofResult::Success(substitution));
}
self.prove_using_rules(goal, context)
}
fn match_against_facts(
&self,
goal: &RuleAtom,
context_sub: &Substitution,
) -> Result<Option<Substitution>> {
match goal {
RuleAtom::Triple {
subject,
predicate,
object,
} => {
for fact in &self.facts {
if let RuleAtom::Triple {
subject: fact_subject,
predicate: fact_predicate,
object: fact_object,
} = fact
{
if let Some(substitution) = self.unify_triple(
(subject, predicate, object),
(fact_subject, fact_predicate, fact_object),
context_sub,
)? {
SUBSTITUTION_CLONES.inc();
return Ok(Some(substitution));
}
}
}
Ok(None)
}
RuleAtom::Builtin { name, args } => {
let result = self.evaluate_builtin(name, args, context_sub)?;
if result.is_some() {
SUBSTITUTION_CLONES.inc();
}
Ok(result)
}
RuleAtom::NotEqual { left, right } => {
let left_term = self.substitute_term(left, context_sub);
let right_term = self.substitute_term(right, context_sub);
if !self.terms_equal(&left_term, &right_term) {
SUBSTITUTION_CLONES.inc();
Ok(Some(context_sub.clone()))
} else {
Ok(None)
}
}
RuleAtom::GreaterThan { left, right } => {
let left_term = self.substitute_term(left, context_sub);
let right_term = self.substitute_term(right, context_sub);
if self.compare_terms(&left_term, &right_term) > 0 {
SUBSTITUTION_CLONES.inc();
Ok(Some(context_sub.clone()))
} else {
Ok(None)
}
}
RuleAtom::LessThan { left, right } => {
let left_term = self.substitute_term(left, context_sub);
let right_term = self.substitute_term(right, context_sub);
if self.compare_terms(&left_term, &right_term) < 0 {
SUBSTITUTION_CLONES.inc();
Ok(Some(context_sub.clone()))
} else {
Ok(None)
}
}
}
}
fn prove_using_rules(
&mut self,
goal: &RuleAtom,
context: &ProofContext,
) -> Result<ProofResult> {
let mut applicable_rules = Vec::new();
for rule in &self.rules {
for head_atom in &rule.head {
if let Some(head_substitution) =
self.unify_atoms(goal, head_atom, &context.substitution)?
{
applicable_rules.push((
rule.name.clone(),
rule.body.clone(),
head_substitution,
));
}
}
}
for (rule_name, rule_body, head_substitution) in applicable_rules {
if self.debug_mode {
debug!("Trying rule '{}' for goal: {:?}", rule_name, goal);
}
CONTEXT_CLONES.inc();
let mut new_context = context.clone();
new_context.path.push(goal.clone());
new_context.substitution = head_substitution;
new_context.depth += 1;
ACTIVE_PROOF_DEPTH.set(new_context.depth as f64);
if let Some(final_substitution) = self.prove_rule_body(&rule_body, &new_context)? {
if self.debug_mode {
debug!("Rule '{}' successfully proven", rule_name);
}
return Ok(ProofResult::Success(final_substitution));
}
}
Ok(ProofResult::Failure)
}
fn prove_rule_body(
&mut self,
body: &[RuleAtom],
context: &ProofContext,
) -> Result<Option<Substitution>> {
let mut current_substitution = context.substitution.clone();
for atom in body {
let instantiated_atom = self.apply_substitution(atom, ¤t_substitution)?;
let subgoal_context = ProofContext {
path: context.path.clone(),
substitution: current_substitution.clone(),
depth: context.depth,
};
match self.prove_goal(&instantiated_atom, &subgoal_context)? {
ProofResult::Success(new_substitution) => {
current_substitution =
self.merge_substitutions(current_substitution, new_substitution)?;
}
ProofResult::Failure => {
return Ok(None);
}
ProofResult::Partial(_) => {
return Ok(None);
}
}
}
Ok(Some(current_substitution))
}
fn find_all_proofs(
&mut self,
goal: &RuleAtom,
context: &ProofContext,
results: &mut Vec<Substitution>,
) -> Result<()> {
if context.depth > self.max_depth {
return Ok(());
}
if context.path.contains(goal) {
return Ok(());
}
if let Some(substitution) = self.match_against_facts(goal, &context.substitution)? {
results.push(substitution);
}
let mut applicable_rules = Vec::new();
for rule in &self.rules {
for head_atom in &rule.head {
if let Some(head_substitution) =
self.unify_atoms(goal, head_atom, &context.substitution)?
{
applicable_rules.push((rule.body.clone(), head_substitution));
}
}
}
for (rule_body, head_substitution) in applicable_rules {
CONTEXT_CLONES.inc();
let mut new_context = context.clone();
new_context.path.push(goal.clone());
new_context.substitution = head_substitution;
new_context.depth += 1;
if let Some(final_substitution) = self.prove_rule_body(&rule_body, &new_context)? {
results.push(final_substitution);
}
}
Ok(())
}
fn unify_atoms(
&self,
atom1: &RuleAtom,
atom2: &RuleAtom,
substitution: &Substitution,
) -> Result<Option<Substitution>> {
match (atom1, atom2) {
(
RuleAtom::Triple {
subject: s1,
predicate: p1,
object: o1,
},
RuleAtom::Triple {
subject: s2,
predicate: p2,
object: o2,
},
) => self.unify_triple((s1, p1, o1), (s2, p2, o2), substitution),
(
RuleAtom::Builtin { name: n1, args: a1 },
RuleAtom::Builtin { name: n2, args: a2 },
) => {
if n1 == n2 && a1.len() == a2.len() {
let mut new_substitution = substitution.clone();
for (arg1, arg2) in a1.iter().zip(a2.iter()) {
if !self.unify_terms(arg1, arg2, &mut new_substitution)? {
return Ok(None);
}
}
Ok(Some(new_substitution))
} else {
Ok(None)
}
}
_ => Ok(None),
}
}
fn unify_triple(
&self,
triple1: (&Term, &Term, &Term),
triple2: (&Term, &Term, &Term),
substitution: &Substitution,
) -> Result<Option<Substitution>> {
let mut new_substitution = substitution.clone();
if !self.unify_terms(triple1.0, triple2.0, &mut new_substitution)? {
return Ok(None);
}
if !self.unify_terms(triple1.1, triple2.1, &mut new_substitution)? {
return Ok(None);
}
if !self.unify_terms(triple1.2, triple2.2, &mut new_substitution)? {
return Ok(None);
}
Ok(Some(new_substitution))
}
#[allow(clippy::only_used_in_recursion)]
fn unify_terms(
&self,
term1: &Term,
term2: &Term,
substitution: &mut Substitution,
) -> Result<bool> {
match (term1, term2) {
(Term::Variable(var), term) | (term, Term::Variable(var)) => {
if let Some(existing) = substitution.get(var).cloned() {
self.unify_terms(&existing, term, substitution)
} else {
substitution.insert(var.clone(), term.clone());
Ok(true)
}
}
(Term::Constant(c1), Term::Constant(c2)) => Ok(c1 == c2),
(Term::Literal(l1), Term::Literal(l2)) => Ok(l1 == l2),
(Term::Constant(c), Term::Literal(l)) | (Term::Literal(l), Term::Constant(c)) => {
Ok(c == l)
}
_ => Ok(false),
}
}
fn apply_substitution(&self, atom: &RuleAtom, substitution: &Substitution) -> Result<RuleAtom> {
match atom {
RuleAtom::Triple {
subject,
predicate,
object,
} => Ok(RuleAtom::Triple {
subject: self.substitute_term(subject, substitution),
predicate: self.substitute_term(predicate, substitution),
object: self.substitute_term(object, substitution),
}),
RuleAtom::Builtin { name, args } => {
let substituted_args = args
.iter()
.map(|arg| self.substitute_term(arg, substitution))
.collect();
Ok(RuleAtom::Builtin {
name: name.clone(),
args: substituted_args,
})
}
RuleAtom::NotEqual { left, right } => Ok(RuleAtom::NotEqual {
left: self.substitute_term(left, substitution),
right: self.substitute_term(right, substitution),
}),
RuleAtom::GreaterThan { left, right } => Ok(RuleAtom::GreaterThan {
left: self.substitute_term(left, substitution),
right: self.substitute_term(right, substitution),
}),
RuleAtom::LessThan { left, right } => Ok(RuleAtom::LessThan {
left: self.substitute_term(left, substitution),
right: self.substitute_term(right, substitution),
}),
}
}
#[allow(clippy::only_used_in_recursion)]
fn substitute_term(&self, term: &Term, substitution: &Substitution) -> Term {
match term {
Term::Variable(var) => substitution
.get(var)
.cloned()
.unwrap_or_else(|| term.clone()),
Term::Function { name, args } => {
let substituted_args = args
.iter()
.map(|arg| self.substitute_term(arg, substitution))
.collect();
Term::Function {
name: name.clone(),
args: substituted_args,
}
}
_ => term.clone(),
}
}
fn merge_substitutions(&self, sub1: Substitution, sub2: Substitution) -> Result<Substitution> {
let mut merged = sub1;
for (var, term) in sub2 {
if let Some(existing) = merged.get(&var) {
if !self.terms_equal(existing, &term) {
return Err(anyhow::anyhow!(
"Inconsistent substitutions for variable {}",
var
));
}
} else {
merged.insert(var, term);
}
}
Ok(merged)
}
fn terms_equal(&self, term1: &Term, term2: &Term) -> bool {
match (term1, term2) {
(Term::Variable(v1), Term::Variable(v2)) => v1 == v2,
(Term::Constant(c1), Term::Constant(c2)) => c1 == c2,
(Term::Literal(l1), Term::Literal(l2)) => l1 == l2,
(Term::Constant(c), Term::Literal(l)) | (Term::Literal(l), Term::Constant(c)) => c == l,
(Term::Function { name: n1, args: a1 }, Term::Function { name: n2, args: a2 }) => {
n1 == n2 && a1 == a2
}
_ => false,
}
}
fn compare_terms(&self, term1: &Term, term2: &Term) -> i32 {
match (term1, term2) {
(Term::Constant(c1), Term::Constant(c2)) => {
if let (Ok(n1), Ok(n2)) = (c1.parse::<f64>(), c2.parse::<f64>()) {
if n1 < n2 {
-1
} else if n1 > n2 {
1
} else {
0
}
} else {
if c1 < c2 {
-1
} else if c1 > c2 {
1
} else {
0
}
}
}
(Term::Literal(l1), Term::Literal(l2)) => {
if let (Ok(n1), Ok(n2)) = (l1.parse::<f64>(), l2.parse::<f64>()) {
if n1 < n2 {
-1
} else if n1 > n2 {
1
} else {
0
}
} else {
if l1 < l2 {
-1
} else if l1 > l2 {
1
} else {
0
}
}
}
(Term::Constant(c), Term::Literal(l)) | (Term::Literal(l), Term::Constant(c)) => {
if let (Ok(n1), Ok(n2)) = (c.parse::<f64>(), l.parse::<f64>()) {
if n1 < n2 {
-1
} else if n1 > n2 {
1
} else {
0
}
} else {
if c < l {
-1
} else if c > l {
1
} else {
0
}
}
}
(Term::Function { name: n1, args: a1 }, Term::Function { name: n2, args: a2 }) => {
if n1 < n2 {
-1
} else if n1 > n2 {
1
} else {
if a1.len() < a2.len() {
-1
} else if a1.len() > a2.len() {
1
} else {
0
} }
}
_ => 0,
}
}
fn evaluate_builtin(
&self,
name: &str,
args: &[Term],
substitution: &Substitution,
) -> Result<Option<Substitution>> {
match name {
"equal" => {
if args.len() != 2 {
return Err(anyhow::anyhow!("equal/2 requires exactly 2 arguments"));
}
let arg1 = self.substitute_term(&args[0], substitution);
let arg2 = self.substitute_term(&args[1], substitution);
if self.terms_equal(&arg1, &arg2) {
Ok(Some(substitution.clone()))
} else {
Ok(None)
}
}
"notEqual" => {
if args.len() != 2 {
return Err(anyhow::anyhow!("notEqual/2 requires exactly 2 arguments"));
}
let arg1 = self.substitute_term(&args[0], substitution);
let arg2 = self.substitute_term(&args[1], substitution);
if !self.terms_equal(&arg1, &arg2) {
Ok(Some(substitution.clone()))
} else {
Ok(None)
}
}
"bound" => {
if args.len() != 1 {
return Err(anyhow::anyhow!("bound/1 requires exactly 1 argument"));
}
match &args[0] {
Term::Variable(var) => {
if substitution.contains_key(var) {
Ok(Some(substitution.clone()))
} else {
Ok(None)
}
}
_ => Ok(Some(substitution.clone())),
}
}
"unbound" => {
if args.len() != 1 {
return Err(anyhow::anyhow!("unbound/1 requires exactly 1 argument"));
}
match &args[0] {
Term::Variable(var) => {
if !substitution.contains_key(var) {
Ok(Some(substitution.clone()))
} else {
Ok(None)
}
}
_ => Ok(None),
}
}
_ => {
warn!("Unknown built-in predicate: {}", name);
Ok(None)
}
}
}
pub fn get_stats(&self) -> BackwardChainingStats {
BackwardChainingStats {
total_facts: self.facts.len(),
total_rules: self.rules.len(),
cache_size: self.proof_cache.len(),
}
}
pub fn query(&mut self, pattern: &RuleAtom) -> Result<Vec<RuleAtom>> {
let mut results = Vec::new();
let empty_substitution = HashMap::new();
for fact in &self.facts {
if self
.unify_atoms(pattern, fact, &empty_substitution)?
.is_some()
{
results.push(fact.clone());
}
}
let substitutions = self.prove_all(pattern)?;
for substitution in substitutions {
let instantiated = self.apply_substitution(pattern, &substitution)?;
if !results.contains(&instantiated) {
results.push(instantiated);
}
}
Ok(results)
}
}
#[derive(Debug, Clone)]
pub struct BackwardChainingStats {
pub total_facts: usize,
pub total_rules: usize,
pub cache_size: usize,
}
impl std::fmt::Display for BackwardChainingStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Facts: {}, Rules: {}, Cache: {}",
self.total_facts, self.total_rules, self.cache_size
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_backward_chaining() -> Result<(), Box<dyn std::error::Error>> {
let mut chainer = BackwardChainer::new();
chainer.add_rule(Rule {
name: "mortality_rule".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("mortal".to_string()),
}],
});
chainer.add_fact(RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
});
let goal = RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("mortal".to_string()),
};
assert!(chainer.prove(&goal)?);
Ok(())
}
#[test]
fn test_fact_matching() -> Result<(), Box<dyn std::error::Error>> {
let mut chainer = BackwardChainer::new();
chainer.add_fact(RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
});
let goal = RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
};
assert!(chainer.prove(&goal)?);
Ok(())
}
#[test]
fn test_variable_substitution() -> Result<(), Box<dyn std::error::Error>> {
let mut chainer = BackwardChainer::new();
chainer.add_fact(RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
});
let goal = RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
};
let substitutions = chainer.prove_all(&goal)?;
assert_eq!(substitutions.len(), 1);
assert_eq!(
substitutions[0].get("X"),
Some(&Term::Constant("socrates".to_string()))
);
Ok(())
}
#[test]
fn test_transitive_proof() -> Result<(), Box<dyn std::error::Error>> {
let mut chainer = BackwardChainer::with_config(20, true);
chainer.add_rule(Rule {
name: "direct_ancestor".to_string(),
body: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("parent".to_string()),
object: Term::Variable("Y".to_string()),
}],
head: vec![RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Variable("Y".to_string()),
}],
});
chainer.add_fact(RuleAtom::Triple {
subject: Term::Constant("john".to_string()),
predicate: Term::Constant("parent".to_string()),
object: Term::Constant("mary".to_string()),
});
let direct_goal = RuleAtom::Triple {
subject: Term::Constant("john".to_string()),
predicate: Term::Constant("ancestor".to_string()),
object: Term::Constant("mary".to_string()),
};
assert!(chainer.prove(&direct_goal)?);
Ok(())
}
#[test]
fn test_query_functionality() -> Result<(), Box<dyn std::error::Error>> {
let mut chainer = BackwardChainer::new();
chainer.add_fact(RuleAtom::Triple {
subject: Term::Constant("socrates".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
});
chainer.add_fact(RuleAtom::Triple {
subject: Term::Constant("plato".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
});
let pattern = RuleAtom::Triple {
subject: Term::Variable("X".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("human".to_string()),
};
let results = chainer.query(&pattern)?;
assert_eq!(results.len(), 2);
Ok(())
}
}