1use super::{
2 prover::{
3 format::{NamedExprView, NamedTermView},
4 Integer, NameIntMap, NameIntMapState, ProveCx, Prover,
5 },
6 repr::{ClauseId, TermStorage, TermStorageLen},
7};
8use crate::{
9 parse::{
10 repr::{Clause, ClauseDataset, Expr, Predicate, Term},
11 VAR_PREFIX,
12 },
13 prove::repr::{ExprKind, ExprView, TermView, TermViewIter},
14 Atom, IndexMap, IndexSet, Map,
15};
16use core::{
17 fmt::{self, Debug, Display, Write},
18 iter::FusedIterator,
19};
20
21pub struct Database<T> {
22 clauses: IndexMap<Predicate<Integer>, Vec<ClauseId>>,
24
25 table_clauses: IndexSet<Predicate<Integer>>,
27
28 dup_checker: DuplicateClauseChecker,
30
31 stor: TermStorage<Integer>,
33
34 prover: Prover,
36
37 nimap: NameIntMap<T>,
42
43 revert_point: Option<DatabaseState>,
47}
48
49impl<T: Atom> Database<T> {
50 pub fn new() -> Self {
51 Self {
52 clauses: IndexMap::default(),
53 table_clauses: IndexSet::default(),
54 dup_checker: DuplicateClauseChecker::default(),
55 stor: TermStorage::new(),
56 prover: Prover::new(),
57 nimap: NameIntMap::new(),
58 revert_point: None,
59 }
60 }
61
62 pub fn terms(&self) -> NamedTermViewIter<'_, T> {
63 NamedTermViewIter {
64 term_iter: self.stor.terms.terms(),
65 nimap: &self.nimap,
66 }
67 }
68
69 pub fn clauses(&self) -> ClauseIter<'_, T> {
70 ClauseIter {
71 clauses: &self.clauses,
72 stor: &self.stor,
73 nimap: &self.nimap,
74 i: 0,
75 j: 0,
76 }
77 }
78
79 pub fn insert_dataset(&mut self, dataset: ClauseDataset<T>) {
80 for clause in dataset {
81 self.insert_clause(clause);
82 }
83 }
84
85 pub fn insert_clause(&mut self, clause: Clause<T>) {
87 if self.revert_point.is_none() {
89 self.revert_point = Some(self.state());
90 }
91
92 let clause = clause.map(&mut |t| self.nimap.name_to_int(t));
93
94 if clause.needs_tabling() {
96 self.table_clauses.insert(clause.head.predicate());
97 }
98
99 if !self.dup_checker.insert(clause.clone()) {
101 return;
102 }
103
104 let key = clause.head.predicate();
105 let value = ClauseId {
106 head: self.stor.insert_term(clause.head),
107 body: clause.body.map(|expr| self.stor.insert_expr(expr)),
108 };
109
110 self.clauses
111 .entry(key)
112 .and_modify(|similar_clauses| {
113 if similar_clauses.iter().all(|clause| clause != &value) {
114 similar_clauses.push(value);
115 }
116 })
117 .or_insert(vec![value]);
118 }
119
120 pub fn query(&mut self, expr: Expr<T>) -> ProveCx<'_, T> {
121 if let Some(revert_point) = self.revert_point.take() {
123 self.revert(revert_point);
124 }
125
126 self.prover.prove(
127 expr,
128 &self.clauses,
129 &self.table_clauses,
130 &mut self.stor,
131 &mut self.nimap,
132 )
133 }
134
135 pub fn commit(&mut self) {
136 self.revert_point.take();
137 }
138
139 pub fn to_prolog<F: FnMut(&str) -> &str>(&self, sanitize: F) -> String
144 where
145 T: AsRef<str>,
146 {
147 let mut prolog_text = String::new();
148
149 let mut conv_map = ConversionMap {
150 int_to_str: Map::default(),
151 sanitized_to_suffix: Map::default(),
152 nimap: &self.nimap,
153 sanitizer: sanitize,
154 };
155
156 for clauses in self.clauses.values() {
157 for clause in clauses {
158 let head = self.stor.get_term(clause.head);
159 write_term(head, &mut conv_map, &mut prolog_text);
160
161 if let Some(body) = clause.body {
162 prolog_text.push_str(" :- ");
163
164 let body = self.stor.get_expr(body);
165 write_expr(body, &mut conv_map, &mut prolog_text);
166 }
167
168 prolog_text.push_str(".\n");
169 }
170 }
171
172 return prolog_text;
173
174 struct ConversionMap<'a, T, F> {
177 int_to_str: Map<Integer, String>,
178 sanitized_to_suffix: Map<&'a str, u32>,
180 nimap: &'a NameIntMap<T>,
181 sanitizer: F,
182 }
183
184 impl<T, F> ConversionMap<'_, T, F>
185 where
186 T: AsRef<str>,
187 F: FnMut(&str) -> &str,
188 {
189 fn int_to_str(&mut self, int: Integer) -> &str {
190 self.int_to_str.entry(int).or_insert_with(|| {
191 let name = self.nimap.get_name(&int).unwrap();
192 let name: &str = name.as_ref();
193
194 let mut is_var = false;
195
196 let name = if name.starts_with(VAR_PREFIX) {
198 is_var = true;
199 &name[1..]
200 } else {
201 name
202 };
203
204 let pure_name = (self.sanitizer)(name);
206
207 let suffix = self
208 .sanitized_to_suffix
209 .entry(pure_name)
210 .and_modify(|x| *x += 1)
211 .or_insert(0);
212
213 let mut buf = String::new();
214
215 if is_var {
216 let upper = pure_name.chars().next().unwrap().to_uppercase();
217 for c in upper {
218 buf.push(c);
219 }
220 } else {
221 let lower = pure_name.chars().next().unwrap().to_lowercase();
222 for c in lower {
223 buf.push(c);
224 }
225 };
226 buf.push_str(&pure_name[1..]);
227
228 if *suffix == 0 {
229 buf
230 } else {
231 write!(&mut buf, "_{suffix}").unwrap();
232 buf
233 }
234 })
235 }
236 }
237
238 fn write_term<T, F>(
239 term: TermView<'_, Integer>,
240 conv_map: &mut ConversionMap<'_, T, F>,
241 prolog_text: &mut String,
242 ) where
243 T: AsRef<str>,
244 F: FnMut(&str) -> &str,
245 {
246 let functor = term.functor();
247 let args = term.args();
248 let num_args = args.len();
249
250 let functor = conv_map.int_to_str(*functor);
251 prolog_text.push_str(functor);
252
253 if num_args > 0 {
254 prolog_text.push('(');
255 for (i, arg) in args.enumerate() {
256 write_term(arg, conv_map, prolog_text);
257 if i + 1 < num_args {
258 prolog_text.push_str(", ");
259 }
260 }
261 prolog_text.push(')');
262 }
263 }
264
265 fn write_expr<T, F>(
266 expr: ExprView<'_, Integer>,
267 conv_map: &mut ConversionMap<'_, T, F>,
268 prolog_text: &mut String,
269 ) where
270 T: AsRef<str>,
271 F: FnMut(&str) -> &str,
272 {
273 match expr.as_kind() {
274 ExprKind::Term(term) => {
275 write_term(term, conv_map, prolog_text);
276 }
277 ExprKind::Not(inner) => {
278 prolog_text.push_str("\\+ ");
279 if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) {
280 prolog_text.push('(');
281 write_expr(inner, conv_map, prolog_text);
282 prolog_text.push(')');
283 } else {
284 write_expr(inner, conv_map, prolog_text);
285 }
286 }
287 ExprKind::And(args) => {
288 let num_args = args.len();
289 for (i, arg) in args.enumerate() {
290 if matches!(arg.as_kind(), ExprKind::Or(_)) {
291 prolog_text.push('(');
292 write_expr(arg, conv_map, prolog_text);
293 prolog_text.push(')');
294 } else {
295 write_expr(arg, conv_map, prolog_text);
296 }
297 if i + 1 < num_args {
298 prolog_text.push_str(", ");
299 }
300 }
301 }
302 ExprKind::Or(args) => {
303 let num_args = args.len();
304 for (i, arg) in args.enumerate() {
305 write_expr(arg, conv_map, prolog_text);
306 if i + 1 < num_args {
307 prolog_text.push_str("; ");
308 }
309 }
310 }
311 }
312 }
313 }
314
315 fn revert(
316 &mut self,
317 DatabaseState {
318 clauses_len,
319 clause_set_len,
320 stor_len,
321 nimap_state,
322 }: DatabaseState,
323 ) {
324 self.clauses.truncate(clauses_len.len());
325 for (i, len) in clauses_len.into_iter().enumerate() {
326 self.clauses[i].truncate(len);
327 }
328 self.dup_checker.truncate(clause_set_len);
329 self.stor.truncate(stor_len);
330 self.nimap.revert(nimap_state);
331 }
333
334 fn state(&self) -> DatabaseState {
335 DatabaseState {
336 clauses_len: self.clauses.values().map(|v| v.len()).collect(),
337 clause_set_len: self.dup_checker.len(),
338 stor_len: self.stor.len(),
339 nimap_state: self.nimap.state(),
340 }
341 }
342}
343
344impl<T: Atom> Default for Database<T> {
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350impl<T: Debug> Debug for Database<T> {
351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352 f.debug_struct("Database")
353 .field("clauses", &self.clauses)
354 .field("dup_checker", &self.dup_checker)
355 .field("stor", &self.stor)
356 .field("nimap", &self.nimap)
357 .field("revert_point", &self.revert_point)
358 .finish_non_exhaustive()
359 }
360}
361
362#[derive(Debug, PartialEq, Eq)]
363struct DatabaseState {
364 clauses_len: Vec<usize>,
365 clause_set_len: usize,
366 stor_len: TermStorageLen,
367 nimap_state: NameIntMapState,
368}
369
370#[derive(Debug, Default)]
371struct DuplicateClauseChecker {
372 seen: IndexSet<Clause<Integer>>,
373
374 vars: Vec<Integer>,
376}
377
378impl DuplicateClauseChecker {
379 fn insert(&mut self, clause: Clause<Integer>) -> bool {
381 let canonical_clause = clause.map(&mut |t| {
382 if !t.is_variable() {
383 t
384 } else if let Some(found) = self.vars.iter().find(|&&var| var == t) {
385 *found
386 } else {
387 let next_int = self.vars.len() as u32;
388 let int = Integer::variable(next_int);
389 self.vars.push(int);
390 int
391 }
392 });
393 let is_new = self.seen.insert(canonical_clause);
394 self.vars.clear();
395 is_new
396 }
397
398 fn len(&self) -> usize {
399 self.seen.len()
400 }
401
402 fn truncate(&mut self, len: usize) {
403 self.seen.truncate(len);
404 }
405}
406
407fn _convert_var_into_num<T: Atom>(
411 this: &Clause<T>,
412 canonical_var: Option<&dyn Fn(usize) -> T>,
413) -> Option<Clause<T>> {
414 let canonical_var = canonical_var?;
415 let mut cloned: Option<Clause<T>> = None;
416
417 let mut i = 0;
418
419 while let Some(from) = find_var_in_clause(cloned.as_ref().unwrap_or(this)) {
420 let from = from.clone();
421 let canonical_t = canonical_var(i);
422
423 let mut convert = |term: &Term<T>| {
424 (term.functor == from && term.args.is_empty()).then_some(Term {
425 functor: canonical_t.clone(),
426 args: vec![],
427 })
428 };
429
430 if let Some(cloned) = &mut cloned {
431 cloned.replace_term(&mut convert);
432 } else {
433 let mut this = this.clone();
434 this.replace_term(&mut convert);
435 cloned = Some(this);
436 }
437
438 i += 1;
439 }
440
441 return cloned;
442
443 fn find_var_in_clause<T: Atom>(clause: &Clause<T>) -> Option<T> {
446 find_var_in_term(&clause.head).or_else(|| clause.body.as_ref().and_then(find_var_in_expr))
447 }
448
449 fn find_var_in_expr<T: Atom>(expr: &Expr<T>) -> Option<T> {
450 match expr {
451 Expr::Term(term) => find_var_in_term(term),
452 Expr::Not(arg) => find_var_in_expr(arg),
453 Expr::And(args) | Expr::Or(args) => args.iter().find_map(find_var_in_expr),
454 }
455 }
456
457 fn find_var_in_term<T: Atom>(term: &Term<T>) -> Option<T> {
458 if term.functor.is_variable() {
459 Some(term.functor.clone())
460 } else {
461 term.args.iter().find_map(find_var_in_term)
462 }
463 }
464}
465
466#[derive(Clone)]
467pub struct ClauseIter<'a, T> {
468 clauses: &'a IndexMap<Predicate<Integer>, Vec<ClauseId>>,
469 stor: &'a TermStorage<Integer>,
470 nimap: &'a NameIntMap<T>,
471 i: usize,
472 j: usize,
473}
474
475impl<'a, T> Iterator for ClauseIter<'a, T> {
476 type Item = ClauseRef<'a, T>;
477
478 fn next(&mut self) -> Option<Self::Item> {
479 let id = loop {
480 let (_, group) = self.clauses.get_index(self.i)?;
481
482 if let Some(id) = group.get(self.j) {
483 self.j += 1;
484 break *id;
485 }
486
487 self.i += 1;
488 self.j = 0;
489 };
490
491 Some(ClauseRef {
492 id,
493 stor: self.stor,
494 nimap: self.nimap,
495 })
496 }
497}
498
499impl<T> FusedIterator for ClauseIter<'_, T> {}
500
501pub struct ClauseRef<'a, T> {
502 id: ClauseId,
503 stor: &'a TermStorage<Integer>,
504 nimap: &'a NameIntMap<T>,
505}
506
507impl<'a, T: Atom> ClauseRef<'a, T> {
508 pub fn head(&self) -> NamedTermView<'a, T> {
509 let head = self.stor.get_term(self.id.head);
510 NamedTermView::new(head, self.nimap)
511 }
512
513 pub fn body(&self) -> Option<NamedExprView<'a, T>> {
514 self.id.body.map(|id| {
515 let body = self.stor.get_expr(id);
516 NamedExprView::new(body, self.nimap)
517 })
518 }
519}
520
521impl<T: Atom + Display> Display for ClauseRef<'_, T> {
522 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
523 Display::fmt(&self.head(), f)?;
524
525 if let Some(body) = self.body() {
526 f.write_str(" :- ")?;
527 Display::fmt(&body, f)?
528 }
529
530 f.write_char('.')
531 }
532}
533
534impl<T: Atom + Debug> Debug for ClauseRef<'_, T> {
535 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
536 let mut d = f.debug_struct("Clause");
537
538 let head = self.stor.get_term(self.id.head);
539 d.field("head", &NamedTermView::new(head, self.nimap));
540
541 if let Some(body) = self.id.body {
542 let body = self.stor.get_expr(body);
543 d.field("body", &NamedExprView::new(body, self.nimap));
544 }
545
546 d.finish()
547 }
548}
549
550pub struct NamedTermViewIter<'a, T> {
551 term_iter: TermViewIter<'a, Integer>,
552 nimap: &'a NameIntMap<T>,
553}
554
555impl<'a, T: Atom> Iterator for NamedTermViewIter<'a, T> {
556 type Item = NamedTermView<'a, T>;
557
558 fn next(&mut self) -> Option<Self::Item> {
559 self.term_iter
560 .next()
561 .map(|view| NamedTermView::new(view, self.nimap))
562 }
563}
564
565impl<T: Atom> FusedIterator for NamedTermViewIter<'_, T> {}
566
567#[cfg(test)]
568mod str_atom_tests {
569 use crate::{parse, NameIn};
570
571 type Interner = any_intern::DroplessInterner;
572 type Database<'int> = crate::Database<NameIn<'int, Interner>>;
573 type ProveCx<'a, 'int> = crate::ProveCx<'a, NameIn<'int, Interner>>;
574 type ClauseDataset<'int> = crate::ClauseDatasetIn<'int, Interner>;
575 type Expr<'int> = crate::ExprIn<'int, Interner>;
576 type Clause<'int> = crate::ClauseIn<'int, Interner>;
577
578 #[test]
579 fn test_serial_queries() {
580 fn assert_query<'int>(db: &mut Database<'int>, interner: &'int Interner) {
581 let query = "g($X).";
582 let query: Expr<'int> = parse::parse_str(query, interner).unwrap();
583
584 let cx = db.query(query);
585 let answer = collect_answer(cx);
586 let expected = [["$X = a"], ["$X = b"]];
587 assert_eq!(answer, expected);
588 }
589
590 let mut db = Database::new();
591 let interner = Interner::new();
592
593 for _ in 0..2 {
594 insert_dataset(
595 &mut db,
596 &interner,
597 r"
598 f(a).
599 f(b).
600 g($X) :- f($X).
601 ",
602 );
603 let len = db.stor.len();
604 assert_query(&mut db, &interner);
605 assert_eq!(db.stor.len(), len);
606 }
607 }
608
609 #[test]
610 fn test_not_expression() {
611 let mut db = Database::new();
612 let interner = Interner::new();
613
614 insert_dataset(
615 &mut db,
616 &interner,
617 r"
618 g(a).
619 f($X) :- \+ g($X).
620 ",
621 );
622
623 let query: Expr<'_> = parse::parse_str("f(a).", &interner).unwrap();
624 let answer = collect_answer(db.query(query));
625 assert!(answer.is_empty());
626
627 let query: Expr<'_> = parse::parse_str("f(b).", &interner).unwrap();
628 let answer = collect_answer(db.query(query));
629 assert_eq!(answer.len(), 1);
630 }
631
632 #[test]
633 fn test_and_expression() {
634 let mut db = Database::new();
635 let interner = Interner::new();
636
637 insert_dataset(
638 &mut db,
639 &interner,
640 r"
641 g(a).
642 g(b).
643 h(b).
644 f($X) :- g($X), h($X).
645 ",
646 );
647
648 let query: Expr<'_> = parse::parse_str("f($X).", &interner).unwrap();
649 let answer = collect_answer(db.query(query));
650 let expected = [["$X = b"]];
651 assert_eq!(answer, expected);
652 }
653
654 #[test]
655 fn test_or_expression() {
656 let mut db = Database::new();
657 let interner = Interner::new();
658
659 insert_dataset(
660 &mut db,
661 &interner,
662 r"
663 g(a).
664 h(b).
665 f($X) :- g($X); h($X).
666 ",
667 );
668
669 let query: Expr<'_> = parse::parse_str("f($X).", &interner).unwrap();
670 let answer = collect_answer(db.query(query));
671 let expected = [["$X = a"], ["$X = b"]];
672 assert_eq!(answer, expected);
673 }
674
675 #[test]
676 fn test_mixed_expression() {
677 let mut db = Database::new();
678 let interner = Interner::new();
679
680 insert_dataset(
681 &mut db,
682 &interner,
683 r"
684 g(b).
685 g(c).
686
687 h(b).
688
689 i(a).
690 i(b).
691 i(c).
692
693 f($X) :- (\+ g($X); h($X)), i($X).
694 ",
695 );
696
697 let query: Expr<'_> = parse::parse_str("f($X).", &interner).unwrap();
698 let answer = collect_answer(db.query(query));
699 let expected = [["$X = b"]];
700 assert_eq!(answer, expected);
701 }
702
703 #[test]
704 fn test_simple_recursion() {
705 let mut db = Database::new();
706 let interner = Interner::new();
707
708 insert_dataset(
709 &mut db,
710 &interner,
711 r"
712 impl(Clone, a).
713 impl(Clone, b).
714 impl(Clone, c).
715 impl(Clone, Vec($T)) :- impl(Clone, $T).
716 ",
717 );
718
719 let query: Expr<'_> = parse::parse_str("impl(Clone, $T).", &interner).unwrap();
720 let mut cx = db.query(query);
721
722 let mut assert_next = |expected: &[&str]| {
723 let eval = cx.prove_next().unwrap();
724 let assignments = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
725 assert_eq!(assignments, expected);
726 };
727
728 assert_next(&["$T = a"]);
729 assert_next(&["$T = b"]);
730 assert_next(&["$T = c"]);
731 assert_next(&["$T = Vec(a)"]);
732 assert_next(&["$T = Vec(b)"]);
733 assert_next(&["$T = Vec(c)"]);
734 assert_next(&["$T = Vec(Vec(a))"]);
735 assert_next(&["$T = Vec(Vec(b))"]);
736 assert_next(&["$T = Vec(Vec(c))"]);
737 }
738
739 #[test]
740 fn test_right_recursion() {
741 let mut db = Database::new();
742 let interner = Interner::new();
743
744 insert_dataset(
745 &mut db,
746 &interner,
747 r"
748 child(a, b).
749 child(b, c).
750 child(c, d).
751 descend($X, $Y) :- child($X, $Y).
752 descend($X, $Z) :- child($X, $Y), descend($Y, $Z).
753 ",
754 );
755
756 let query: Expr<'_> = parse::parse_str("descend($X, $Y).", &interner).unwrap();
757 let mut answer = collect_answer(db.query(query));
758
759 let mut expected = [
760 ["$X = a", "$Y = b"],
761 ["$X = a", "$Y = c"],
762 ["$X = a", "$Y = d"],
763 ["$X = b", "$Y = c"],
764 ["$X = b", "$Y = d"],
765 ["$X = c", "$Y = d"],
766 ];
767
768 answer.sort_unstable();
769 expected.sort_unstable();
770 assert_eq!(answer, expected);
771 }
772
773 #[test]
775 fn test_mid_recursion() {
776 let mut db = Database::new();
777 let interner = Interner::new();
778
779 insert_dataset(
780 &mut db,
781 &interner,
782 r"
783 edge(a, b).
784 edge(b, c).
785 edge(c, a).
786 path($X, $Y) :- edge($X, $Z), path($Z, $W), edge($W, $Y).
787 path($X, $Y) :- edge($X, $Y).
788 ",
789 );
790
791 let query: Expr<'_> = parse::parse_str("path($X, $Y).", &interner).unwrap();
792 let mut answer = collect_answer(db.query(query));
793
794 let mut expected = [
795 ["$X = a", "$Y = a"],
796 ["$X = a", "$Y = b"],
797 ["$X = a", "$Y = c"],
798 ["$X = b", "$Y = a"],
799 ["$X = b", "$Y = b"],
800 ["$X = b", "$Y = c"],
801 ["$X = c", "$Y = a"],
802 ["$X = c", "$Y = b"],
803 ["$X = c", "$Y = c"],
804 ];
805
806 answer.sort_unstable();
807 expected.sort_unstable();
808 assert_eq!(answer, expected);
809 }
810
811 #[test]
813 fn test_left_recursion() {
814 let mut db = Database::new();
815 let interner = Interner::new();
816
817 insert_dataset(
818 &mut db,
819 &interner,
820 r"
821 parent(a, b).
822 parent(b, c).
823 parent(c, d).
824 ancestor($X, $Y) :- ancestor($X, $Z), parent($Z, $Y).
825 ancestor($X, $Y) :- parent($X, $Y).
826 ",
827 );
828
829 let query: Expr<'_> = parse::parse_str("ancestor($X, $Y).", &interner).unwrap();
830 let mut answer = collect_answer(db.query(query));
831
832 let mut expected = [
833 ["$X = a", "$Y = b"],
834 ["$X = a", "$Y = c"],
835 ["$X = a", "$Y = d"],
836 ["$X = b", "$Y = c"],
837 ["$X = b", "$Y = d"],
838 ["$X = c", "$Y = d"],
839 ];
840
841 answer.sort_unstable();
842 expected.sort_unstable();
843 assert_eq!(answer, expected);
844 }
845
846 #[test]
847 fn test_discarding_uncomitted_change() {
848 let mut db = Database::new();
849 let interner = Interner::new();
850
851 let clause: Clause<'_> = parse::parse_str("f(a).", &interner).unwrap();
852 db.insert_clause(clause);
853 let fa_state = db.state();
854 db.commit();
855
856 let clause: Clause<'_> = parse::parse_str("f(b).", &interner).unwrap();
857 db.insert_clause(clause);
858
859 let query: Expr<'_> = parse::parse_str("f($X).", &interner).unwrap();
860 let answer = collect_answer(db.query(query));
861
862 let expected = [["$X = a"]];
864 assert_eq!(answer, expected);
865 assert_eq!(db.state(), fa_state);
866 }
867
868 fn insert_dataset<'int>(db: &mut Database<'int>, interner: &'int Interner, text: &str) {
871 let dataset: ClauseDataset<'int> = parse::parse_str(text, interner).unwrap();
872 db.insert_dataset(dataset);
873 db.commit();
874 }
875
876 fn collect_answer(mut cx: ProveCx<'_, '_>) -> Vec<Vec<String>> {
877 let mut v = Vec::new();
878 while let Some(eval) = cx.prove_next() {
879 let x = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
880 v.push(x);
881 }
882 v
883 }
884}
885
886#[cfg(test)]
887mod tests {
888 use crate::{Atom, Clause, ClauseDataset, Database, Expr, ProveCx, Term};
889
890 #[test]
891 fn test_custom_atom() {
892 #[allow(non_camel_case_types)]
893 #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
894 enum A {
895 child,
896 descend,
897 a,
898 b,
899 c,
900 d,
901 X,
902 Y,
903 Z,
904 }
905
906 impl Atom for A {
907 fn is_variable(&self) -> bool {
908 matches!(self, A::X | A::Y | A::Z)
909 }
910 }
911
912 let mut db = Database::new();
913
914 let child_a_b = Clause::fact(Term::compound(
915 A::child,
916 [Term::atom(A::a), Term::atom(A::b)],
917 ));
918 let child_b_c = Clause::fact(Term::compound(
919 A::child,
920 [Term::atom(A::b), Term::atom(A::c)],
921 ));
922 let child_c_d = Clause::fact(Term::compound(
923 A::child,
924 [Term::atom(A::c), Term::atom(A::d)],
925 ));
926 let descend_x_y = Clause::rule(
927 Term::compound(A::descend, [Term::atom(A::X), Term::atom(A::Y)]),
928 Expr::term_compound(A::child, [Term::atom(A::X), Term::atom(A::Y)]),
929 );
930 let descend_x_z = Clause::rule(
931 Term::compound(A::descend, [Term::atom(A::X), Term::atom(A::Z)]),
932 Expr::expr_and([
933 Expr::term_compound(A::child, [Term::atom(A::X), Term::atom(A::Y)]),
934 Expr::term_compound(A::descend, [Term::atom(A::Y), Term::atom(A::Z)]),
935 ]),
936 );
937 insert_dataset(
938 &mut db,
939 crate::ClauseDataset(vec![
940 child_a_b,
941 child_b_c,
942 child_c_d,
943 descend_x_y,
944 descend_x_z,
945 ]),
946 );
947
948 let query = Expr::term_compound(A::descend, [Term::atom(A::X), Term::atom(A::Y)]);
949 let mut answer = collect_answer(db.query(query));
950
951 let mut expected = [
952 [
953 (Term::atom(A::X), Term::atom(A::a)),
954 (Term::atom(A::Y), Term::atom(A::b)),
955 ],
956 [
957 (Term::atom(A::X), Term::atom(A::a)),
958 (Term::atom(A::Y), Term::atom(A::c)),
959 ],
960 [
961 (Term::atom(A::X), Term::atom(A::a)),
962 (Term::atom(A::Y), Term::atom(A::d)),
963 ],
964 [
965 (Term::atom(A::X), Term::atom(A::b)),
966 (Term::atom(A::Y), Term::atom(A::c)),
967 ],
968 [
969 (Term::atom(A::X), Term::atom(A::b)),
970 (Term::atom(A::Y), Term::atom(A::d)),
971 ],
972 [
973 (Term::atom(A::X), Term::atom(A::c)),
974 (Term::atom(A::Y), Term::atom(A::d)),
975 ],
976 ];
977
978 answer.sort_unstable();
979 expected.sort_unstable();
980 assert_eq!(answer, expected);
981 }
982
983 fn insert_dataset<T: Atom>(db: &mut Database<T>, dataset: ClauseDataset<T>) {
986 db.insert_dataset(dataset);
987 db.commit();
988 }
989
990 fn collect_answer<T: Atom>(mut cx: ProveCx<'_, T>) -> Vec<Vec<(Term<T>, Term<T>)>> {
991 let mut v = Vec::new();
992 while let Some(eval) = cx.prove_next() {
993 let pairs = eval.map(|assign| (assign.lhs(), assign.rhs())).collect();
994 v.push(pairs);
995 }
996 v
997 }
998}