use crate::cache::{CacheManager, QueryKey};
use crate::ir::{KnowledgeBase, Predicate, Rule, Term};
use ipfrs_core::error::Result;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
pub type Substitution = HashMap<String, Term>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Proof {
pub goal: Predicate,
pub rule: Option<ProofRule>,
pub subproofs: Vec<Proof>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProofRule {
pub head: Predicate,
pub body: Vec<Predicate>,
pub is_fact: bool,
}
impl Proof {
pub fn fact(goal: Predicate) -> Self {
Self {
rule: Some(ProofRule {
head: goal.clone(),
body: Vec::new(),
is_fact: true,
}),
goal,
subproofs: Vec::new(),
}
}
pub fn from_rule(goal: Predicate, rule: &Rule, subproofs: Vec<Proof>) -> Self {
Self {
goal,
rule: Some(ProofRule {
head: rule.head.clone(),
body: rule.body.clone(),
is_fact: false,
}),
subproofs,
}
}
pub fn depth(&self) -> usize {
if self.subproofs.is_empty() {
1
} else {
1 + self.subproofs.iter().map(|p| p.depth()).max().unwrap_or(0)
}
}
#[inline]
pub fn size(&self) -> usize {
1 + self.subproofs.iter().map(|p| p.size()).sum::<usize>()
}
#[inline]
pub fn is_fact(&self) -> bool {
self.subproofs.is_empty()
}
pub fn all_goals(&self) -> Vec<&Predicate> {
let mut goals = vec![&self.goal];
for subproof in &self.subproofs {
goals.extend(subproof.all_goals());
}
goals
}
}
#[derive(Debug, Clone)]
pub struct GoalDecomposition {
pub goal: Predicate,
pub subgoals: Vec<Predicate>,
pub rule_applied: Option<Rule>,
pub local_solutions: Vec<bool>,
pub depth: usize,
}
impl GoalDecomposition {
pub fn new(goal: Predicate, depth: usize) -> Self {
Self {
goal,
subgoals: Vec::new(),
rule_applied: None,
local_solutions: Vec::new(),
depth,
}
}
pub fn apply_rule(&mut self, rule: &Rule) {
self.rule_applied = Some(rule.clone());
self.subgoals = rule.body.clone();
self.local_solutions = vec![false; rule.body.len()];
}
pub fn mark_solved(&mut self, index: usize) {
if index < self.local_solutions.len() {
self.local_solutions[index] = true;
}
}
#[inline]
pub fn is_complete(&self) -> bool {
self.local_solutions.iter().all(|&solved| solved)
}
pub fn unsolved_subgoals(&self) -> Vec<&Predicate> {
self.subgoals
.iter()
.zip(self.local_solutions.iter())
.filter(|(_, &solved)| !solved)
.map(|(sg, _)| sg)
.collect()
}
}
#[derive(Debug, Clone, Default)]
pub struct CycleDetector {
goal_stack: Vec<String>,
goal_set: HashSet<String>,
}
impl CycleDetector {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn push(&mut self, goal: &Predicate) -> bool {
let key = goal_to_key(goal);
if self.goal_set.contains(&key) {
return false; }
self.goal_set.insert(key.clone());
self.goal_stack.push(key);
true
}
#[inline]
pub fn pop(&mut self) {
if let Some(key) = self.goal_stack.pop() {
self.goal_set.remove(&key);
}
}
#[inline]
pub fn would_cycle(&self, goal: &Predicate) -> bool {
let key = goal_to_key(goal);
self.goal_set.contains(&key)
}
#[inline]
pub fn depth(&self) -> usize {
self.goal_stack.len()
}
pub fn clear(&mut self) {
self.goal_stack.clear();
self.goal_set.clear();
}
}
fn goal_to_key(goal: &Predicate) -> String {
format!("{}({})", goal.name, goal.args.len())
}
#[derive(Default)]
pub struct InferenceEngine {
max_depth: usize,
max_solutions: usize,
cycle_detection: bool,
}
impl InferenceEngine {
#[inline]
pub fn new() -> Self {
Self {
max_depth: 100,
max_solutions: 100,
cycle_detection: true,
}
}
#[inline]
pub fn with_limits(max_depth: usize, max_solutions: usize) -> Self {
Self {
max_depth,
max_solutions,
cycle_detection: true,
}
}
#[inline]
pub fn with_cycle_detection(mut self, enabled: bool) -> Self {
self.cycle_detection = enabled;
self
}
pub fn query(&self, goal: &Predicate, kb: &KnowledgeBase) -> Result<Vec<Substitution>> {
let mut solutions = Vec::new();
let initial_subst = Substitution::new();
self.solve_goal(goal, &initial_subst, kb, 0, &mut solutions)?;
Ok(solutions)
}
pub fn prove(&self, goal: &Predicate, kb: &KnowledgeBase) -> Result<Option<Proof>> {
let initial_subst = Substitution::new();
self.prove_goal(goal, &initial_subst, kb, 0)
}
fn solve_goal(
&self,
goal: &Predicate,
subst: &Substitution,
kb: &KnowledgeBase,
depth: usize,
solutions: &mut Vec<Substitution>,
) -> Result<()> {
if depth > self.max_depth {
return Ok(());
}
if solutions.len() >= self.max_solutions {
return Ok(());
}
let goal = apply_subst_predicate(goal, subst);
for fact in kb.get_predicates(&goal.name) {
if let Some(new_subst) = unify_predicates(&goal, fact, subst) {
solutions.push(new_subst);
if solutions.len() >= self.max_solutions {
return Ok(());
}
}
}
for rule in kb.get_rules(&goal.name) {
let renamed_rule = rename_rule_vars(rule, depth);
if let Some(new_subst) = unify_predicates(&goal, &renamed_rule.head, subst) {
self.solve_conjunction(&renamed_rule.body, &new_subst, kb, depth + 1, solutions)?;
}
}
Ok(())
}
fn solve_conjunction(
&self,
goals: &[Predicate],
subst: &Substitution,
kb: &KnowledgeBase,
depth: usize,
solutions: &mut Vec<Substitution>,
) -> Result<()> {
if goals.is_empty() {
solutions.push(subst.clone());
return Ok(());
}
let first_goal = &goals[0];
let rest_goals = &goals[1..];
let mut intermediate_solutions = Vec::new();
self.solve_goal(first_goal, subst, kb, depth, &mut intermediate_solutions)?;
for intermediate_subst in intermediate_solutions {
if solutions.len() >= self.max_solutions {
return Ok(());
}
self.solve_conjunction(rest_goals, &intermediate_subst, kb, depth, solutions)?;
}
Ok(())
}
fn prove_goal(
&self,
goal: &Predicate,
subst: &Substitution,
kb: &KnowledgeBase,
depth: usize,
) -> Result<Option<Proof>> {
if depth > self.max_depth {
return Ok(None);
}
let goal = apply_subst_predicate(goal, subst);
for fact in kb.get_predicates(&goal.name) {
if let Some(_new_subst) = unify_predicates(&goal, fact, subst) {
return Ok(Some(Proof::fact(goal)));
}
}
for rule in kb.get_rules(&goal.name) {
let renamed_rule = rename_rule_vars(rule, depth);
if let Some(new_subst) = unify_predicates(&goal, &renamed_rule.head, subst) {
if let Some(subproofs) =
self.prove_conjunction(&renamed_rule.body, &new_subst, kb, depth + 1)?
{
return Ok(Some(Proof::from_rule(goal, &renamed_rule, subproofs)));
}
}
}
Ok(None)
}
fn prove_conjunction(
&self,
goals: &[Predicate],
subst: &Substitution,
kb: &KnowledgeBase,
depth: usize,
) -> Result<Option<Vec<Proof>>> {
if goals.is_empty() {
return Ok(Some(Vec::new()));
}
let first_goal = &goals[0];
let rest_goals = &goals[1..];
if let Some(first_proof) = self.prove_goal(first_goal, subst, kb, depth)? {
if let Some(rest_proofs) = self.prove_conjunction(rest_goals, subst, kb, depth)? {
let mut all_proofs = vec![first_proof];
all_proofs.extend(rest_proofs);
return Ok(Some(all_proofs));
}
}
Ok(None)
}
pub fn verify(&self, proof: &Proof, kb: &KnowledgeBase) -> Result<bool> {
self.verify_proof_recursive(proof, kb, 0)
}
fn verify_proof_recursive(
&self,
proof: &Proof,
kb: &KnowledgeBase,
depth: usize,
) -> Result<bool> {
if depth > self.max_depth {
return Ok(false);
}
let Some(ref rule) = proof.rule else {
return Ok(false);
};
if rule.is_fact {
let facts = kb.get_predicates(&proof.goal.name);
for fact in facts {
if unify_predicates(&proof.goal, fact, &Substitution::new()).is_some() {
return Ok(true);
}
}
return Ok(false);
}
let rules = kb.get_rules(&proof.goal.name);
let mut rule_exists = false;
for kb_rule in rules {
if kb_rule.head.name == rule.head.name
&& kb_rule.head.args.len() == rule.head.args.len()
&& kb_rule.body.len() == rule.body.len()
{
let bodies_match = kb_rule
.body
.iter()
.zip(rule.body.iter())
.all(|(b1, b2)| b1.name == b2.name && b1.args.len() == b2.args.len());
if bodies_match {
rule_exists = true;
break;
}
}
}
if !rule_exists {
return Ok(false);
}
if proof.subproofs.len() != rule.body.len() {
return Ok(false);
}
for (i, subproof) in proof.subproofs.iter().enumerate() {
let body_predicate = &rule.body[i];
if subproof.goal.name != body_predicate.name {
return Ok(false);
}
if !self.verify_proof_recursive(subproof, kb, depth + 1)? {
return Ok(false);
}
}
Ok(true)
}
}
pub fn unify(t1: &Term, t2: &Term, subst: &Substitution) -> Option<Substitution> {
let t1 = apply_subst_term(t1, subst);
let t2 = apply_subst_term(t2, subst);
match (&t1, &t2) {
(Term::Const(c1), Term::Const(c2)) if c1 == c2 => Some(subst.clone()),
(Term::Var(v), t) | (t, Term::Var(v)) => {
if let Term::Var(v2) = t {
if v == v2 {
return Some(subst.clone());
}
}
if occurs_in(v, t) {
return None;
}
let mut new_subst = subst.clone();
new_subst.insert(v.clone(), t.clone());
Some(new_subst)
}
(Term::Fun(f1, args1), Term::Fun(f2, args2)) if f1 == f2 && args1.len() == args2.len() => {
let mut current_subst = subst.clone();
for (a1, a2) in args1.iter().zip(args2.iter()) {
match unify(a1, a2, ¤t_subst) {
Some(new_subst) => current_subst = new_subst,
None => return None,
}
}
Some(current_subst)
}
(Term::Ref(r1), Term::Ref(r2)) if r1.cid == r2.cid => Some(subst.clone()),
_ => None,
}
}
pub fn unify_predicates(
p1: &Predicate,
p2: &Predicate,
subst: &Substitution,
) -> Option<Substitution> {
if p1.name != p2.name || p1.args.len() != p2.args.len() {
return None;
}
let mut current_subst = subst.clone();
for (a1, a2) in p1.args.iter().zip(p2.args.iter()) {
match unify(a1, a2, ¤t_subst) {
Some(new_subst) => current_subst = new_subst,
None => return None,
}
}
Some(current_subst)
}
fn occurs_in(var: &str, term: &Term) -> bool {
match term {
Term::Var(v) => v == var,
Term::Fun(_, args) => args.iter().any(|t| occurs_in(var, t)),
_ => false,
}
}
pub fn apply_subst_term(term: &Term, subst: &Substitution) -> Term {
match term {
Term::Var(v) => subst.get(v).cloned().unwrap_or_else(|| term.clone()),
Term::Fun(f, args) => {
let new_args = args.iter().map(|t| apply_subst_term(t, subst)).collect();
Term::Fun(f.clone(), new_args)
}
_ => term.clone(),
}
}
pub fn apply_subst_predicate(pred: &Predicate, subst: &Substitution) -> Predicate {
Predicate {
name: pred.name.clone(),
args: pred
.args
.iter()
.map(|t| apply_subst_term(t, subst))
.collect(),
}
}
fn rename_rule_vars(rule: &Rule, suffix: usize) -> Rule {
let var_map: HashMap<String, String> = rule
.variables()
.into_iter()
.map(|v| (v.clone(), format!("{}_{}", v, suffix)))
.collect();
let rename_subst: Substitution = var_map
.into_iter()
.map(|(old, new)| (old, Term::Var(new)))
.collect();
Rule {
head: apply_subst_predicate(&rule.head, &rename_subst),
body: rule
.body
.iter()
.map(|p| apply_subst_predicate(p, &rename_subst))
.collect(),
}
}
pub struct MemoizedInferenceEngine {
engine: InferenceEngine,
cache: Arc<CacheManager>,
}
impl MemoizedInferenceEngine {
pub fn new(cache: Arc<CacheManager>) -> Self {
Self {
engine: InferenceEngine::new(),
cache,
}
}
pub fn with_limits(max_depth: usize, max_solutions: usize, cache: Arc<CacheManager>) -> Self {
Self {
engine: InferenceEngine::with_limits(max_depth, max_solutions),
cache,
}
}
pub fn query(&self, goal: &Predicate, kb: &KnowledgeBase) -> Result<Vec<Substitution>> {
let key = QueryKey::from_predicate(goal);
if let Some(cached) = self.cache.query_cache.get(&key) {
return Ok(cached);
}
let solutions = self.engine.query(goal, kb)?;
if !solutions.is_empty() {
self.cache.query_cache.insert(key, solutions.clone());
}
Ok(solutions)
}
pub fn prove(&self, goal: &Predicate, kb: &KnowledgeBase) -> Result<Option<Proof>> {
self.engine.prove(goal, kb)
}
#[inline]
pub fn cache_stats(&self) -> crate::cache::CombinedCacheStats {
self.cache.stats()
}
pub fn clear_cache(&self) {
self.cache.query_cache.clear();
}
}
pub struct DistributedReasoner {
engine: InferenceEngine,
cache: Option<Arc<CacheManager>>,
decompositions: Vec<GoalDecomposition>,
}
impl DistributedReasoner {
pub fn new() -> Result<Self> {
Ok(Self {
engine: InferenceEngine::new(),
cache: None,
decompositions: Vec::new(),
})
}
pub fn with_cache(cache: Arc<CacheManager>) -> Result<Self> {
Ok(Self {
engine: InferenceEngine::new(),
cache: Some(cache),
decompositions: Vec::new(),
})
}
pub async fn query(&self, goal: &Predicate, kb: &KnowledgeBase) -> Result<Vec<Substitution>> {
if let Some(cache) = &self.cache {
let key = QueryKey::from_predicate(goal);
if let Some(cached) = cache.query_cache.get(&key) {
return Ok(cached);
}
let solutions = self.engine.query(goal, kb)?;
if !solutions.is_empty() {
cache.query_cache.insert(key, solutions.clone());
}
Ok(solutions)
} else {
self.engine.query(goal, kb)
}
}
pub async fn prove(&self, goal: &Predicate, kb: &KnowledgeBase) -> Result<Option<Proof>> {
self.engine.prove(goal, kb)
}
pub async fn prove_with_decomposition(
&mut self,
goal: &Predicate,
kb: &KnowledgeBase,
) -> Result<(Option<Proof>, Vec<GoalDecomposition>)> {
self.decompositions.clear();
let proof = self.prove_tracking(goal, kb, 0)?;
let decomps = std::mem::take(&mut self.decompositions);
Ok((proof, decomps))
}
fn prove_tracking(
&mut self,
goal: &Predicate,
kb: &KnowledgeBase,
depth: usize,
) -> Result<Option<Proof>> {
let mut decomp = GoalDecomposition::new(goal.clone(), depth);
for fact in kb.get_predicates(&goal.name) {
if unify_predicates(goal, fact, &Substitution::new()).is_some() {
self.decompositions.push(decomp);
return Ok(Some(Proof::fact(goal.clone())));
}
}
for rule in kb.get_rules(&goal.name) {
let renamed_rule = rename_rule_vars(rule, depth);
if let Some(subst) = unify_predicates(goal, &renamed_rule.head, &Substitution::new()) {
decomp.apply_rule(&renamed_rule);
let mut subproofs = Vec::new();
let mut all_proved = true;
for (i, subgoal) in renamed_rule.body.iter().enumerate() {
let subgoal = apply_subst_predicate(subgoal, &subst);
if let Some(subproof) = self.prove_tracking(&subgoal, kb, depth + 1)? {
subproofs.push(subproof);
decomp.mark_solved(i);
} else {
all_proved = false;
break;
}
}
if all_proved {
self.decompositions.push(decomp);
return Ok(Some(Proof::from_rule(
goal.clone(),
&renamed_rule,
subproofs,
)));
}
}
}
self.decompositions.push(decomp);
Ok(None)
}
pub fn get_unsolved_goals(&self) -> Vec<&Predicate> {
self.decompositions
.iter()
.flat_map(|d| d.unsolved_subgoals())
.collect()
}
pub fn cache_stats(&self) -> Option<crate::cache::CombinedCacheStats> {
self.cache.as_ref().map(|c| c.stats())
}
}
impl Default for DistributedReasoner {
fn default() -> Self {
Self {
engine: InferenceEngine::new(),
cache: None,
decompositions: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::Constant;
#[test]
fn test_unify_constants() {
let t1 = Term::Const(Constant::String("Alice".to_string()));
let t2 = Term::Const(Constant::String("Alice".to_string()));
let subst = Substitution::new();
assert!(unify(&t1, &t2, &subst).is_some());
}
#[test]
fn test_unify_var_const() {
let t1 = Term::Var("X".to_string());
let t2 = Term::Const(Constant::String("Alice".to_string()));
let subst = Substitution::new();
let result = unify(&t1, &t2, &subst);
assert!(result.is_some());
let result_subst = result.unwrap();
assert_eq!(
result_subst.get("X"),
Some(&Term::Const(Constant::String("Alice".to_string())))
);
}
#[test]
fn test_unify_functions() {
let t1 = Term::Fun(
"f".to_string(),
vec![Term::Var("X".to_string()), Term::Const(Constant::Int(1))],
);
let t2 = Term::Fun(
"f".to_string(),
vec![
Term::Const(Constant::String("a".to_string())),
Term::Const(Constant::Int(1)),
],
);
let subst = Substitution::new();
let result = unify(&t1, &t2, &subst);
assert!(result.is_some());
}
#[test]
fn test_occurs_check() {
let t1 = Term::Var("X".to_string());
let t2 = Term::Fun("f".to_string(), vec![Term::Var("X".to_string())]);
let subst = Substitution::new();
assert!(unify(&t1, &t2, &subst).is_none());
}
#[test]
fn test_inference_fact() {
let mut kb = KnowledgeBase::new();
kb.add_fact(Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("Alice".to_string())),
Term::Const(Constant::String("Bob".to_string())),
],
));
let engine = InferenceEngine::new();
let goal = Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("Alice".to_string())),
Term::Var("X".to_string()),
],
);
let solutions = engine.query(&goal, &kb).unwrap();
assert_eq!(solutions.len(), 1);
assert_eq!(
solutions[0].get("X"),
Some(&Term::Const(Constant::String("Bob".to_string())))
);
}
#[test]
fn test_inference_rule() {
let mut kb = KnowledgeBase::new();
kb.add_fact(Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("alice".to_string())),
Term::Const(Constant::String("bob".to_string())),
],
));
kb.add_fact(Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("bob".to_string())),
Term::Const(Constant::String("charlie".to_string())),
],
));
let rule = Rule::new(
Predicate::new(
"grandparent".to_string(),
vec![Term::Var("X".to_string()), Term::Var("Z".to_string())],
),
vec![
Predicate::new(
"parent".to_string(),
vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
),
Predicate::new(
"parent".to_string(),
vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
),
],
);
kb.add_rule(rule);
let engine = InferenceEngine::new();
let goal = Predicate::new(
"grandparent".to_string(),
vec![
Term::Const(Constant::String("alice".to_string())),
Term::Var("Z".to_string()),
],
);
let solutions = engine.query(&goal, &kb).unwrap();
assert!(!solutions.is_empty());
}
}