1use crate::ir::{KnowledgeBase, Predicate, Rule};
81use crate::reasoning::{apply_subst_predicate, unify_predicates, Substitution};
82use ipfrs_core::error::Result;
83use std::collections::{HashMap, HashSet};
84
85#[derive(Debug, Clone)]
87struct TableEntry {
88 #[allow(dead_code)]
90 goal: Predicate,
91 solutions: Vec<Substitution>,
93 complete: bool,
95 #[allow(dead_code)]
97 depth: usize,
98}
99
100pub struct TabledInferenceEngine {
102 table: HashMap<String, TableEntry>,
104 max_depth: usize,
106 max_solutions: usize,
108}
109
110impl TabledInferenceEngine {
111 pub fn new() -> Self {
113 Self {
114 table: HashMap::new(),
115 max_depth: 100,
116 max_solutions: 1000,
117 }
118 }
119
120 pub fn with_limits(max_depth: usize, max_solutions: usize) -> Self {
122 Self {
123 table: HashMap::new(),
124 max_depth,
125 max_solutions,
126 }
127 }
128
129 pub fn query(&self, goal: &Predicate, kb: &KnowledgeBase) -> Result<Vec<Substitution>> {
131 let mut engine = Self {
132 table: HashMap::new(),
133 max_depth: self.max_depth,
134 max_solutions: self.max_solutions,
135 };
136
137 engine.solve_tabled(goal, &Substitution::new(), kb, 0)
138 }
139
140 fn solve_tabled(
142 &mut self,
143 goal: &Predicate,
144 subst: &Substitution,
145 kb: &KnowledgeBase,
146 depth: usize,
147 ) -> Result<Vec<Substitution>> {
148 if depth > self.max_depth {
150 return Ok(Vec::new());
151 }
152
153 let goal = apply_subst_predicate(goal, subst);
155
156 let key = self.goal_key(&goal);
158
159 if let Some(entry) = self.table.get(&key) {
161 if entry.complete {
163 return Ok(entry.solutions.clone());
164 }
165 return Ok(Vec::new());
167 }
168
169 let mut entry = TableEntry {
171 goal: goal.clone(),
172 solutions: Vec::new(),
173 complete: false,
174 depth,
175 };
176
177 self.table.insert(key.clone(), entry.clone());
179
180 let mut solutions = Vec::new();
182
183 for fact in kb.get_predicates(&goal.name) {
185 if let Some(new_subst) = unify_predicates(&goal, fact, &Substitution::new()) {
186 solutions.push(new_subst);
187 if solutions.len() >= self.max_solutions {
188 break;
189 }
190 }
191 }
192
193 for rule in kb.get_rules(&goal.name) {
195 if solutions.len() >= self.max_solutions {
196 break;
197 }
198
199 let renamed_rule = self.rename_rule(rule, depth);
201
202 if let Some(new_subst) =
204 unify_predicates(&goal, &renamed_rule.head, &Substitution::new())
205 {
206 let body_solutions =
208 self.solve_conjunction(&renamed_rule.body, &new_subst, kb, depth + 1)?;
209 solutions.extend(body_solutions);
210 }
211 }
212
213 entry.solutions = solutions.clone();
215 entry.complete = true;
216 self.table.insert(key, entry);
217
218 Ok(solutions)
219 }
220
221 fn solve_conjunction(
223 &mut self,
224 goals: &[Predicate],
225 subst: &Substitution,
226 kb: &KnowledgeBase,
227 depth: usize,
228 ) -> Result<Vec<Substitution>> {
229 if goals.is_empty() {
230 return Ok(vec![subst.clone()]);
231 }
232
233 let first = &goals[0];
234 let rest = &goals[1..];
235
236 let first_solutions = self.solve_tabled(first, subst, kb, depth)?;
237
238 let mut all_solutions = Vec::new();
239 for first_subst in first_solutions {
240 let rest_solutions = self.solve_conjunction(rest, &first_subst, kb, depth)?;
241 all_solutions.extend(rest_solutions);
242
243 if all_solutions.len() >= self.max_solutions {
244 break;
245 }
246 }
247
248 Ok(all_solutions)
249 }
250
251 fn goal_key(&self, goal: &Predicate) -> String {
253 format!("{}({})", goal.name, goal.args.len())
254 }
255
256 fn rename_rule(&self, rule: &Rule, suffix: usize) -> Rule {
258 let var_map: HashMap<String, String> = rule
259 .variables()
260 .into_iter()
261 .map(|v| (v.clone(), format!("{}_{}", v, suffix)))
262 .collect();
263
264 let rename_subst: Substitution = var_map
265 .into_iter()
266 .map(|(old, new)| (old, crate::ir::Term::Var(new)))
267 .collect();
268
269 Rule {
270 head: apply_subst_predicate(&rule.head, &rename_subst),
271 body: rule
272 .body
273 .iter()
274 .map(|p| apply_subst_predicate(p, &rename_subst))
275 .collect(),
276 }
277 }
278
279 pub fn table_stats(&self) -> TableStats {
281 TableStats {
282 entries: self.table.len(),
283 complete_entries: self.table.values().filter(|e| e.complete).count(),
284 total_solutions: self.table.values().map(|e| e.solutions.len()).sum(),
285 }
286 }
287
288 pub fn clear_table(&mut self) {
290 self.table.clear();
291 }
292}
293
294impl Default for TabledInferenceEngine {
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300#[derive(Debug, Clone)]
302pub struct TableStats {
303 pub entries: usize,
305 pub complete_entries: usize,
307 pub total_solutions: usize,
309}
310
311pub struct FixpointEngine {
313 max_iterations: usize,
315}
316
317impl FixpointEngine {
318 pub fn new() -> Self {
320 Self {
321 max_iterations: 100,
322 }
323 }
324
325 pub fn with_max_iterations(max_iterations: usize) -> Self {
327 Self { max_iterations }
328 }
329
330 pub fn compute_fixpoint(&self, kb: &KnowledgeBase) -> Result<KnowledgeBase> {
332 let mut current_kb = kb.clone();
333 let mut iteration = 0;
334
335 loop {
336 iteration += 1;
337 if iteration > self.max_iterations {
338 break;
339 }
340
341 let mut new_facts = Vec::new();
342 let mut changed = false;
343
344 let predicate_names: std::collections::HashSet<String> = current_kb
347 .rules
348 .iter()
349 .map(|r| r.head.name.clone())
350 .collect();
351
352 for predicate_name in predicate_names {
353 for rule in current_kb.get_rules(&predicate_name) {
354 let derived = self.derive_facts_from_rule(rule, ¤t_kb)?;
355 for fact in derived {
356 if !current_kb.facts.contains(&fact) {
358 new_facts.push(fact);
359 changed = true;
360 }
361 }
362 }
363 }
364
365 for fact in new_facts {
367 current_kb.add_fact(fact);
368 }
369
370 if !changed {
372 break;
373 }
374 }
375
376 Ok(current_kb)
377 }
378
379 fn derive_facts_from_rule(&self, _rule: &Rule, _kb: &KnowledgeBase) -> Result<Vec<Predicate>> {
381 let derived = Vec::new();
382
383 Ok(derived)
390 }
391}
392
393impl Default for FixpointEngine {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399pub struct StratificationAnalyzer {
401 dependencies: HashMap<String, HashSet<String>>,
403}
404
405impl StratificationAnalyzer {
406 pub fn new() -> Self {
408 Self {
409 dependencies: HashMap::new(),
410 }
411 }
412
413 pub fn analyze(&mut self, kb: &KnowledgeBase) -> StratificationResult {
415 self.build_dependency_graph(kb);
416
417 if self.has_cycles() {
419 StratificationResult::NonStratifiable
420 } else {
421 let strata = self.compute_strata();
423 StratificationResult::Stratifiable(strata)
424 }
425 }
426
427 fn build_dependency_graph(&mut self, kb: &KnowledgeBase) {
429 let predicate_names: HashSet<String> =
431 kb.rules.iter().map(|r| r.head.name.clone()).collect();
432
433 for predicate_name in predicate_names {
434 for rule in kb.get_rules(&predicate_name) {
435 let head = &rule.head.name;
436 let deps: HashSet<String> = rule.body.iter().map(|p| p.name.clone()).collect();
437
438 self.dependencies
439 .entry(head.clone())
440 .or_default()
441 .extend(deps);
442 }
443 }
444 }
445
446 fn has_cycles(&self) -> bool {
448 let mut visited = HashSet::new();
449 let mut rec_stack = HashSet::new();
450
451 for node in self.dependencies.keys() {
452 if self.has_cycle_util(node, &mut visited, &mut rec_stack) {
453 return true;
454 }
455 }
456
457 false
458 }
459
460 fn has_cycle_util(
462 &self,
463 node: &str,
464 visited: &mut HashSet<String>,
465 rec_stack: &mut HashSet<String>,
466 ) -> bool {
467 if rec_stack.contains(node) {
468 return true;
469 }
470
471 if visited.contains(node) {
472 return false;
473 }
474
475 visited.insert(node.to_string());
476 rec_stack.insert(node.to_string());
477
478 if let Some(neighbors) = self.dependencies.get(node) {
479 for neighbor in neighbors {
480 if self.has_cycle_util(neighbor, visited, rec_stack) {
481 return true;
482 }
483 }
484 }
485
486 rec_stack.remove(node);
487 false
488 }
489
490 fn compute_strata(&self) -> Vec<Vec<String>> {
492 let mut strata = Vec::new();
493 let mut remaining: HashSet<String> = self.dependencies.keys().cloned().collect();
494
495 while !remaining.is_empty() {
496 let mut current_stratum = Vec::new();
498
499 for pred in &remaining {
500 let has_remaining_deps = self
501 .dependencies
502 .get(pred)
503 .map(|deps| deps.iter().any(|d| remaining.contains(d)))
504 .unwrap_or(false);
505
506 if !has_remaining_deps {
507 current_stratum.push(pred.clone());
508 }
509 }
510
511 if current_stratum.is_empty() {
512 break;
514 }
515
516 for pred in ¤t_stratum {
517 remaining.remove(pred);
518 }
519
520 strata.push(current_stratum);
521 }
522
523 strata
524 }
525}
526
527impl Default for StratificationAnalyzer {
528 fn default() -> Self {
529 Self::new()
530 }
531}
532
533#[derive(Debug, Clone)]
535pub enum StratificationResult {
536 Stratifiable(Vec<Vec<String>>),
538 NonStratifiable,
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545 use crate::ir::{Constant, Term};
546
547 #[test]
548 fn test_tabled_inference_basic() {
549 let mut kb = KnowledgeBase::new();
550
551 kb.add_fact(Predicate::new(
553 "parent".to_string(),
554 vec![
555 Term::Const(Constant::String("alice".to_string())),
556 Term::Const(Constant::String("bob".to_string())),
557 ],
558 ));
559 kb.add_fact(Predicate::new(
560 "parent".to_string(),
561 vec![
562 Term::Const(Constant::String("bob".to_string())),
563 Term::Const(Constant::String("charlie".to_string())),
564 ],
565 ));
566
567 kb.add_rule(Rule::new(
569 Predicate::new(
570 "ancestor".to_string(),
571 vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
572 ),
573 vec![Predicate::new(
574 "parent".to_string(),
575 vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
576 )],
577 ));
578
579 kb.add_rule(Rule::new(
581 Predicate::new(
582 "ancestor".to_string(),
583 vec![Term::Var("X".to_string()), Term::Var("Z".to_string())],
584 ),
585 vec![
586 Predicate::new(
587 "parent".to_string(),
588 vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
589 ),
590 Predicate::new(
591 "ancestor".to_string(),
592 vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
593 ),
594 ],
595 ));
596
597 let engine = TabledInferenceEngine::new();
598
599 let goal = Predicate::new(
600 "ancestor".to_string(),
601 vec![
602 Term::Const(Constant::String("alice".to_string())),
603 Term::Var("Z".to_string()),
604 ],
605 );
606
607 let solutions = engine.query(&goal, &kb).unwrap();
608 assert!(!solutions.is_empty());
609 }
610
611 #[test]
612 fn test_table_stats() {
613 let engine = TabledInferenceEngine::new();
614 let stats = engine.table_stats();
615 assert_eq!(stats.entries, 0);
616 assert_eq!(stats.complete_entries, 0);
617 }
618
619 #[test]
620 fn test_stratification_no_cycles() {
621 let mut kb = KnowledgeBase::new();
622
623 kb.add_rule(Rule::new(
625 Predicate::new(
626 "grandparent".to_string(),
627 vec![Term::Var("X".to_string()), Term::Var("Z".to_string())],
628 ),
629 vec![
630 Predicate::new(
631 "parent".to_string(),
632 vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
633 ),
634 Predicate::new(
635 "parent".to_string(),
636 vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
637 ),
638 ],
639 ));
640
641 let mut analyzer = StratificationAnalyzer::new();
642 let result = analyzer.analyze(&kb);
643
644 match result {
645 StratificationResult::Stratifiable(strata) => {
646 assert!(!strata.is_empty());
647 }
648 StratificationResult::NonStratifiable => {
649 panic!("Expected stratifiable result");
651 }
652 }
653 }
654
655 #[test]
656 fn test_fixpoint_engine() {
657 let engine = FixpointEngine::new();
658 let kb = KnowledgeBase::new();
659
660 let result = engine.compute_fixpoint(&kb).unwrap();
662 assert_eq!(result.facts.len(), kb.facts.len());
663 }
664}