use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Goal {
pub predicate: String,
pub subject: String,
pub object: String,
}
impl Goal {
pub fn new(
predicate: impl Into<String>,
subject: impl Into<String>,
object: impl Into<String>,
) -> Self {
Self {
predicate: predicate.into(),
subject: subject.into(),
object: object.into(),
}
}
}
#[derive(Debug, Clone)]
pub struct Clause {
pub head: Goal,
pub body: Vec<Goal>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Substitution(pub HashMap<String, String>);
impl Substitution {
pub fn new() -> Self {
Self(HashMap::new())
}
pub fn bind(&mut self, var: String, val: String) {
self.0.insert(var, val);
}
pub fn get(&self, var: &str) -> Option<&str> {
self.0.get(var).map(|s| s.as_str())
}
pub fn compose(&self, other: &Substitution) -> Substitution {
let mut result = Substitution::new();
for (k, v) in &self.0 {
let resolved = if Self::is_variable(v) {
let var_name = &v[1..];
other.0.get(var_name).cloned().unwrap_or_else(|| v.clone())
} else {
v.clone()
};
result.0.insert(k.clone(), resolved);
}
for (k, v) in &other.0 {
result.0.entry(k.clone()).or_insert_with(|| v.clone());
}
result
}
pub fn is_variable(s: &str) -> bool {
s.starts_with('?')
}
}
#[derive(Debug, Clone)]
pub struct ProofTree {
pub goal: Goal,
pub used_clause: Option<String>,
pub children: Vec<ProofTree>,
pub substitution: Substitution,
}
#[derive(Debug, Clone)]
pub struct BackwardChainer {
clauses: Vec<Clause>,
max_depth: usize,
}
impl BackwardChainer {
pub fn new(max_depth: usize) -> Self {
Self {
clauses: Vec::new(),
max_depth,
}
}
pub fn add_clause(&mut self, clause: Clause) {
self.clauses.push(clause);
}
pub fn add_fact(&mut self, predicate: &str, subject: &str, object: &str) {
self.clauses.push(Clause {
head: Goal::new(predicate, subject, object),
body: vec![],
});
}
pub fn prove(&self, goal: &Goal) -> Option<ProofTree> {
let sub = Substitution::new();
self.prove_internal(goal, &sub, 0)
}
pub fn prove_all(&self, goal: &Goal) -> Vec<ProofTree> {
let sub = Substitution::new();
self.prove_all_internal(goal, &sub, 0)
}
pub fn can_prove(&self, goal: &Goal) -> bool {
self.prove(goal).is_some()
}
pub fn clause_count(&self) -> usize {
self.clauses.len()
}
pub fn unify(goal: &Goal, head: &Goal) -> Option<Substitution> {
if goal.predicate != head.predicate {
return None;
}
let mut sub = Substitution::new();
Self::unify_term(&goal.subject, &head.subject, &mut sub)?;
Self::unify_term(&goal.object, &head.object, &mut sub)?;
Some(sub)
}
pub fn apply_substitution(goal: &Goal, sub: &Substitution) -> Goal {
Goal {
predicate: goal.predicate.clone(),
subject: Self::apply_term(&goal.subject, sub),
object: Self::apply_term(&goal.object, sub),
}
}
fn apply_term(term: &str, sub: &Substitution) -> String {
if Substitution::is_variable(term) {
let var_name = &term[1..];
sub.get(var_name)
.map(|v| {
if Substitution::is_variable(v) {
Self::apply_term(v, sub)
} else {
v.to_string()
}
})
.unwrap_or_else(|| term.to_string())
} else {
term.to_string()
}
}
fn unify_term(t1: &str, t2: &str, sub: &mut Substitution) -> Option<()> {
let r1 = if Substitution::is_variable(t1) {
sub.get(&t1[1..])
.map(|s| s.to_string())
.unwrap_or_else(|| t1.to_string())
} else {
t1.to_string()
};
let r2 = if Substitution::is_variable(t2) {
sub.get(&t2[1..])
.map(|s| s.to_string())
.unwrap_or_else(|| t2.to_string())
} else {
t2.to_string()
};
if r1 == r2 {
return Some(());
}
if Substitution::is_variable(&r1) {
sub.0.insert(r1[1..].to_string(), r2);
return Some(());
}
if Substitution::is_variable(&r2) {
sub.0.insert(r2[1..].to_string(), r1);
return Some(());
}
None
}
fn prove_internal(
&self,
goal: &Goal,
current_sub: &Substitution,
depth: usize,
) -> Option<ProofTree> {
if depth > self.max_depth {
return None;
}
let resolved_goal = Self::apply_substitution(goal, current_sub);
for clause in &self.clauses {
let renamed = self.rename_clause(clause, depth);
if let Some(mgu) = Self::unify(&resolved_goal, &renamed.head) {
let combined = current_sub.compose(&mgu);
if let Some((children, final_sub)) =
self.prove_body(&renamed.body, &combined, depth + 1)
{
let label = if renamed.body.is_empty() {
None
} else {
Some(format!(
"{} :- {}",
renamed.head.predicate,
renamed
.body
.iter()
.map(|g| g.predicate.as_str())
.collect::<Vec<_>>()
.join(", ")
))
};
return Some(ProofTree {
goal: resolved_goal,
used_clause: label,
children,
substitution: final_sub,
});
}
}
}
None
}
fn prove_all_internal(
&self,
goal: &Goal,
current_sub: &Substitution,
depth: usize,
) -> Vec<ProofTree> {
if depth > self.max_depth {
return vec![];
}
let resolved_goal = Self::apply_substitution(goal, current_sub);
let mut proofs = Vec::new();
for clause in &self.clauses {
let renamed = self.rename_clause(clause, depth);
if let Some(mgu) = Self::unify(&resolved_goal, &renamed.head) {
let combined = current_sub.compose(&mgu);
if renamed.body.is_empty() {
proofs.push(ProofTree {
goal: resolved_goal.clone(),
used_clause: None,
children: vec![],
substitution: combined,
});
} else {
let all_body = self.prove_all_body(&renamed.body, &combined, depth + 1);
for (children, final_sub) in all_body {
let label = Some(format!(
"{} :- {}",
renamed.head.predicate,
renamed
.body
.iter()
.map(|g| g.predicate.as_str())
.collect::<Vec<_>>()
.join(", ")
));
proofs.push(ProofTree {
goal: resolved_goal.clone(),
used_clause: label,
children,
substitution: final_sub,
});
}
}
}
}
proofs
}
fn prove_body(
&self,
body: &[Goal],
sub: &Substitution,
depth: usize,
) -> Option<(Vec<ProofTree>, Substitution)> {
if body.is_empty() {
return Some((vec![], sub.clone()));
}
let first = &body[0];
let rest = &body[1..];
if let Some(tree) = self.prove_internal(first, sub, depth) {
let new_sub = sub.compose(&tree.substitution);
if let Some((mut rest_trees, final_sub)) = self.prove_body(rest, &new_sub, depth) {
rest_trees.insert(0, tree);
return Some((rest_trees, final_sub));
}
}
None
}
fn prove_all_body(
&self,
body: &[Goal],
sub: &Substitution,
depth: usize,
) -> Vec<(Vec<ProofTree>, Substitution)> {
if body.is_empty() {
return vec![(vec![], sub.clone())];
}
let first = &body[0];
let rest = &body[1..];
let first_proofs = self.prove_all_internal(first, sub, depth);
let mut results = Vec::new();
for tree in first_proofs {
let new_sub = sub.compose(&tree.substitution);
let rest_results = self.prove_all_body(rest, &new_sub, depth);
for (mut rest_trees, final_sub) in rest_results {
rest_trees.insert(0, tree.clone());
results.push((rest_trees, final_sub));
}
}
results
}
fn rename_clause(&self, clause: &Clause, depth: usize) -> Clause {
let suffix = format!("_{depth}");
let rename_goal = |g: &Goal| Goal {
predicate: g.predicate.clone(),
subject: Self::rename_term(&g.subject, &suffix),
object: Self::rename_term(&g.object, &suffix),
};
Clause {
head: rename_goal(&clause.head),
body: clause.body.iter().map(rename_goal).collect(),
}
}
fn rename_term(term: &str, suffix: &str) -> String {
if Substitution::is_variable(term) {
format!("{}{}", term, suffix)
} else {
term.to_string()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_chainer() -> BackwardChainer {
let mut bc = BackwardChainer::new(10);
bc.add_fact("parent", "tom", "bob");
bc.add_fact("parent", "bob", "ann");
bc.add_clause(Clause {
head: Goal::new("ancestor", "?X", "?Y"),
body: vec![Goal::new("parent", "?X", "?Y")],
});
bc.add_clause(Clause {
head: Goal::new("ancestor", "?X", "?Z"),
body: vec![
Goal::new("parent", "?X", "?Y"),
Goal::new("ancestor", "?Y", "?Z"),
],
});
bc
}
#[test]
fn test_prove_simple_fact() {
let bc = simple_chainer();
let goal = Goal::new("parent", "tom", "bob");
assert!(bc.can_prove(&goal));
}
#[test]
fn test_prove_unknown_fact() {
let bc = simple_chainer();
let goal = Goal::new("parent", "ann", "tom");
assert!(!bc.can_prove(&goal));
}
#[test]
fn test_prove_direct_ancestor() {
let bc = simple_chainer();
let goal = Goal::new("ancestor", "tom", "bob");
assert!(bc.can_prove(&goal));
}
#[test]
fn test_prove_chained_ancestor() {
let bc = simple_chainer();
let goal = Goal::new("ancestor", "tom", "ann");
assert!(bc.can_prove(&goal));
}
#[test]
fn test_prove_returns_proof_tree() -> Result<(), Box<dyn std::error::Error>> {
let bc = simple_chainer();
let goal = Goal::new("parent", "tom", "bob");
let proof = bc.prove(&goal);
assert!(proof.is_some());
let tree = proof.ok_or("expected Some value")?;
assert_eq!(tree.goal.predicate, "parent");
Ok(())
}
#[test]
fn test_proof_tree_has_correct_goal() -> Result<(), Box<dyn std::error::Error>> {
let bc = simple_chainer();
let goal = Goal::new("ancestor", "tom", "ann");
let tree = bc.prove(&goal).ok_or("expected Some value")?;
assert_eq!(tree.goal.subject, "tom");
assert_eq!(tree.goal.object, "ann");
Ok(())
}
#[test]
fn test_clause_count_initial() {
let bc = BackwardChainer::new(5);
assert_eq!(bc.clause_count(), 0);
}
#[test]
fn test_clause_count_after_facts() {
let mut bc = BackwardChainer::new(5);
bc.add_fact("p", "a", "b");
bc.add_fact("p", "b", "c");
assert_eq!(bc.clause_count(), 2);
}
#[test]
fn test_clause_count_with_rules() {
let bc = simple_chainer();
assert_eq!(bc.clause_count(), 4);
}
#[test]
fn test_add_fact_increases_count() {
let mut bc = BackwardChainer::new(5);
bc.add_fact("likes", "alice", "bob");
assert_eq!(bc.clause_count(), 1);
}
#[test]
fn test_prove_added_fact() {
let mut bc = BackwardChainer::new(5);
bc.add_fact("likes", "alice", "bob");
assert!(bc.can_prove(&Goal::new("likes", "alice", "bob")));
}
#[test]
fn test_unify_same_ground_terms() {
let g1 = Goal::new("p", "a", "b");
let g2 = Goal::new("p", "a", "b");
assert!(BackwardChainer::unify(&g1, &g2).is_some());
}
#[test]
fn test_unify_different_predicate() {
let g1 = Goal::new("p", "a", "b");
let g2 = Goal::new("q", "a", "b");
assert!(BackwardChainer::unify(&g1, &g2).is_none());
}
#[test]
fn test_unify_variable_binds() -> Result<(), Box<dyn std::error::Error>> {
let g1 = Goal::new("p", "?X", "b");
let g2 = Goal::new("p", "a", "b");
let sub = BackwardChainer::unify(&g1, &g2).ok_or("expected Some value")?;
assert_eq!(sub.get("X"), Some("a"));
Ok(())
}
#[test]
fn test_unify_conflict_fails() {
let g1 = Goal::new("p", "a", "b");
let g2 = Goal::new("p", "c", "b");
assert!(BackwardChainer::unify(&g1, &g2).is_none());
}
#[test]
fn test_unify_both_variables() -> Result<(), Box<dyn std::error::Error>> {
let g1 = Goal::new("p", "?X", "?Y");
let g2 = Goal::new("p", "tom", "ann");
let sub = BackwardChainer::unify(&g1, &g2).ok_or("expected Some value")?;
assert_eq!(sub.get("X"), Some("tom"));
assert_eq!(sub.get("Y"), Some("ann"));
Ok(())
}
#[test]
fn test_apply_substitution_binds_variable() -> Result<(), Box<dyn std::error::Error>> {
let goal = Goal::new("p", "?X", "?Y");
let mut sub = Substitution::new();
sub.bind("X".to_string(), "a".to_string());
sub.bind("Y".to_string(), "b".to_string());
let result = BackwardChainer::apply_substitution(&goal, &sub);
assert_eq!(result.subject, "a");
assert_eq!(result.object, "b");
Ok(())
}
#[test]
fn test_apply_substitution_keeps_constant() -> Result<(), Box<dyn std::error::Error>> {
let goal = Goal::new("p", "const", "?Y");
let mut sub = Substitution::new();
sub.bind("Y".to_string(), "val".to_string());
let result = BackwardChainer::apply_substitution(&goal, &sub);
assert_eq!(result.subject, "const");
assert_eq!(result.object, "val");
Ok(())
}
#[test]
fn test_apply_substitution_unbound_variable_stays() -> Result<(), Box<dyn std::error::Error>> {
let goal = Goal::new("p", "?X", "b");
let sub = Substitution::new();
let result = BackwardChainer::apply_substitution(&goal, &sub);
assert_eq!(result.subject, "?X");
Ok(())
}
#[test]
fn test_substitution_is_variable_true() -> Result<(), Box<dyn std::error::Error>> {
assert!(Substitution::is_variable("?X"));
Ok(())
}
#[test]
fn test_substitution_is_variable_false() {
assert!(!Substitution::is_variable("alice"));
}
#[test]
fn test_substitution_bind_and_get() {
let mut sub = Substitution::new();
sub.bind("X".to_string(), "hello".to_string());
assert_eq!(sub.get("X"), Some("hello"));
}
#[test]
fn test_substitution_get_missing() {
let sub = Substitution::new();
assert!(sub.get("Z").is_none());
}
#[test]
fn test_substitution_compose() {
let mut s1 = Substitution::new();
s1.bind("X".to_string(), "a".to_string());
let mut s2 = Substitution::new();
s2.bind("Y".to_string(), "b".to_string());
let composed = s1.compose(&s2);
assert_eq!(composed.get("X"), Some("a"));
assert_eq!(composed.get("Y"), Some("b"));
}
#[test]
fn test_substitution_compose_resolves_variable_chain() -> Result<(), Box<dyn std::error::Error>>
{
let mut s1 = Substitution::new();
s1.bind("X".to_string(), "?Y".to_string());
let mut s2 = Substitution::new();
s2.bind("Y".to_string(), "val".to_string());
let composed = s1.compose(&s2);
assert_eq!(composed.get("X"), Some("val"));
Ok(())
}
#[test]
fn test_prove_all_multiple_solutions() -> Result<(), Box<dyn std::error::Error>> {
let mut bc = BackwardChainer::new(5);
bc.add_fact("color", "rose", "red");
bc.add_fact("color", "sky", "blue");
bc.add_fact("color", "grass", "green");
let goal = Goal::new("color", "?X", "?Y");
let proofs = bc.prove_all(&goal);
assert_eq!(proofs.len(), 3);
Ok(())
}
#[test]
fn test_prove_all_no_solutions() {
let bc = BackwardChainer::new(5);
let goal = Goal::new("unknown", "x", "y");
let proofs = bc.prove_all(&goal);
assert!(proofs.is_empty());
}
#[test]
fn test_max_depth_limits_recursion() -> Result<(), Box<dyn std::error::Error>> {
let mut bc = BackwardChainer::new(1);
bc.add_fact("parent", "tom", "bob");
bc.add_fact("parent", "bob", "ann");
bc.add_clause(Clause {
head: Goal::new("ancestor", "?X", "?Z"),
body: vec![
Goal::new("parent", "?X", "?Y"),
Goal::new("ancestor", "?Y", "?Z"),
],
});
assert!(bc.can_prove(&Goal::new("parent", "tom", "bob")));
Ok(())
}
#[test]
fn test_proof_tree_used_clause_none_for_fact() -> Result<(), Box<dyn std::error::Error>> {
let mut bc = BackwardChainer::new(5);
bc.add_fact("likes", "alice", "bob");
let tree = bc
.prove(&Goal::new("likes", "alice", "bob"))
.ok_or("expected Some value")?;
assert!(tree.used_clause.is_none());
Ok(())
}
#[test]
fn test_proof_tree_used_clause_some_for_rule() -> Result<(), Box<dyn std::error::Error>> {
let bc = simple_chainer();
let tree = bc
.prove(&Goal::new("ancestor", "tom", "bob"))
.ok_or("expected Some value")?;
assert!(tree.used_clause.is_some());
Ok(())
}
#[test]
fn test_can_prove_false() {
let bc = simple_chainer();
assert!(!bc.can_prove(&Goal::new("parent", "ann", "tom")));
}
#[test]
fn test_prove_all_single_solution() {
let bc = simple_chainer();
let goal = Goal::new("parent", "tom", "bob");
let proofs = bc.prove_all(&goal);
assert!(!proofs.is_empty());
}
#[test]
fn test_proof_tree_children_for_rule() -> Result<(), Box<dyn std::error::Error>> {
let bc = simple_chainer();
let tree = bc
.prove(&Goal::new("ancestor", "tom", "bob"))
.ok_or("expected Some value")?;
assert!(!tree.children.is_empty());
Ok(())
}
#[test]
fn test_goal_new() {
let g = Goal::new("foo", "a", "b");
assert_eq!(g.predicate, "foo");
assert_eq!(g.subject, "a");
assert_eq!(g.object, "b");
}
}