1use ipfrs_core::Cid;
7use serde::{Deserialize, Serialize};
8use std::fmt;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
12pub enum Term {
13 Var(String),
15 Const(Constant),
17 Fun(String, Vec<Term>),
19 Ref(TermRef),
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
25pub enum Constant {
26 String(String),
28 Int(i64),
30 Bool(bool),
32 Float(String),
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
38pub struct TermRef {
39 #[serde(
41 serialize_with = "crate::serialize_cid",
42 deserialize_with = "crate::deserialize_cid"
43 )]
44 pub cid: Cid,
45 pub hint: Option<String>,
47}
48
49impl TermRef {
50 pub fn new(cid: Cid) -> Self {
52 Self { cid, hint: None }
53 }
54
55 pub fn with_hint(cid: Cid, hint: String) -> Self {
57 Self {
58 cid,
59 hint: Some(hint),
60 }
61 }
62}
63
64impl Term {
65 #[inline]
67 pub fn is_var(&self) -> bool {
68 matches!(self, Term::Var(_))
69 }
70
71 #[inline]
73 pub fn is_const(&self) -> bool {
74 matches!(self, Term::Const(_))
75 }
76
77 #[inline]
79 pub fn is_ground(&self) -> bool {
80 match self {
81 Term::Var(_) => false,
82 Term::Const(_) => true,
83 Term::Fun(_, args) => args.iter().all(|t| t.is_ground()),
84 Term::Ref(_) => true, }
86 }
87
88 pub fn variables(&self) -> Vec<String> {
90 let capacity = self.estimate_var_count();
91 let mut vars = Vec::with_capacity(capacity);
92 self.collect_vars(&mut vars);
93 vars.sort_unstable();
94 vars.dedup();
95 vars
96 }
97
98 #[inline]
100 fn estimate_var_count(&self) -> usize {
101 match self {
102 Term::Var(_) => 1,
103 Term::Const(_) | Term::Ref(_) => 0,
104 Term::Fun(_, args) => args.iter().map(|t| t.estimate_var_count()).sum(),
105 }
106 }
107
108 #[inline]
109 fn collect_vars(&self, vars: &mut Vec<String>) {
110 match self {
111 Term::Var(v) => vars.push(v.clone()),
112 Term::Fun(_, args) => {
113 for arg in args {
114 arg.collect_vars(vars);
115 }
116 }
117 _ => {}
118 }
119 }
120
121 #[inline]
123 pub fn complexity(&self) -> usize {
124 match self {
125 Term::Var(_) | Term::Const(_) | Term::Ref(_) => 1,
126 Term::Fun(_, args) => 1 + args.iter().map(|t| t.complexity()).sum::<usize>(),
127 }
128 }
129}
130
131impl fmt::Display for Term {
132 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133 match self {
134 Term::Var(v) => write!(f, "?{}", v),
135 Term::Const(c) => write!(f, "{}", c),
136 Term::Fun(name, args) => {
137 write!(f, "{}(", name)?;
138 for (i, arg) in args.iter().enumerate() {
139 if i > 0 {
140 write!(f, ", ")?;
141 }
142 write!(f, "{}", arg)?;
143 }
144 write!(f, ")")
145 }
146 Term::Ref(r) => write!(f, "@{}", r.cid),
147 }
148 }
149}
150
151impl fmt::Display for Constant {
152 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153 match self {
154 Constant::String(s) => write!(f, "\"{}\"", s),
155 Constant::Int(i) => write!(f, "{}", i),
156 Constant::Bool(b) => write!(f, "{}", b),
157 Constant::Float(s) => write!(f, "{}", s),
158 }
159 }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
164pub struct Predicate {
165 pub name: String,
167 pub args: Vec<Term>,
169}
170
171impl Predicate {
172 pub fn new(name: String, args: Vec<Term>) -> Self {
174 Self { name, args }
175 }
176
177 #[inline]
179 pub fn arity(&self) -> usize {
180 self.args.len()
181 }
182
183 #[inline]
185 pub fn is_ground(&self) -> bool {
186 self.args.iter().all(|t| t.is_ground())
187 }
188
189 pub fn variables(&self) -> Vec<String> {
191 let capacity: usize = self.args.iter().map(|t| t.estimate_var_count()).sum();
192 let mut vars = Vec::with_capacity(capacity);
193 for arg in &self.args {
194 arg.collect_vars(&mut vars);
195 }
196 vars.sort_unstable();
197 vars.dedup();
198 vars
199 }
200}
201
202impl fmt::Display for Predicate {
203 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
204 write!(f, "{}(", self.name)?;
205 for (i, arg) in self.args.iter().enumerate() {
206 if i > 0 {
207 write!(f, ", ")?;
208 }
209 write!(f, "{}", arg)?;
210 }
211 write!(f, ")")
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct Rule {
218 pub head: Predicate,
220 pub body: Vec<Predicate>,
222}
223
224impl Rule {
225 pub fn new(head: Predicate, body: Vec<Predicate>) -> Self {
227 Self { head, body }
228 }
229
230 pub fn fact(head: Predicate) -> Self {
232 Self {
233 head,
234 body: Vec::new(),
235 }
236 }
237
238 #[inline]
240 pub fn is_fact(&self) -> bool {
241 self.body.is_empty()
242 }
243
244 pub fn variables(&self) -> Vec<String> {
246 let mut vars = self.head.variables();
247 for pred in &self.body {
248 for var in pred.variables() {
249 if !vars.contains(&var) {
250 vars.push(var);
251 }
252 }
253 }
254 vars.sort_unstable();
255 vars
256 }
257}
258
259impl fmt::Display for Rule {
260 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
261 write!(f, "{}", self.head)?;
262 if !self.body.is_empty() {
263 write!(f, " :- ")?;
264 for (i, pred) in self.body.iter().enumerate() {
265 if i > 0 {
266 write!(f, ", ")?;
267 }
268 write!(f, "{}", pred)?;
269 }
270 }
271 write!(f, ".")
272 }
273}
274
275#[derive(Debug, Clone, Default, Serialize, Deserialize)]
277pub struct KnowledgeBase {
278 pub facts: Vec<Predicate>,
280 pub rules: Vec<Rule>,
282}
283
284impl KnowledgeBase {
285 pub fn new() -> Self {
287 Self::default()
288 }
289
290 pub fn add_fact(&mut self, fact: Predicate) {
292 self.facts.push(fact);
293 }
294
295 pub fn add_rule(&mut self, rule: Rule) {
297 self.rules.push(rule);
298 }
299
300 #[inline]
302 pub fn get_predicates(&self, name: &str) -> Vec<&Predicate> {
303 self.facts.iter().filter(|p| p.name == name).collect()
304 }
305
306 #[inline]
308 pub fn get_rules(&self, name: &str) -> Vec<&Rule> {
309 self.rules.iter().filter(|r| r.head.name == name).collect()
310 }
311
312 pub fn stats(&self) -> KnowledgeBaseStats {
314 KnowledgeBaseStats {
315 num_facts: self.facts.len(),
316 num_rules: self.rules.len(),
317 }
318 }
319}
320
321#[derive(Debug, Clone)]
323pub struct KnowledgeBaseStats {
324 pub num_facts: usize,
326 pub num_rules: usize,
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_term_creation() {
336 let var = Term::Var("X".to_string());
337 assert!(var.is_var());
338 assert!(!var.is_ground());
339
340 let const_term = Term::Const(Constant::String("Alice".to_string()));
341 assert!(const_term.is_const());
342 assert!(const_term.is_ground());
343 }
344
345 #[test]
346 fn test_predicate() {
347 let pred = Predicate::new(
348 "parent".to_string(),
349 vec![
350 Term::Const(Constant::String("Alice".to_string())),
351 Term::Var("X".to_string()),
352 ],
353 );
354
355 assert_eq!(pred.arity(), 2);
356 assert!(!pred.is_ground());
357 assert_eq!(pred.variables(), vec!["X".to_string()]);
358 }
359
360 #[test]
361 fn test_rule() {
362 let head = Predicate::new(
363 "grandparent".to_string(),
364 vec![Term::Var("X".to_string()), Term::Var("Z".to_string())],
365 );
366
367 let body = vec![
368 Predicate::new(
369 "parent".to_string(),
370 vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
371 ),
372 Predicate::new(
373 "parent".to_string(),
374 vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
375 ),
376 ];
377
378 let rule = Rule::new(head, body);
379 assert!(!rule.is_fact());
380 assert_eq!(
381 rule.variables(),
382 vec!["X".to_string(), "Y".to_string(), "Z".to_string()]
383 );
384 }
385
386 #[test]
387 fn test_knowledge_base() {
388 let mut kb = KnowledgeBase::new();
389
390 kb.add_fact(Predicate::new(
391 "parent".to_string(),
392 vec![
393 Term::Const(Constant::String("Alice".to_string())),
394 Term::Const(Constant::String("Bob".to_string())),
395 ],
396 ));
397
398 let stats = kb.stats();
399 assert_eq!(stats.num_facts, 1);
400 assert_eq!(stats.num_rules, 0);
401 }
402}