ipfrs_tensorlogic/
utils.rs

1//! Utility functions for common TensorLogic operations
2//!
3//! This module provides helper functions that make it easier to work with
4//! TensorLogic predicates, terms, and knowledge bases.
5
6use crate::ir::{Constant, KnowledgeBase, Predicate, Rule, Term};
7use crate::reasoning::{InferenceEngine, Substitution};
8use ipfrs_core::error::Result;
9use std::collections::HashMap;
10
11/// Builder for creating predicates more easily
12pub struct PredicateBuilder {
13    name: String,
14    args: Vec<Term>,
15}
16
17impl PredicateBuilder {
18    /// Create a new predicate builder
19    pub fn new(name: impl Into<String>) -> Self {
20        Self {
21            name: name.into(),
22            args: Vec::new(),
23        }
24    }
25
26    /// Add a constant string argument
27    pub fn arg_str(mut self, value: impl Into<String>) -> Self {
28        self.args.push(Term::Const(Constant::String(value.into())));
29        self
30    }
31
32    /// Add a constant integer argument
33    pub fn arg_int(mut self, value: i64) -> Self {
34        self.args.push(Term::Const(Constant::Int(value)));
35        self
36    }
37
38    /// Add a constant boolean argument
39    pub fn arg_bool(mut self, value: bool) -> Self {
40        self.args.push(Term::Const(Constant::Bool(value)));
41        self
42    }
43
44    /// Add a variable argument
45    pub fn arg_var(mut self, name: impl Into<String>) -> Self {
46        self.args.push(Term::Var(name.into()));
47        self
48    }
49
50    /// Add any term as an argument
51    pub fn arg(mut self, term: Term) -> Self {
52        self.args.push(term);
53        self
54    }
55
56    /// Build the predicate
57    pub fn build(self) -> Predicate {
58        Predicate::new(self.name, self.args)
59    }
60}
61
62/// Builder for creating rules more easily
63pub struct RuleBuilder {
64    head: Option<Predicate>,
65    body: Vec<Predicate>,
66}
67
68impl RuleBuilder {
69    /// Create a new rule builder
70    pub fn new() -> Self {
71        Self {
72            head: None,
73            body: Vec::new(),
74        }
75    }
76
77    /// Set the rule head
78    pub fn head(mut self, predicate: Predicate) -> Self {
79        self.head = Some(predicate);
80        self
81    }
82
83    /// Add a body predicate
84    pub fn body(mut self, predicate: Predicate) -> Self {
85        self.body.push(predicate);
86        self
87    }
88
89    /// Add multiple body predicates
90    pub fn bodies(mut self, predicates: Vec<Predicate>) -> Self {
91        self.body.extend(predicates);
92        self
93    }
94
95    /// Build the rule
96    pub fn build(self) -> Rule {
97        Rule::new(self.head.expect("Rule head must be set"), self.body)
98    }
99}
100
101impl Default for RuleBuilder {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107/// Utility functions for knowledge base operations
108pub struct KnowledgeBaseUtils;
109
110impl KnowledgeBaseUtils {
111    /// Create a knowledge base from a list of facts
112    pub fn from_facts(facts: Vec<Predicate>) -> KnowledgeBase {
113        let mut kb = KnowledgeBase::new();
114        for fact in facts {
115            kb.add_fact(fact);
116        }
117        kb
118    }
119
120    /// Merge two knowledge bases
121    pub fn merge(kb1: &KnowledgeBase, kb2: &KnowledgeBase) -> KnowledgeBase {
122        let mut merged = kb1.clone();
123        for fact in &kb2.facts {
124            if !merged.facts.contains(fact) {
125                merged.add_fact(fact.clone());
126            }
127        }
128        for rule in &kb2.rules {
129            merged.add_rule(rule.clone());
130        }
131        merged
132    }
133
134    /// Filter facts by predicate name
135    pub fn filter_facts(kb: &KnowledgeBase, predicate_name: &str) -> Vec<Predicate> {
136        kb.facts
137            .iter()
138            .filter(|p| p.name == predicate_name)
139            .cloned()
140            .collect()
141    }
142
143    /// Count predicates by name
144    pub fn count_predicates(kb: &KnowledgeBase) -> HashMap<String, usize> {
145        let mut counts = HashMap::new();
146        for fact in &kb.facts {
147            *counts.entry(fact.name.clone()).or_insert(0) += 1;
148        }
149        counts
150    }
151
152    /// Get all unique predicate names in the knowledge base
153    pub fn predicate_names(kb: &KnowledgeBase) -> Vec<String> {
154        let mut names: Vec<String> = kb
155            .facts
156            .iter()
157            .map(|p| p.name.clone())
158            .chain(kb.rules.iter().map(|r| r.head.name.clone()))
159            .collect();
160        names.sort_unstable();
161        names.dedup();
162        names
163    }
164
165    /// Check if a fact exists in the knowledge base
166    pub fn contains_fact(kb: &KnowledgeBase, fact: &Predicate) -> bool {
167        kb.facts.contains(fact)
168    }
169
170    /// Remove duplicate facts from a knowledge base
171    pub fn deduplicate(kb: &mut KnowledgeBase) {
172        kb.facts.sort_by(|a, b| {
173            a.name
174                .cmp(&b.name)
175                .then_with(|| a.args.len().cmp(&b.args.len()))
176        });
177        kb.facts.dedup();
178    }
179}
180
181/// Utility functions for query operations
182pub struct QueryUtils;
183
184impl QueryUtils {
185    /// Execute a simple query and return only the first solution
186    pub fn query_one(predicate: &Predicate, kb: &KnowledgeBase) -> Result<Option<Substitution>> {
187        let engine = InferenceEngine::new();
188        let solutions = engine.query(predicate, kb)?;
189        Ok(solutions.into_iter().next())
190    }
191
192    /// Execute a query and extract values for a specific variable
193    pub fn query_var(
194        predicate: &Predicate,
195        kb: &KnowledgeBase,
196        var_name: &str,
197    ) -> Result<Vec<Term>> {
198        let engine = InferenceEngine::new();
199        let solutions = engine.query(predicate, kb)?;
200        Ok(solutions
201            .into_iter()
202            .filter_map(|subst| subst.get(var_name).cloned())
203            .collect())
204    }
205
206    /// Check if a goal is provable
207    pub fn is_provable(predicate: &Predicate, kb: &KnowledgeBase) -> Result<bool> {
208        let engine = InferenceEngine::new();
209        let solutions = engine.query(predicate, kb)?;
210        Ok(!solutions.is_empty())
211    }
212
213    /// Count the number of solutions for a query
214    pub fn count_solutions(predicate: &Predicate, kb: &KnowledgeBase) -> Result<usize> {
215        let engine = InferenceEngine::new();
216        let solutions = engine.query(predicate, kb)?;
217        Ok(solutions.len())
218    }
219}
220
221/// Utility functions for term manipulation
222pub struct TermUtils;
223
224impl TermUtils {
225    /// Create a constant string term
226    pub fn string(value: impl Into<String>) -> Term {
227        Term::Const(Constant::String(value.into()))
228    }
229
230    /// Create a constant integer term
231    pub fn int(value: i64) -> Term {
232        Term::Const(Constant::Int(value))
233    }
234
235    /// Create a constant boolean term
236    pub fn bool(value: bool) -> Term {
237        Term::Const(Constant::Bool(value))
238    }
239
240    /// Create a variable term
241    pub fn var(name: impl Into<String>) -> Term {
242        Term::Var(name.into())
243    }
244
245    /// Extract string value from a constant term
246    pub fn as_string(term: &Term) -> Option<&str> {
247        match term {
248            Term::Const(Constant::String(s)) => Some(s),
249            _ => None,
250        }
251    }
252
253    /// Extract integer value from a constant term
254    pub fn as_int(term: &Term) -> Option<i64> {
255        match term {
256            Term::Const(Constant::Int(i)) => Some(*i),
257            _ => None,
258        }
259    }
260
261    /// Extract boolean value from a constant term
262    pub fn as_bool(term: &Term) -> Option<bool> {
263        match term {
264            Term::Const(Constant::Bool(b)) => Some(*b),
265            _ => None,
266        }
267    }
268
269    /// Check if term is ground (contains no variables)
270    pub fn is_ground(term: &Term) -> bool {
271        term.is_ground()
272    }
273
274    /// Get all variables in a term
275    pub fn variables(term: &Term) -> Vec<String> {
276        term.variables()
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_predicate_builder() {
286        let pred = PredicateBuilder::new("parent")
287            .arg_str("alice")
288            .arg_str("bob")
289            .build();
290
291        assert_eq!(pred.name, "parent");
292        assert_eq!(pred.args.len(), 2);
293        assert!(pred.is_ground());
294    }
295
296    #[test]
297    fn test_predicate_builder_with_vars() {
298        let pred = PredicateBuilder::new("parent")
299            .arg_str("alice")
300            .arg_var("X")
301            .build();
302
303        assert_eq!(pred.name, "parent");
304        assert_eq!(pred.args.len(), 2);
305        assert!(!pred.is_ground());
306    }
307
308    #[test]
309    fn test_rule_builder() {
310        let head = PredicateBuilder::new("grandparent")
311            .arg_var("X")
312            .arg_var("Z")
313            .build();
314
315        let body1 = PredicateBuilder::new("parent")
316            .arg_var("X")
317            .arg_var("Y")
318            .build();
319
320        let body2 = PredicateBuilder::new("parent")
321            .arg_var("Y")
322            .arg_var("Z")
323            .build();
324
325        let rule = RuleBuilder::new()
326            .head(head)
327            .body(body1)
328            .body(body2)
329            .build();
330
331        assert_eq!(rule.head.name, "grandparent");
332        assert_eq!(rule.body.len(), 2);
333    }
334
335    #[test]
336    fn test_kb_from_facts() {
337        let facts = vec![
338            PredicateBuilder::new("parent")
339                .arg_str("alice")
340                .arg_str("bob")
341                .build(),
342            PredicateBuilder::new("parent")
343                .arg_str("bob")
344                .arg_str("charlie")
345                .build(),
346        ];
347
348        let kb = KnowledgeBaseUtils::from_facts(facts);
349        assert_eq!(kb.facts.len(), 2);
350    }
351
352    #[test]
353    fn test_kb_merge() {
354        let mut kb1 = KnowledgeBase::new();
355        kb1.add_fact(PredicateBuilder::new("test").arg_str("a").build());
356
357        let mut kb2 = KnowledgeBase::new();
358        kb2.add_fact(PredicateBuilder::new("test").arg_str("b").build());
359
360        let merged = KnowledgeBaseUtils::merge(&kb1, &kb2);
361        assert_eq!(merged.facts.len(), 2);
362    }
363
364    #[test]
365    fn test_filter_facts() {
366        let mut kb = KnowledgeBase::new();
367        kb.add_fact(
368            PredicateBuilder::new("parent")
369                .arg_str("a")
370                .arg_str("b")
371                .build(),
372        );
373        kb.add_fact(
374            PredicateBuilder::new("parent")
375                .arg_str("b")
376                .arg_str("c")
377                .build(),
378        );
379        kb.add_fact(
380            PredicateBuilder::new("knows")
381                .arg_str("a")
382                .arg_str("c")
383                .build(),
384        );
385
386        let parents = KnowledgeBaseUtils::filter_facts(&kb, "parent");
387        assert_eq!(parents.len(), 2);
388
389        let knows = KnowledgeBaseUtils::filter_facts(&kb, "knows");
390        assert_eq!(knows.len(), 1);
391    }
392
393    #[test]
394    fn test_count_predicates() {
395        let mut kb = KnowledgeBase::new();
396        kb.add_fact(
397            PredicateBuilder::new("parent")
398                .arg_str("a")
399                .arg_str("b")
400                .build(),
401        );
402        kb.add_fact(
403            PredicateBuilder::new("parent")
404                .arg_str("b")
405                .arg_str("c")
406                .build(),
407        );
408        kb.add_fact(
409            PredicateBuilder::new("knows")
410                .arg_str("a")
411                .arg_str("c")
412                .build(),
413        );
414
415        let counts = KnowledgeBaseUtils::count_predicates(&kb);
416        assert_eq!(counts.get("parent"), Some(&2));
417        assert_eq!(counts.get("knows"), Some(&1));
418    }
419
420    #[test]
421    fn test_term_utils() {
422        let str_term = TermUtils::string("alice");
423        assert_eq!(TermUtils::as_string(&str_term), Some("alice"));
424
425        let int_term = TermUtils::int(42);
426        assert_eq!(TermUtils::as_int(&int_term), Some(42));
427
428        let bool_term = TermUtils::bool(true);
429        assert_eq!(TermUtils::as_bool(&bool_term), Some(true));
430
431        let var_term = TermUtils::var("X");
432        assert!(!TermUtils::is_ground(&var_term));
433    }
434
435    #[test]
436    fn test_query_utils() {
437        let mut kb = KnowledgeBase::new();
438        kb.add_fact(
439            PredicateBuilder::new("parent")
440                .arg_str("alice")
441                .arg_str("bob")
442                .build(),
443        );
444
445        let query = PredicateBuilder::new("parent")
446            .arg_str("alice")
447            .arg_var("X")
448            .build();
449
450        let is_provable = QueryUtils::is_provable(&query, &kb).unwrap();
451        assert!(is_provable);
452
453        let count = QueryUtils::count_solutions(&query, &kb).unwrap();
454        assert_eq!(count, 1);
455    }
456}