use crate::ir::{KnowledgeBase, Predicate, Rule};
use crate::reasoning::{apply_subst_predicate, unify_predicates, Substitution};
use ipfrs_core::error::Result;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
struct TableEntry {
#[allow(dead_code)]
goal: Predicate,
solutions: Vec<Substitution>,
complete: bool,
#[allow(dead_code)]
depth: usize,
}
pub struct TabledInferenceEngine {
table: HashMap<String, TableEntry>,
max_depth: usize,
max_solutions: usize,
}
impl TabledInferenceEngine {
pub fn new() -> Self {
Self {
table: HashMap::new(),
max_depth: 100,
max_solutions: 1000,
}
}
pub fn with_limits(max_depth: usize, max_solutions: usize) -> Self {
Self {
table: HashMap::new(),
max_depth,
max_solutions,
}
}
pub fn query(&self, goal: &Predicate, kb: &KnowledgeBase) -> Result<Vec<Substitution>> {
let mut engine = Self {
table: HashMap::new(),
max_depth: self.max_depth,
max_solutions: self.max_solutions,
};
engine.solve_tabled(goal, &Substitution::new(), kb, 0)
}
fn solve_tabled(
&mut self,
goal: &Predicate,
subst: &Substitution,
kb: &KnowledgeBase,
depth: usize,
) -> Result<Vec<Substitution>> {
if depth > self.max_depth {
return Ok(Vec::new());
}
let goal = apply_subst_predicate(goal, subst);
let key = self.goal_key(&goal);
if let Some(entry) = self.table.get(&key) {
if entry.complete {
return Ok(entry.solutions.clone());
}
return Ok(Vec::new());
}
let mut entry = TableEntry {
goal: goal.clone(),
solutions: Vec::new(),
complete: false,
depth,
};
self.table.insert(key.clone(), entry.clone());
let mut solutions = Vec::new();
for fact in kb.get_predicates(&goal.name) {
if let Some(new_subst) = unify_predicates(&goal, fact, &Substitution::new()) {
solutions.push(new_subst);
if solutions.len() >= self.max_solutions {
break;
}
}
}
for rule in kb.get_rules(&goal.name) {
if solutions.len() >= self.max_solutions {
break;
}
let renamed_rule = self.rename_rule(rule, depth);
if let Some(new_subst) =
unify_predicates(&goal, &renamed_rule.head, &Substitution::new())
{
let body_solutions =
self.solve_conjunction(&renamed_rule.body, &new_subst, kb, depth + 1)?;
solutions.extend(body_solutions);
}
}
entry.solutions = solutions.clone();
entry.complete = true;
self.table.insert(key, entry);
Ok(solutions)
}
fn solve_conjunction(
&mut self,
goals: &[Predicate],
subst: &Substitution,
kb: &KnowledgeBase,
depth: usize,
) -> Result<Vec<Substitution>> {
if goals.is_empty() {
return Ok(vec![subst.clone()]);
}
let first = &goals[0];
let rest = &goals[1..];
let first_solutions = self.solve_tabled(first, subst, kb, depth)?;
let mut all_solutions = Vec::new();
for first_subst in first_solutions {
let rest_solutions = self.solve_conjunction(rest, &first_subst, kb, depth)?;
all_solutions.extend(rest_solutions);
if all_solutions.len() >= self.max_solutions {
break;
}
}
Ok(all_solutions)
}
fn goal_key(&self, goal: &Predicate) -> String {
format!("{}({})", goal.name, goal.args.len())
}
fn rename_rule(&self, rule: &Rule, suffix: usize) -> Rule {
let var_map: HashMap<String, String> = rule
.variables()
.into_iter()
.map(|v| (v.clone(), format!("{}_{}", v, suffix)))
.collect();
let rename_subst: Substitution = var_map
.into_iter()
.map(|(old, new)| (old, crate::ir::Term::Var(new)))
.collect();
Rule {
head: apply_subst_predicate(&rule.head, &rename_subst),
body: rule
.body
.iter()
.map(|p| apply_subst_predicate(p, &rename_subst))
.collect(),
}
}
pub fn table_stats(&self) -> TableStats {
TableStats {
entries: self.table.len(),
complete_entries: self.table.values().filter(|e| e.complete).count(),
total_solutions: self.table.values().map(|e| e.solutions.len()).sum(),
}
}
pub fn clear_table(&mut self) {
self.table.clear();
}
}
impl Default for TabledInferenceEngine {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TableStats {
pub entries: usize,
pub complete_entries: usize,
pub total_solutions: usize,
}
pub struct FixpointEngine {
max_iterations: usize,
}
impl FixpointEngine {
pub fn new() -> Self {
Self {
max_iterations: 100,
}
}
pub fn with_max_iterations(max_iterations: usize) -> Self {
Self { max_iterations }
}
pub fn compute_fixpoint(&self, kb: &KnowledgeBase) -> Result<KnowledgeBase> {
let mut current_kb = kb.clone();
let mut iteration = 0;
loop {
iteration += 1;
if iteration > self.max_iterations {
break;
}
let mut new_facts = Vec::new();
let mut changed = false;
let predicate_names: std::collections::HashSet<String> = current_kb
.rules
.iter()
.map(|r| r.head.name.clone())
.collect();
for predicate_name in predicate_names {
for rule in current_kb.get_rules(&predicate_name) {
let derived = self.derive_facts_from_rule(rule, ¤t_kb)?;
for fact in derived {
if !current_kb.facts.contains(&fact) {
new_facts.push(fact);
changed = true;
}
}
}
}
for fact in new_facts {
current_kb.add_fact(fact);
}
if !changed {
break;
}
}
Ok(current_kb)
}
fn derive_facts_from_rule(&self, _rule: &Rule, _kb: &KnowledgeBase) -> Result<Vec<Predicate>> {
let derived = Vec::new();
Ok(derived)
}
}
impl Default for FixpointEngine {
fn default() -> Self {
Self::new()
}
}
pub struct StratificationAnalyzer {
dependencies: HashMap<String, HashSet<String>>,
}
impl StratificationAnalyzer {
pub fn new() -> Self {
Self {
dependencies: HashMap::new(),
}
}
pub fn analyze(&mut self, kb: &KnowledgeBase) -> StratificationResult {
self.build_dependency_graph(kb);
if self.has_cycles() {
StratificationResult::NonStratifiable
} else {
let strata = self.compute_strata();
StratificationResult::Stratifiable(strata)
}
}
fn build_dependency_graph(&mut self, kb: &KnowledgeBase) {
let predicate_names: HashSet<String> =
kb.rules.iter().map(|r| r.head.name.clone()).collect();
for predicate_name in predicate_names {
for rule in kb.get_rules(&predicate_name) {
let head = &rule.head.name;
let deps: HashSet<String> = rule.body.iter().map(|p| p.name.clone()).collect();
self.dependencies
.entry(head.clone())
.or_default()
.extend(deps);
}
}
}
fn has_cycles(&self) -> bool {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
for node in self.dependencies.keys() {
if self.has_cycle_util(node, &mut visited, &mut rec_stack) {
return true;
}
}
false
}
fn has_cycle_util(
&self,
node: &str,
visited: &mut HashSet<String>,
rec_stack: &mut HashSet<String>,
) -> bool {
if rec_stack.contains(node) {
return true;
}
if visited.contains(node) {
return false;
}
visited.insert(node.to_string());
rec_stack.insert(node.to_string());
if let Some(neighbors) = self.dependencies.get(node) {
for neighbor in neighbors {
if self.has_cycle_util(neighbor, visited, rec_stack) {
return true;
}
}
}
rec_stack.remove(node);
false
}
fn compute_strata(&self) -> Vec<Vec<String>> {
let mut strata = Vec::new();
let mut remaining: HashSet<String> = self.dependencies.keys().cloned().collect();
while !remaining.is_empty() {
let mut current_stratum = Vec::new();
for pred in &remaining {
let has_remaining_deps = self
.dependencies
.get(pred)
.map(|deps| deps.iter().any(|d| remaining.contains(d)))
.unwrap_or(false);
if !has_remaining_deps {
current_stratum.push(pred.clone());
}
}
if current_stratum.is_empty() {
break;
}
for pred in ¤t_stratum {
remaining.remove(pred);
}
strata.push(current_stratum);
}
strata
}
}
impl Default for StratificationAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum StratificationResult {
Stratifiable(Vec<Vec<String>>),
NonStratifiable,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{Constant, Term};
#[test]
fn test_tabled_inference_basic() {
let mut kb = KnowledgeBase::new();
kb.add_fact(Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("alice".to_string())),
Term::Const(Constant::String("bob".to_string())),
],
));
kb.add_fact(Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("bob".to_string())),
Term::Const(Constant::String("charlie".to_string())),
],
));
kb.add_rule(Rule::new(
Predicate::new(
"ancestor".to_string(),
vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
),
vec![Predicate::new(
"parent".to_string(),
vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
)],
));
kb.add_rule(Rule::new(
Predicate::new(
"ancestor".to_string(),
vec![Term::Var("X".to_string()), Term::Var("Z".to_string())],
),
vec![
Predicate::new(
"parent".to_string(),
vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
),
Predicate::new(
"ancestor".to_string(),
vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
),
],
));
let engine = TabledInferenceEngine::new();
let goal = Predicate::new(
"ancestor".to_string(),
vec![
Term::Const(Constant::String("alice".to_string())),
Term::Var("Z".to_string()),
],
);
let solutions = engine.query(&goal, &kb).unwrap();
assert!(!solutions.is_empty());
}
#[test]
fn test_table_stats() {
let engine = TabledInferenceEngine::new();
let stats = engine.table_stats();
assert_eq!(stats.entries, 0);
assert_eq!(stats.complete_entries, 0);
}
#[test]
fn test_stratification_no_cycles() {
let mut kb = KnowledgeBase::new();
kb.add_rule(Rule::new(
Predicate::new(
"grandparent".to_string(),
vec![Term::Var("X".to_string()), Term::Var("Z".to_string())],
),
vec![
Predicate::new(
"parent".to_string(),
vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
),
Predicate::new(
"parent".to_string(),
vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
),
],
));
let mut analyzer = StratificationAnalyzer::new();
let result = analyzer.analyze(&kb);
match result {
StratificationResult::Stratifiable(strata) => {
assert!(!strata.is_empty());
}
StratificationResult::NonStratifiable => {
panic!("Expected stratifiable result");
}
}
}
#[test]
fn test_fixpoint_engine() {
let engine = FixpointEngine::new();
let kb = KnowledgeBase::new();
let result = engine.compute_fixpoint(&kb).unwrap();
assert_eq!(result.facts.len(), kb.facts.len());
}
}