ipfrs_tensorlogic/
recursive_reasoning.rs

1//! Recursive Query Support with Tabling
2//!
3//! This module implements advanced recursive query handling including:
4//! - Tabling/tabulation for efficient recursive queries
5//! - Stratified evaluation
6//! - Support for left-recursive rules
7//! - Fixpoint computation
8//!
9//! # Tabling
10//!
11//! Tabling (also called tabled resolution or SLG resolution) is a technique
12//! for evaluating logic programs that improves on standard SLD resolution
13//! by memoizing intermediate results and detecting loops.
14//!
15//! # Example
16//!
17//! ```
18//! use ipfrs_tensorlogic::{TabledInferenceEngine, KnowledgeBase, Predicate, Rule, Term, Constant};
19//!
20//! let mut kb = KnowledgeBase::new();
21//!
22//! // Define ancestor relation: ancestor(X, Y) :- parent(X, Y).
23//! // ancestor(X, Z) :- parent(X, Y), ancestor(Y, Z).
24//! // This is recursive and benefits from tabling
25//!
26//! // Add parent facts
27//! kb.add_fact(Predicate::new("parent".to_string(), vec![
28//!     Term::Const(Constant::String("alice".to_string())),
29//!     Term::Const(Constant::String("bob".to_string())),
30//! ]));
31//! kb.add_fact(Predicate::new("parent".to_string(), vec![
32//!     Term::Const(Constant::String("bob".to_string())),
33//!     Term::Const(Constant::String("charlie".to_string())),
34//! ]));
35//!
36//! // Add base rule: ancestor(X, Y) :- parent(X, Y)
37//! kb.add_rule(Rule::new(
38//!     Predicate::new("ancestor".to_string(), vec![
39//!         Term::Var("X".to_string()),
40//!         Term::Var("Y".to_string()),
41//!     ]),
42//!     vec![Predicate::new("parent".to_string(), vec![
43//!         Term::Var("X".to_string()),
44//!         Term::Var("Y".to_string()),
45//!     ])],
46//! ));
47//!
48//! // Add recursive rule: ancestor(X, Z) :- parent(X, Y), ancestor(Y, Z)
49//! kb.add_rule(Rule::new(
50//!     Predicate::new("ancestor".to_string(), vec![
51//!         Term::Var("X".to_string()),
52//!         Term::Var("Z".to_string()),
53//!     ]),
54//!     vec![
55//!         Predicate::new("parent".to_string(), vec![
56//!             Term::Var("X".to_string()),
57//!             Term::Var("Y".to_string()),
58//!         ]),
59//!         Predicate::new("ancestor".to_string(), vec![
60//!             Term::Var("Y".to_string()),
61//!             Term::Var("Z".to_string()),
62//!         ]),
63//!     ],
64//! ));
65//!
66//! // Create tabled engine
67//! let engine = TabledInferenceEngine::new();
68//!
69//! // Query for all ancestors of alice
70//! let goal = Predicate::new("ancestor".to_string(), vec![
71//!     Term::Const(Constant::String("alice".to_string())),
72//!     Term::Var("Z".to_string()),
73//! ]);
74//!
75//! let solutions = engine.query(&goal, &kb).unwrap();
76//! // Should find at least bob as an ancestor
77//! assert!(!solutions.is_empty());
78//! ```
79
80use crate::ir::{KnowledgeBase, Predicate, Rule};
81use crate::reasoning::{apply_subst_predicate, unify_predicates, Substitution};
82use ipfrs_core::error::Result;
83use std::collections::{HashMap, HashSet};
84
85/// Table entry for memoized subgoals
86#[derive(Debug, Clone)]
87struct TableEntry {
88    /// The subgoal being solved
89    #[allow(dead_code)]
90    goal: Predicate,
91    /// Solutions found so far
92    solutions: Vec<Substitution>,
93    /// Whether this entry is complete
94    complete: bool,
95    /// Depth at which this was tabled
96    #[allow(dead_code)]
97    depth: usize,
98}
99
100/// Tabled inference engine using SLG resolution
101pub struct TabledInferenceEngine {
102    /// Table for memoizing subgoals
103    table: HashMap<String, TableEntry>,
104    /// Maximum depth
105    max_depth: usize,
106    /// Maximum solutions per subgoal
107    max_solutions: usize,
108}
109
110impl TabledInferenceEngine {
111    /// Create a new tabled inference engine
112    pub fn new() -> Self {
113        Self {
114            table: HashMap::new(),
115            max_depth: 100,
116            max_solutions: 1000,
117        }
118    }
119
120    /// Create with custom limits
121    pub fn with_limits(max_depth: usize, max_solutions: usize) -> Self {
122        Self {
123            table: HashMap::new(),
124            max_depth,
125            max_solutions,
126        }
127    }
128
129    /// Query with tabling
130    pub fn query(&self, goal: &Predicate, kb: &KnowledgeBase) -> Result<Vec<Substitution>> {
131        let mut engine = Self {
132            table: HashMap::new(),
133            max_depth: self.max_depth,
134            max_solutions: self.max_solutions,
135        };
136
137        engine.solve_tabled(goal, &Substitution::new(), kb, 0)
138    }
139
140    /// Solve a goal with tabling
141    fn solve_tabled(
142        &mut self,
143        goal: &Predicate,
144        subst: &Substitution,
145        kb: &KnowledgeBase,
146        depth: usize,
147    ) -> Result<Vec<Substitution>> {
148        // Check depth limit
149        if depth > self.max_depth {
150            return Ok(Vec::new());
151        }
152
153        // Apply substitution to goal
154        let goal = apply_subst_predicate(goal, subst);
155
156        // Create table key
157        let key = self.goal_key(&goal);
158
159        // Check if goal is already tabled
160        if let Some(entry) = self.table.get(&key) {
161            // If complete, return cached solutions
162            if entry.complete {
163                return Ok(entry.solutions.clone());
164            }
165            // If incomplete, we have a loop - return empty for now
166            return Ok(Vec::new());
167        }
168
169        // Create new table entry
170        let mut entry = TableEntry {
171            goal: goal.clone(),
172            solutions: Vec::new(),
173            complete: false,
174            depth,
175        };
176
177        // Insert incomplete entry to detect loops
178        self.table.insert(key.clone(), entry.clone());
179
180        // Solve using standard backward chaining
181        let mut solutions = Vec::new();
182
183        // Try facts
184        for fact in kb.get_predicates(&goal.name) {
185            if let Some(new_subst) = unify_predicates(&goal, fact, &Substitution::new()) {
186                solutions.push(new_subst);
187                if solutions.len() >= self.max_solutions {
188                    break;
189                }
190            }
191        }
192
193        // Try rules
194        for rule in kb.get_rules(&goal.name) {
195            if solutions.len() >= self.max_solutions {
196                break;
197            }
198
199            // Rename variables in rule
200            let renamed_rule = self.rename_rule(rule, depth);
201
202            // Try to unify with rule head
203            if let Some(new_subst) =
204                unify_predicates(&goal, &renamed_rule.head, &Substitution::new())
205            {
206                // Solve rule body
207                let body_solutions =
208                    self.solve_conjunction(&renamed_rule.body, &new_subst, kb, depth + 1)?;
209                solutions.extend(body_solutions);
210            }
211        }
212
213        // Mark entry as complete and update solutions
214        entry.solutions = solutions.clone();
215        entry.complete = true;
216        self.table.insert(key, entry);
217
218        Ok(solutions)
219    }
220
221    /// Solve a conjunction of goals
222    fn solve_conjunction(
223        &mut self,
224        goals: &[Predicate],
225        subst: &Substitution,
226        kb: &KnowledgeBase,
227        depth: usize,
228    ) -> Result<Vec<Substitution>> {
229        if goals.is_empty() {
230            return Ok(vec![subst.clone()]);
231        }
232
233        let first = &goals[0];
234        let rest = &goals[1..];
235
236        let first_solutions = self.solve_tabled(first, subst, kb, depth)?;
237
238        let mut all_solutions = Vec::new();
239        for first_subst in first_solutions {
240            let rest_solutions = self.solve_conjunction(rest, &first_subst, kb, depth)?;
241            all_solutions.extend(rest_solutions);
242
243            if all_solutions.len() >= self.max_solutions {
244                break;
245            }
246        }
247
248        Ok(all_solutions)
249    }
250
251    /// Generate a unique key for a goal
252    fn goal_key(&self, goal: &Predicate) -> String {
253        format!("{}({})", goal.name, goal.args.len())
254    }
255
256    /// Rename variables in a rule
257    fn rename_rule(&self, rule: &Rule, suffix: usize) -> Rule {
258        let var_map: HashMap<String, String> = rule
259            .variables()
260            .into_iter()
261            .map(|v| (v.clone(), format!("{}_{}", v, suffix)))
262            .collect();
263
264        let rename_subst: Substitution = var_map
265            .into_iter()
266            .map(|(old, new)| (old, crate::ir::Term::Var(new)))
267            .collect();
268
269        Rule {
270            head: apply_subst_predicate(&rule.head, &rename_subst),
271            body: rule
272                .body
273                .iter()
274                .map(|p| apply_subst_predicate(p, &rename_subst))
275                .collect(),
276        }
277    }
278
279    /// Get table statistics
280    pub fn table_stats(&self) -> TableStats {
281        TableStats {
282            entries: self.table.len(),
283            complete_entries: self.table.values().filter(|e| e.complete).count(),
284            total_solutions: self.table.values().map(|e| e.solutions.len()).sum(),
285        }
286    }
287
288    /// Clear the table
289    pub fn clear_table(&mut self) {
290        self.table.clear();
291    }
292}
293
294impl Default for TabledInferenceEngine {
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300/// Statistics about the tabling system
301#[derive(Debug, Clone)]
302pub struct TableStats {
303    /// Number of table entries
304    pub entries: usize,
305    /// Number of complete entries
306    pub complete_entries: usize,
307    /// Total solutions across all entries
308    pub total_solutions: usize,
309}
310
311/// Fixpoint computation for stratified programs
312pub struct FixpointEngine {
313    /// Maximum iterations for fixpoint
314    max_iterations: usize,
315}
316
317impl FixpointEngine {
318    /// Create a new fixpoint engine
319    pub fn new() -> Self {
320        Self {
321            max_iterations: 100,
322        }
323    }
324
325    /// Create with custom iteration limit
326    pub fn with_max_iterations(max_iterations: usize) -> Self {
327        Self { max_iterations }
328    }
329
330    /// Compute fixpoint for a set of rules
331    pub fn compute_fixpoint(&self, kb: &KnowledgeBase) -> Result<KnowledgeBase> {
332        let mut current_kb = kb.clone();
333        let mut iteration = 0;
334
335        loop {
336            iteration += 1;
337            if iteration > self.max_iterations {
338                break;
339            }
340
341            let mut new_facts = Vec::new();
342            let mut changed = false;
343
344            // Apply all rules to derive new facts
345            // Collect unique predicate names from rules
346            let predicate_names: std::collections::HashSet<String> = current_kb
347                .rules
348                .iter()
349                .map(|r| r.head.name.clone())
350                .collect();
351
352            for predicate_name in predicate_names {
353                for rule in current_kb.get_rules(&predicate_name) {
354                    let derived = self.derive_facts_from_rule(rule, &current_kb)?;
355                    for fact in derived {
356                        // Check if fact already exists
357                        if !current_kb.facts.contains(&fact) {
358                            new_facts.push(fact);
359                            changed = true;
360                        }
361                    }
362                }
363            }
364
365            // Add new facts to KB
366            for fact in new_facts {
367                current_kb.add_fact(fact);
368            }
369
370            // If no new facts, we've reached fixpoint
371            if !changed {
372                break;
373            }
374        }
375
376        Ok(current_kb)
377    }
378
379    /// Derive facts from a single rule
380    fn derive_facts_from_rule(&self, _rule: &Rule, _kb: &KnowledgeBase) -> Result<Vec<Predicate>> {
381        let derived = Vec::new();
382
383        // This is a simplified implementation
384        // A full implementation would do proper unification and substitution
385
386        // For now, just return empty
387        // TODO: Implement full derivation
388
389        Ok(derived)
390    }
391}
392
393impl Default for FixpointEngine {
394    fn default() -> Self {
395        Self::new()
396    }
397}
398
399/// Stratification analysis for logic programs
400pub struct StratificationAnalyzer {
401    /// Dependency graph between predicates
402    dependencies: HashMap<String, HashSet<String>>,
403}
404
405impl StratificationAnalyzer {
406    /// Create a new stratification analyzer
407    pub fn new() -> Self {
408        Self {
409            dependencies: HashMap::new(),
410        }
411    }
412
413    /// Analyze a knowledge base for stratification
414    pub fn analyze(&mut self, kb: &KnowledgeBase) -> StratificationResult {
415        self.build_dependency_graph(kb);
416
417        // Check for cycles
418        if self.has_cycles() {
419            StratificationResult::NonStratifiable
420        } else {
421            // Compute stratification levels
422            let strata = self.compute_strata();
423            StratificationResult::Stratifiable(strata)
424        }
425    }
426
427    /// Build dependency graph from KB
428    fn build_dependency_graph(&mut self, kb: &KnowledgeBase) {
429        // Collect unique predicate names from rules
430        let predicate_names: HashSet<String> =
431            kb.rules.iter().map(|r| r.head.name.clone()).collect();
432
433        for predicate_name in predicate_names {
434            for rule in kb.get_rules(&predicate_name) {
435                let head = &rule.head.name;
436                let deps: HashSet<String> = rule.body.iter().map(|p| p.name.clone()).collect();
437
438                self.dependencies
439                    .entry(head.clone())
440                    .or_default()
441                    .extend(deps);
442            }
443        }
444    }
445
446    /// Check if dependency graph has cycles
447    fn has_cycles(&self) -> bool {
448        let mut visited = HashSet::new();
449        let mut rec_stack = HashSet::new();
450
451        for node in self.dependencies.keys() {
452            if self.has_cycle_util(node, &mut visited, &mut rec_stack) {
453                return true;
454            }
455        }
456
457        false
458    }
459
460    /// Utility for cycle detection (DFS)
461    fn has_cycle_util(
462        &self,
463        node: &str,
464        visited: &mut HashSet<String>,
465        rec_stack: &mut HashSet<String>,
466    ) -> bool {
467        if rec_stack.contains(node) {
468            return true;
469        }
470
471        if visited.contains(node) {
472            return false;
473        }
474
475        visited.insert(node.to_string());
476        rec_stack.insert(node.to_string());
477
478        if let Some(neighbors) = self.dependencies.get(node) {
479            for neighbor in neighbors {
480                if self.has_cycle_util(neighbor, visited, rec_stack) {
481                    return true;
482                }
483            }
484        }
485
486        rec_stack.remove(node);
487        false
488    }
489
490    /// Compute stratification levels
491    fn compute_strata(&self) -> Vec<Vec<String>> {
492        let mut strata = Vec::new();
493        let mut remaining: HashSet<String> = self.dependencies.keys().cloned().collect();
494
495        while !remaining.is_empty() {
496            // Find predicates with no dependencies on remaining predicates
497            let mut current_stratum = Vec::new();
498
499            for pred in &remaining {
500                let has_remaining_deps = self
501                    .dependencies
502                    .get(pred)
503                    .map(|deps| deps.iter().any(|d| remaining.contains(d)))
504                    .unwrap_or(false);
505
506                if !has_remaining_deps {
507                    current_stratum.push(pred.clone());
508                }
509            }
510
511            if current_stratum.is_empty() {
512                // Shouldn't happen if no cycles, but break to avoid infinite loop
513                break;
514            }
515
516            for pred in &current_stratum {
517                remaining.remove(pred);
518            }
519
520            strata.push(current_stratum);
521        }
522
523        strata
524    }
525}
526
527impl Default for StratificationAnalyzer {
528    fn default() -> Self {
529        Self::new()
530    }
531}
532
533/// Result of stratification analysis
534#[derive(Debug, Clone)]
535pub enum StratificationResult {
536    /// Program is stratifiable with given strata
537    Stratifiable(Vec<Vec<String>>),
538    /// Program contains unstratifiable recursion
539    NonStratifiable,
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545    use crate::ir::{Constant, Term};
546
547    #[test]
548    fn test_tabled_inference_basic() {
549        let mut kb = KnowledgeBase::new();
550
551        // Add facts
552        kb.add_fact(Predicate::new(
553            "parent".to_string(),
554            vec![
555                Term::Const(Constant::String("alice".to_string())),
556                Term::Const(Constant::String("bob".to_string())),
557            ],
558        ));
559        kb.add_fact(Predicate::new(
560            "parent".to_string(),
561            vec![
562                Term::Const(Constant::String("bob".to_string())),
563                Term::Const(Constant::String("charlie".to_string())),
564            ],
565        ));
566
567        // Add recursive rule: ancestor(X, Y) :- parent(X, Y)
568        kb.add_rule(Rule::new(
569            Predicate::new(
570                "ancestor".to_string(),
571                vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
572            ),
573            vec![Predicate::new(
574                "parent".to_string(),
575                vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
576            )],
577        ));
578
579        // Add recursive rule: ancestor(X, Z) :- parent(X, Y), ancestor(Y, Z)
580        kb.add_rule(Rule::new(
581            Predicate::new(
582                "ancestor".to_string(),
583                vec![Term::Var("X".to_string()), Term::Var("Z".to_string())],
584            ),
585            vec![
586                Predicate::new(
587                    "parent".to_string(),
588                    vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
589                ),
590                Predicate::new(
591                    "ancestor".to_string(),
592                    vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
593                ),
594            ],
595        ));
596
597        let engine = TabledInferenceEngine::new();
598
599        let goal = Predicate::new(
600            "ancestor".to_string(),
601            vec![
602                Term::Const(Constant::String("alice".to_string())),
603                Term::Var("Z".to_string()),
604            ],
605        );
606
607        let solutions = engine.query(&goal, &kb).unwrap();
608        assert!(!solutions.is_empty());
609    }
610
611    #[test]
612    fn test_table_stats() {
613        let engine = TabledInferenceEngine::new();
614        let stats = engine.table_stats();
615        assert_eq!(stats.entries, 0);
616        assert_eq!(stats.complete_entries, 0);
617    }
618
619    #[test]
620    fn test_stratification_no_cycles() {
621        let mut kb = KnowledgeBase::new();
622
623        // Add non-recursive rule: grandparent(X, Z) :- parent(X, Y), parent(Y, Z)
624        kb.add_rule(Rule::new(
625            Predicate::new(
626                "grandparent".to_string(),
627                vec![Term::Var("X".to_string()), Term::Var("Z".to_string())],
628            ),
629            vec![
630                Predicate::new(
631                    "parent".to_string(),
632                    vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
633                ),
634                Predicate::new(
635                    "parent".to_string(),
636                    vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
637                ),
638            ],
639        ));
640
641        let mut analyzer = StratificationAnalyzer::new();
642        let result = analyzer.analyze(&kb);
643
644        match result {
645            StratificationResult::Stratifiable(strata) => {
646                assert!(!strata.is_empty());
647            }
648            StratificationResult::NonStratifiable => {
649                // Should be stratifiable
650                panic!("Expected stratifiable result");
651            }
652        }
653    }
654
655    #[test]
656    fn test_fixpoint_engine() {
657        let engine = FixpointEngine::new();
658        let kb = KnowledgeBase::new();
659
660        // Compute fixpoint (should return same KB for empty KB)
661        let result = engine.compute_fixpoint(&kb).unwrap();
662        assert_eq!(result.facts.len(), kb.facts.len());
663    }
664}