ipfrs_tensorlogic/
ir.rs

1//! TensorLogic IR types
2//!
3//! This module defines the Intermediate Representation types for TensorLogic
4//! that can be stored and retrieved via IPFRS.
5
6use ipfrs_core::Cid;
7use serde::{Deserialize, Serialize};
8use std::fmt;
9
10/// A logical term in TensorLogic
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
12pub enum Term {
13    /// Variable (e.g., ?X)
14    Var(String),
15    /// Constant value
16    Const(Constant),
17    /// Function application (e.g., f(X, Y))
18    Fun(String, Vec<Term>),
19    /// Reference to another term via CID
20    Ref(TermRef),
21}
22
23/// Constant value types
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
25pub enum Constant {
26    /// String constant
27    String(String),
28    /// Integer constant
29    Int(i64),
30    /// Boolean constant
31    Bool(bool),
32    /// Floating point constant (stored as string for deterministic hashing)
33    Float(String),
34}
35
36/// Reference to a term via CID
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
38pub struct TermRef {
39    /// CID of the referenced term
40    #[serde(
41        serialize_with = "crate::serialize_cid",
42        deserialize_with = "crate::deserialize_cid"
43    )]
44    pub cid: Cid,
45    /// Optional hint about the term (for optimization)
46    pub hint: Option<String>,
47}
48
49impl TermRef {
50    /// Create a new term reference
51    pub fn new(cid: Cid) -> Self {
52        Self { cid, hint: None }
53    }
54
55    /// Create a term reference with a hint
56    pub fn with_hint(cid: Cid, hint: String) -> Self {
57        Self {
58            cid,
59            hint: Some(hint),
60        }
61    }
62}
63
64impl Term {
65    /// Check if term is a variable
66    #[inline]
67    pub fn is_var(&self) -> bool {
68        matches!(self, Term::Var(_))
69    }
70
71    /// Check if term is a constant
72    #[inline]
73    pub fn is_const(&self) -> bool {
74        matches!(self, Term::Const(_))
75    }
76
77    /// Check if term is ground (contains no variables)
78    #[inline]
79    pub fn is_ground(&self) -> bool {
80        match self {
81            Term::Var(_) => false,
82            Term::Const(_) => true,
83            Term::Fun(_, args) => args.iter().all(|t| t.is_ground()),
84            Term::Ref(_) => true, // References are considered ground
85        }
86    }
87
88    /// Collect all variables in the term
89    pub fn variables(&self) -> Vec<String> {
90        let capacity = self.estimate_var_count();
91        let mut vars = Vec::with_capacity(capacity);
92        self.collect_vars(&mut vars);
93        vars.sort_unstable();
94        vars.dedup();
95        vars
96    }
97
98    /// Estimate the number of unique variables (for capacity hint)
99    #[inline]
100    fn estimate_var_count(&self) -> usize {
101        match self {
102            Term::Var(_) => 1,
103            Term::Const(_) | Term::Ref(_) => 0,
104            Term::Fun(_, args) => args.iter().map(|t| t.estimate_var_count()).sum(),
105        }
106    }
107
108    #[inline]
109    fn collect_vars(&self, vars: &mut Vec<String>) {
110        match self {
111            Term::Var(v) => vars.push(v.clone()),
112            Term::Fun(_, args) => {
113                for arg in args {
114                    arg.collect_vars(vars);
115                }
116            }
117            _ => {}
118        }
119    }
120
121    /// Get the complexity of the term (number of nodes)
122    #[inline]
123    pub fn complexity(&self) -> usize {
124        match self {
125            Term::Var(_) | Term::Const(_) | Term::Ref(_) => 1,
126            Term::Fun(_, args) => 1 + args.iter().map(|t| t.complexity()).sum::<usize>(),
127        }
128    }
129}
130
131impl fmt::Display for Term {
132    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133        match self {
134            Term::Var(v) => write!(f, "?{}", v),
135            Term::Const(c) => write!(f, "{}", c),
136            Term::Fun(name, args) => {
137                write!(f, "{}(", name)?;
138                for (i, arg) in args.iter().enumerate() {
139                    if i > 0 {
140                        write!(f, ", ")?;
141                    }
142                    write!(f, "{}", arg)?;
143                }
144                write!(f, ")")
145            }
146            Term::Ref(r) => write!(f, "@{}", r.cid),
147        }
148    }
149}
150
151impl fmt::Display for Constant {
152    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153        match self {
154            Constant::String(s) => write!(f, "\"{}\"", s),
155            Constant::Int(i) => write!(f, "{}", i),
156            Constant::Bool(b) => write!(f, "{}", b),
157            Constant::Float(s) => write!(f, "{}", s),
158        }
159    }
160}
161
162/// A logical predicate
163#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
164pub struct Predicate {
165    /// Predicate name
166    pub name: String,
167    /// Arguments
168    pub args: Vec<Term>,
169}
170
171impl Predicate {
172    /// Create a new predicate
173    pub fn new(name: String, args: Vec<Term>) -> Self {
174        Self { name, args }
175    }
176
177    /// Get the arity (number of arguments)
178    #[inline]
179    pub fn arity(&self) -> usize {
180        self.args.len()
181    }
182
183    /// Check if predicate is ground
184    #[inline]
185    pub fn is_ground(&self) -> bool {
186        self.args.iter().all(|t| t.is_ground())
187    }
188
189    /// Collect all variables
190    pub fn variables(&self) -> Vec<String> {
191        let capacity: usize = self.args.iter().map(|t| t.estimate_var_count()).sum();
192        let mut vars = Vec::with_capacity(capacity);
193        for arg in &self.args {
194            arg.collect_vars(&mut vars);
195        }
196        vars.sort_unstable();
197        vars.dedup();
198        vars
199    }
200}
201
202impl fmt::Display for Predicate {
203    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
204        write!(f, "{}(", self.name)?;
205        for (i, arg) in self.args.iter().enumerate() {
206            if i > 0 {
207                write!(f, ", ")?;
208            }
209            write!(f, "{}", arg)?;
210        }
211        write!(f, ")")
212    }
213}
214
215/// A logical rule (Horn clause)
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct Rule {
218    /// Head of the rule
219    pub head: Predicate,
220    /// Body of the rule (conjunction)
221    pub body: Vec<Predicate>,
222}
223
224impl Rule {
225    /// Create a new rule
226    pub fn new(head: Predicate, body: Vec<Predicate>) -> Self {
227        Self { head, body }
228    }
229
230    /// Create a fact (rule with empty body)
231    pub fn fact(head: Predicate) -> Self {
232        Self {
233            head,
234            body: Vec::new(),
235        }
236    }
237
238    /// Check if this is a fact
239    #[inline]
240    pub fn is_fact(&self) -> bool {
241        self.body.is_empty()
242    }
243
244    /// Collect all variables in the rule
245    pub fn variables(&self) -> Vec<String> {
246        let mut vars = self.head.variables();
247        for pred in &self.body {
248            for var in pred.variables() {
249                if !vars.contains(&var) {
250                    vars.push(var);
251                }
252            }
253        }
254        vars.sort_unstable();
255        vars
256    }
257}
258
259impl fmt::Display for Rule {
260    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
261        write!(f, "{}", self.head)?;
262        if !self.body.is_empty() {
263            write!(f, " :- ")?;
264            for (i, pred) in self.body.iter().enumerate() {
265                if i > 0 {
266                    write!(f, ", ")?;
267                }
268                write!(f, "{}", pred)?;
269            }
270        }
271        write!(f, ".")
272    }
273}
274
275/// A knowledge base containing facts and rules
276#[derive(Debug, Clone, Default, Serialize, Deserialize)]
277pub struct KnowledgeBase {
278    /// Facts (ground predicates)
279    pub facts: Vec<Predicate>,
280    /// Rules
281    pub rules: Vec<Rule>,
282}
283
284impl KnowledgeBase {
285    /// Create a new empty knowledge base
286    pub fn new() -> Self {
287        Self::default()
288    }
289
290    /// Add a fact
291    pub fn add_fact(&mut self, fact: Predicate) {
292        self.facts.push(fact);
293    }
294
295    /// Add a rule
296    pub fn add_rule(&mut self, rule: Rule) {
297        self.rules.push(rule);
298    }
299
300    /// Get all predicates with a given name
301    #[inline]
302    pub fn get_predicates(&self, name: &str) -> Vec<&Predicate> {
303        self.facts.iter().filter(|p| p.name == name).collect()
304    }
305
306    /// Get all rules with a given head predicate name
307    #[inline]
308    pub fn get_rules(&self, name: &str) -> Vec<&Rule> {
309        self.rules.iter().filter(|r| r.head.name == name).collect()
310    }
311
312    /// Get statistics
313    pub fn stats(&self) -> KnowledgeBaseStats {
314        KnowledgeBaseStats {
315            num_facts: self.facts.len(),
316            num_rules: self.rules.len(),
317        }
318    }
319}
320
321/// Knowledge base statistics
322#[derive(Debug, Clone)]
323pub struct KnowledgeBaseStats {
324    /// Number of facts
325    pub num_facts: usize,
326    /// Number of rules
327    pub num_rules: usize,
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_term_creation() {
336        let var = Term::Var("X".to_string());
337        assert!(var.is_var());
338        assert!(!var.is_ground());
339
340        let const_term = Term::Const(Constant::String("Alice".to_string()));
341        assert!(const_term.is_const());
342        assert!(const_term.is_ground());
343    }
344
345    #[test]
346    fn test_predicate() {
347        let pred = Predicate::new(
348            "parent".to_string(),
349            vec![
350                Term::Const(Constant::String("Alice".to_string())),
351                Term::Var("X".to_string()),
352            ],
353        );
354
355        assert_eq!(pred.arity(), 2);
356        assert!(!pred.is_ground());
357        assert_eq!(pred.variables(), vec!["X".to_string()]);
358    }
359
360    #[test]
361    fn test_rule() {
362        let head = Predicate::new(
363            "grandparent".to_string(),
364            vec![Term::Var("X".to_string()), Term::Var("Z".to_string())],
365        );
366
367        let body = vec![
368            Predicate::new(
369                "parent".to_string(),
370                vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
371            ),
372            Predicate::new(
373                "parent".to_string(),
374                vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
375            ),
376        ];
377
378        let rule = Rule::new(head, body);
379        assert!(!rule.is_fact());
380        assert_eq!(
381            rule.variables(),
382            vec!["X".to_string(), "Y".to_string(), "Z".to_string()]
383        );
384    }
385
386    #[test]
387    fn test_knowledge_base() {
388        let mut kb = KnowledgeBase::new();
389
390        kb.add_fact(Predicate::new(
391            "parent".to_string(),
392            vec![
393                Term::Const(Constant::String("Alice".to_string())),
394                Term::Const(Constant::String("Bob".to_string())),
395            ],
396        ));
397
398        let stats = kb.stats();
399        assert_eq!(stats.num_facts, 1);
400        assert_eq!(stats.num_rules, 0);
401    }
402}