ipfrs_tensorlogic/
utils.rs1use crate::ir::{Constant, KnowledgeBase, Predicate, Rule, Term};
7use crate::reasoning::{InferenceEngine, Substitution};
8use ipfrs_core::error::Result;
9use std::collections::HashMap;
10
11pub struct PredicateBuilder {
13 name: String,
14 args: Vec<Term>,
15}
16
17impl PredicateBuilder {
18 pub fn new(name: impl Into<String>) -> Self {
20 Self {
21 name: name.into(),
22 args: Vec::new(),
23 }
24 }
25
26 pub fn arg_str(mut self, value: impl Into<String>) -> Self {
28 self.args.push(Term::Const(Constant::String(value.into())));
29 self
30 }
31
32 pub fn arg_int(mut self, value: i64) -> Self {
34 self.args.push(Term::Const(Constant::Int(value)));
35 self
36 }
37
38 pub fn arg_bool(mut self, value: bool) -> Self {
40 self.args.push(Term::Const(Constant::Bool(value)));
41 self
42 }
43
44 pub fn arg_var(mut self, name: impl Into<String>) -> Self {
46 self.args.push(Term::Var(name.into()));
47 self
48 }
49
50 pub fn arg(mut self, term: Term) -> Self {
52 self.args.push(term);
53 self
54 }
55
56 pub fn build(self) -> Predicate {
58 Predicate::new(self.name, self.args)
59 }
60}
61
62pub struct RuleBuilder {
64 head: Option<Predicate>,
65 body: Vec<Predicate>,
66}
67
68impl RuleBuilder {
69 pub fn new() -> Self {
71 Self {
72 head: None,
73 body: Vec::new(),
74 }
75 }
76
77 pub fn head(mut self, predicate: Predicate) -> Self {
79 self.head = Some(predicate);
80 self
81 }
82
83 pub fn body(mut self, predicate: Predicate) -> Self {
85 self.body.push(predicate);
86 self
87 }
88
89 pub fn bodies(mut self, predicates: Vec<Predicate>) -> Self {
91 self.body.extend(predicates);
92 self
93 }
94
95 pub fn build(self) -> Rule {
97 Rule::new(self.head.expect("Rule head must be set"), self.body)
98 }
99}
100
101impl Default for RuleBuilder {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107pub struct KnowledgeBaseUtils;
109
110impl KnowledgeBaseUtils {
111 pub fn from_facts(facts: Vec<Predicate>) -> KnowledgeBase {
113 let mut kb = KnowledgeBase::new();
114 for fact in facts {
115 kb.add_fact(fact);
116 }
117 kb
118 }
119
120 pub fn merge(kb1: &KnowledgeBase, kb2: &KnowledgeBase) -> KnowledgeBase {
122 let mut merged = kb1.clone();
123 for fact in &kb2.facts {
124 if !merged.facts.contains(fact) {
125 merged.add_fact(fact.clone());
126 }
127 }
128 for rule in &kb2.rules {
129 merged.add_rule(rule.clone());
130 }
131 merged
132 }
133
134 pub fn filter_facts(kb: &KnowledgeBase, predicate_name: &str) -> Vec<Predicate> {
136 kb.facts
137 .iter()
138 .filter(|p| p.name == predicate_name)
139 .cloned()
140 .collect()
141 }
142
143 pub fn count_predicates(kb: &KnowledgeBase) -> HashMap<String, usize> {
145 let mut counts = HashMap::new();
146 for fact in &kb.facts {
147 *counts.entry(fact.name.clone()).or_insert(0) += 1;
148 }
149 counts
150 }
151
152 pub fn predicate_names(kb: &KnowledgeBase) -> Vec<String> {
154 let mut names: Vec<String> = kb
155 .facts
156 .iter()
157 .map(|p| p.name.clone())
158 .chain(kb.rules.iter().map(|r| r.head.name.clone()))
159 .collect();
160 names.sort_unstable();
161 names.dedup();
162 names
163 }
164
165 pub fn contains_fact(kb: &KnowledgeBase, fact: &Predicate) -> bool {
167 kb.facts.contains(fact)
168 }
169
170 pub fn deduplicate(kb: &mut KnowledgeBase) {
172 kb.facts.sort_by(|a, b| {
173 a.name
174 .cmp(&b.name)
175 .then_with(|| a.args.len().cmp(&b.args.len()))
176 });
177 kb.facts.dedup();
178 }
179}
180
181pub struct QueryUtils;
183
184impl QueryUtils {
185 pub fn query_one(predicate: &Predicate, kb: &KnowledgeBase) -> Result<Option<Substitution>> {
187 let engine = InferenceEngine::new();
188 let solutions = engine.query(predicate, kb)?;
189 Ok(solutions.into_iter().next())
190 }
191
192 pub fn query_var(
194 predicate: &Predicate,
195 kb: &KnowledgeBase,
196 var_name: &str,
197 ) -> Result<Vec<Term>> {
198 let engine = InferenceEngine::new();
199 let solutions = engine.query(predicate, kb)?;
200 Ok(solutions
201 .into_iter()
202 .filter_map(|subst| subst.get(var_name).cloned())
203 .collect())
204 }
205
206 pub fn is_provable(predicate: &Predicate, kb: &KnowledgeBase) -> Result<bool> {
208 let engine = InferenceEngine::new();
209 let solutions = engine.query(predicate, kb)?;
210 Ok(!solutions.is_empty())
211 }
212
213 pub fn count_solutions(predicate: &Predicate, kb: &KnowledgeBase) -> Result<usize> {
215 let engine = InferenceEngine::new();
216 let solutions = engine.query(predicate, kb)?;
217 Ok(solutions.len())
218 }
219}
220
221pub struct TermUtils;
223
224impl TermUtils {
225 pub fn string(value: impl Into<String>) -> Term {
227 Term::Const(Constant::String(value.into()))
228 }
229
230 pub fn int(value: i64) -> Term {
232 Term::Const(Constant::Int(value))
233 }
234
235 pub fn bool(value: bool) -> Term {
237 Term::Const(Constant::Bool(value))
238 }
239
240 pub fn var(name: impl Into<String>) -> Term {
242 Term::Var(name.into())
243 }
244
245 pub fn as_string(term: &Term) -> Option<&str> {
247 match term {
248 Term::Const(Constant::String(s)) => Some(s),
249 _ => None,
250 }
251 }
252
253 pub fn as_int(term: &Term) -> Option<i64> {
255 match term {
256 Term::Const(Constant::Int(i)) => Some(*i),
257 _ => None,
258 }
259 }
260
261 pub fn as_bool(term: &Term) -> Option<bool> {
263 match term {
264 Term::Const(Constant::Bool(b)) => Some(*b),
265 _ => None,
266 }
267 }
268
269 pub fn is_ground(term: &Term) -> bool {
271 term.is_ground()
272 }
273
274 pub fn variables(term: &Term) -> Vec<String> {
276 term.variables()
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_predicate_builder() {
286 let pred = PredicateBuilder::new("parent")
287 .arg_str("alice")
288 .arg_str("bob")
289 .build();
290
291 assert_eq!(pred.name, "parent");
292 assert_eq!(pred.args.len(), 2);
293 assert!(pred.is_ground());
294 }
295
296 #[test]
297 fn test_predicate_builder_with_vars() {
298 let pred = PredicateBuilder::new("parent")
299 .arg_str("alice")
300 .arg_var("X")
301 .build();
302
303 assert_eq!(pred.name, "parent");
304 assert_eq!(pred.args.len(), 2);
305 assert!(!pred.is_ground());
306 }
307
308 #[test]
309 fn test_rule_builder() {
310 let head = PredicateBuilder::new("grandparent")
311 .arg_var("X")
312 .arg_var("Z")
313 .build();
314
315 let body1 = PredicateBuilder::new("parent")
316 .arg_var("X")
317 .arg_var("Y")
318 .build();
319
320 let body2 = PredicateBuilder::new("parent")
321 .arg_var("Y")
322 .arg_var("Z")
323 .build();
324
325 let rule = RuleBuilder::new()
326 .head(head)
327 .body(body1)
328 .body(body2)
329 .build();
330
331 assert_eq!(rule.head.name, "grandparent");
332 assert_eq!(rule.body.len(), 2);
333 }
334
335 #[test]
336 fn test_kb_from_facts() {
337 let facts = vec![
338 PredicateBuilder::new("parent")
339 .arg_str("alice")
340 .arg_str("bob")
341 .build(),
342 PredicateBuilder::new("parent")
343 .arg_str("bob")
344 .arg_str("charlie")
345 .build(),
346 ];
347
348 let kb = KnowledgeBaseUtils::from_facts(facts);
349 assert_eq!(kb.facts.len(), 2);
350 }
351
352 #[test]
353 fn test_kb_merge() {
354 let mut kb1 = KnowledgeBase::new();
355 kb1.add_fact(PredicateBuilder::new("test").arg_str("a").build());
356
357 let mut kb2 = KnowledgeBase::new();
358 kb2.add_fact(PredicateBuilder::new("test").arg_str("b").build());
359
360 let merged = KnowledgeBaseUtils::merge(&kb1, &kb2);
361 assert_eq!(merged.facts.len(), 2);
362 }
363
364 #[test]
365 fn test_filter_facts() {
366 let mut kb = KnowledgeBase::new();
367 kb.add_fact(
368 PredicateBuilder::new("parent")
369 .arg_str("a")
370 .arg_str("b")
371 .build(),
372 );
373 kb.add_fact(
374 PredicateBuilder::new("parent")
375 .arg_str("b")
376 .arg_str("c")
377 .build(),
378 );
379 kb.add_fact(
380 PredicateBuilder::new("knows")
381 .arg_str("a")
382 .arg_str("c")
383 .build(),
384 );
385
386 let parents = KnowledgeBaseUtils::filter_facts(&kb, "parent");
387 assert_eq!(parents.len(), 2);
388
389 let knows = KnowledgeBaseUtils::filter_facts(&kb, "knows");
390 assert_eq!(knows.len(), 1);
391 }
392
393 #[test]
394 fn test_count_predicates() {
395 let mut kb = KnowledgeBase::new();
396 kb.add_fact(
397 PredicateBuilder::new("parent")
398 .arg_str("a")
399 .arg_str("b")
400 .build(),
401 );
402 kb.add_fact(
403 PredicateBuilder::new("parent")
404 .arg_str("b")
405 .arg_str("c")
406 .build(),
407 );
408 kb.add_fact(
409 PredicateBuilder::new("knows")
410 .arg_str("a")
411 .arg_str("c")
412 .build(),
413 );
414
415 let counts = KnowledgeBaseUtils::count_predicates(&kb);
416 assert_eq!(counts.get("parent"), Some(&2));
417 assert_eq!(counts.get("knows"), Some(&1));
418 }
419
420 #[test]
421 fn test_term_utils() {
422 let str_term = TermUtils::string("alice");
423 assert_eq!(TermUtils::as_string(&str_term), Some("alice"));
424
425 let int_term = TermUtils::int(42);
426 assert_eq!(TermUtils::as_int(&int_term), Some(42));
427
428 let bool_term = TermUtils::bool(true);
429 assert_eq!(TermUtils::as_bool(&bool_term), Some(true));
430
431 let var_term = TermUtils::var("X");
432 assert!(!TermUtils::is_ground(&var_term));
433 }
434
435 #[test]
436 fn test_query_utils() {
437 let mut kb = KnowledgeBase::new();
438 kb.add_fact(
439 PredicateBuilder::new("parent")
440 .arg_str("alice")
441 .arg_str("bob")
442 .build(),
443 );
444
445 let query = PredicateBuilder::new("parent")
446 .arg_str("alice")
447 .arg_var("X")
448 .build();
449
450 let is_provable = QueryUtils::is_provable(&query, &kb).unwrap();
451 assert!(is_provable);
452
453 let count = QueryUtils::count_solutions(&query, &kb).unwrap();
454 assert_eq!(count, 1);
455 }
456}