1use super::{
2 prover::{
3 format::{NamedExprView, NamedTermView},
4 Integer, NameIntMap, NameIntMapState, ProveCx, Prover,
5 },
6 repr::{ClauseId, TermStorage, TermStorageLen},
7};
8use crate::{
9 parse::{
10 repr::{Clause, Expr, Predicate, Term},
11 text::Name,
12 VAR_PREFIX,
13 },
14 prove::repr::{ExprKind, ExprView, TermView, TermViewIter},
15 ClauseDatasetIn, ClauseIn, DefaultInterner, ExprIn, Int2Name, Intern, Map,
16};
17use indexmap::{IndexMap, IndexSet};
18use logic_eval_util::reference::Ref;
19use std::{
20 fmt::{self, Write},
21 iter,
22};
23
24#[derive(Debug)]
25pub struct Database<'int, Int: Intern = DefaultInterner> {
26 clauses: IndexMap<Predicate<Integer>, Vec<ClauseId>>,
28
29 clause_texts: IndexSet<String>,
31
32 stor: TermStorage<Integer>,
34
35 prover: Prover,
37
38 nimap: NameIntMap<'int, Int>,
43
44 revert_point: Option<DatabaseState>,
48
49 interner: Ref<'int, Int>,
50}
51
52impl Database<'static> {
53 #[allow(clippy::new_without_default)]
58 pub fn new() -> Self {
59 let leaked = Box::leak(Box::new(DefaultInterner::default()));
60 let interner = Ref::from_mut(leaked);
61
62 Self {
63 clauses: IndexMap::default(),
64 clause_texts: IndexSet::default(),
65 stor: TermStorage::new(),
66 prover: Prover::new(),
67 nimap: NameIntMap::new(interner),
68 revert_point: None,
69 interner,
70 }
71 }
72
73 #[rustfmt::skip]
74 pub fn dealloc(self) {
75 #[allow(dead_code)]
78 {
79 struct ImplDetector<T>(std::marker::PhantomData<T>);
80 trait NotClone { const IS_CLONE: bool = false; }
81 impl<T> NotClone for ImplDetector<T> {}
82 impl<T: Clone> ImplDetector<T> { const IS_CLONE: bool = true; }
83 const _: () = assert!(!ImplDetector::<Database<'static>>::IS_CLONE);
84 }
85
86 let _ = unsafe { Box::from_raw(self.interner.as_ptr()) };
89 }
90}
91
92impl<'int, Int: Intern> Database<'int, Int> {
93 pub fn with_interner(interner: &'int Int) -> Self {
94 let interner = Ref::from_ref(interner);
95
96 Self {
97 clauses: IndexMap::default(),
98 clause_texts: IndexSet::default(),
99 stor: TermStorage::new(),
100 prover: Prover::new(),
101 nimap: NameIntMap::new(interner),
102 revert_point: None,
103 interner,
104 }
105 }
106
107 pub fn interner(&self) -> &'int Int {
108 self.interner.as_ref()
109 }
110
111 pub fn terms(&self) -> NamedTermViewIter<'_, 'int, Int> {
112 NamedTermViewIter {
113 term_iter: self.stor.terms.terms(),
114 int2name: &self.nimap.int2name,
115 }
116 }
117
118 pub fn clauses(&self) -> ClauseIter<'_, 'int, Int> {
119 ClauseIter {
120 clauses: &self.clauses,
121 stor: &self.stor,
122 int2name: &self.nimap.int2name,
123 i: 0,
124 j: 0,
125 }
126 }
127
128 pub fn insert_dataset(&mut self, dataset: ClauseDatasetIn<'int, Int>) {
129 for clause in dataset {
130 self.insert_clause(clause);
131 }
132 }
133
134 pub fn insert_clause(&mut self, clause: ClauseIn<'int, Int>) {
135 if self.revert_point.is_none() {
138 self.revert_point = Some(self.state());
139 }
140
141 let serialized = if let Some(converted) =
143 Clause::convert_var_into_num(&clause, self.interner.as_ref())
144 {
145 converted.to_string()
146 } else {
147 clause.to_string()
148 };
149 if !self.clause_texts.insert(serialized) {
150 return;
151 }
152
153 let clause = clause.map(&mut |name| self.nimap.name_to_int(name));
154
155 let key = clause.head.predicate();
156 let value = ClauseId {
157 head: self.stor.insert_term(clause.head),
158 body: clause.body.map(|expr| self.stor.insert_expr(expr)),
159 };
160
161 self.clauses
162 .entry(key)
163 .and_modify(|similar_clauses| {
164 if similar_clauses.iter().all(|clause| clause != &value) {
165 similar_clauses.push(value);
166 }
167 })
168 .or_insert(vec![value]);
169 }
170
171 pub fn query(&mut self, expr: ExprIn<'int, Int>) -> ProveCx<'_, 'int, Int> {
172 if let Some(revert_point) = self.revert_point.take() {
174 self.revert(revert_point);
175 }
176
177 self.prover
178 .prove(expr, &self.clauses, &mut self.stor, &mut self.nimap)
179 }
180
181 pub fn commit(&mut self) {
182 self.revert_point.take();
183 }
184
185 pub fn to_prolog<F: FnMut(&str) -> &str>(&self, sanitize: F) -> String {
187 let mut prolog_text = String::new();
188
189 let mut conv_map = ConversionMap {
190 int_to_str: Map::default(),
191 sanitized_to_suffix: Map::default(),
192 int2name: &self.nimap.int2name,
193 sanitizer: sanitize,
194 };
195
196 for clauses in self.clauses.values() {
197 for clause in clauses {
198 let head = self.stor.get_term(clause.head);
199 write_term(head, &mut conv_map, &mut prolog_text);
200
201 if let Some(body) = clause.body {
202 prolog_text.push_str(" :- ");
203
204 let body = self.stor.get_expr(body);
205 write_expr(body, &mut conv_map, &mut prolog_text);
206 }
207
208 prolog_text.push_str(".\n");
209 }
210 }
211
212 return prolog_text;
213
214 struct ConversionMap<'a, 'int, Int: Intern, F> {
217 int_to_str: Map<Integer, String>,
218 sanitized_to_suffix: Map<&'a str, u32>,
220 int2name: &'a Int2Name<'int, Int>,
221 sanitizer: F,
222 }
223
224 impl<Int, F> ConversionMap<'_, '_, Int, F>
225 where
226 Int: Intern,
227 F: FnMut(&str) -> &str,
228 {
229 fn int_to_str(&mut self, int: Integer) -> &str {
230 self.int_to_str.entry(int).or_insert_with(|| {
231 let name = self.int2name.get(&int).unwrap();
232 let name: &str = name.as_ref();
233
234 let mut is_var = false;
235
236 let name = if name.starts_with(VAR_PREFIX) {
238 is_var = true;
239 &name[1..]
240 } else {
241 name
242 };
243
244 let pure_name = (self.sanitizer)(name);
246
247 let suffix = self
248 .sanitized_to_suffix
249 .entry(pure_name)
250 .and_modify(|x| *x += 1)
251 .or_insert(0);
252
253 let mut buf = String::new();
254
255 if is_var {
256 let upper = pure_name.chars().next().unwrap().to_uppercase();
257 for c in upper {
258 buf.push(c);
259 }
260 } else {
261 let lower = pure_name.chars().next().unwrap().to_lowercase();
262 for c in lower {
263 buf.push(c);
264 }
265 };
266 buf.push_str(&pure_name[1..]);
267
268 if *suffix == 0 {
269 buf
270 } else {
271 write!(&mut buf, "_{suffix}").unwrap();
272 buf
273 }
274 })
275 }
276 }
277
278 fn write_term<Int, F>(
279 term: TermView<'_, Integer>,
280 conv_map: &mut ConversionMap<'_, '_, Int, F>,
281 prolog_text: &mut String,
282 ) where
283 Int: Intern,
284 F: FnMut(&str) -> &str,
285 {
286 let functor = term.functor();
287 let args = term.args();
288 let num_args = args.len();
289
290 let functor = conv_map.int_to_str(*functor);
291 prolog_text.push_str(functor);
292
293 if num_args > 0 {
294 prolog_text.push('(');
295 for (i, arg) in args.enumerate() {
296 write_term(arg, conv_map, prolog_text);
297 if i + 1 < num_args {
298 prolog_text.push_str(", ");
299 }
300 }
301 prolog_text.push(')');
302 }
303 }
304
305 fn write_expr<Int, F>(
306 expr: ExprView<'_, Integer>,
307 conv_map: &mut ConversionMap<'_, '_, Int, F>,
308 prolog_text: &mut String,
309 ) where
310 Int: Intern,
311 F: FnMut(&str) -> &str,
312 {
313 match expr.as_kind() {
314 ExprKind::Term(term) => {
315 write_term(term, conv_map, prolog_text);
316 }
317 ExprKind::Not(inner) => {
318 prolog_text.push_str("\\+ ");
319 if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) {
320 prolog_text.push('(');
321 write_expr(inner, conv_map, prolog_text);
322 prolog_text.push(')');
323 } else {
324 write_expr(inner, conv_map, prolog_text);
325 }
326 }
327 ExprKind::And(args) => {
328 let num_args = args.len();
329 for (i, arg) in args.enumerate() {
330 if matches!(arg.as_kind(), ExprKind::Or(_)) {
331 prolog_text.push('(');
332 write_expr(arg, conv_map, prolog_text);
333 prolog_text.push(')');
334 } else {
335 write_expr(arg, conv_map, prolog_text);
336 }
337 if i + 1 < num_args {
338 prolog_text.push_str(", ");
339 }
340 }
341 }
342 ExprKind::Or(args) => {
343 let num_args = args.len();
344 for (i, arg) in args.enumerate() {
345 write_expr(arg, conv_map, prolog_text);
346 if i + 1 < num_args {
347 prolog_text.push_str("; ");
348 }
349 }
350 }
351 }
352 }
353 }
354
355 fn revert(
356 &mut self,
357 DatabaseState {
358 clauses_len,
359 clause_texts_len,
360 stor_len,
361 nimap_state,
362 }: DatabaseState,
363 ) {
364 self.clauses.truncate(clauses_len.len());
365 for (i, len) in clauses_len.into_iter().enumerate() {
366 self.clauses[i].truncate(len);
367 }
368 self.clause_texts.truncate(clause_texts_len);
369 self.stor.truncate(stor_len);
370 self.nimap.revert(nimap_state);
371 }
373
374 fn state(&self) -> DatabaseState {
375 DatabaseState {
376 clauses_len: self.clauses.values().map(|v| v.len()).collect(),
377 clause_texts_len: self.clause_texts.len(),
378 stor_len: self.stor.len(),
379 nimap_state: self.nimap.state(),
380 }
381 }
382}
383
384#[derive(Debug, PartialEq, Eq)]
385struct DatabaseState {
386 clauses_len: Vec<usize>,
387 clause_texts_len: usize,
388 stor_len: TermStorageLen,
389 nimap_state: NameIntMapState,
390}
391
392#[derive(Clone)]
393pub struct ClauseIter<'a, 'int, Int: Intern> {
394 clauses: &'a IndexMap<Predicate<Integer>, Vec<ClauseId>>,
395 stor: &'a TermStorage<Integer>,
396 int2name: &'a Int2Name<'int, Int>,
397 i: usize,
398 j: usize,
399}
400
401impl<'a, 'int, Int: Intern> Iterator for ClauseIter<'a, 'int, Int> {
402 type Item = ClauseRef<'a, 'int, Int>;
403
404 fn next(&mut self) -> Option<Self::Item> {
405 let id = loop {
406 let (_, group) = self.clauses.get_index(self.i)?;
407
408 if let Some(id) = group.get(self.j) {
409 self.j += 1;
410 break *id;
411 }
412
413 self.i += 1;
414 self.j = 0;
415 };
416
417 Some(ClauseRef {
418 id,
419 stor: self.stor,
420 int2name: self.int2name,
421 })
422 }
423}
424
425impl<Int: Intern> iter::FusedIterator for ClauseIter<'_, '_, Int> {}
426
427pub struct ClauseRef<'a, 'int, Int: Intern> {
428 id: ClauseId,
429 stor: &'a TermStorage<Integer>,
430 int2name: &'a Int2Name<'int, Int>,
431}
432
433impl<'a, 'int, Int: Intern> ClauseRef<'a, 'int, Int> {
434 pub fn head(&self) -> NamedTermView<'a, 'int, Int> {
435 let head = self.stor.get_term(self.id.head);
436 NamedTermView::new(head, self.int2name)
437 }
438
439 pub fn body(&self) -> Option<NamedExprView<'a, 'int, Int>> {
440 self.id.body.map(|id| {
441 let body = self.stor.get_expr(id);
442 NamedExprView::new(body, self.int2name)
443 })
444 }
445}
446
447impl<Int: Intern> fmt::Display for ClauseRef<'_, '_, Int> {
448 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449 fmt::Display::fmt(&self.head(), f)?;
450
451 if let Some(body) = self.body() {
452 f.write_str(" :- ")?;
453 fmt::Display::fmt(&body, f)?
454 }
455
456 f.write_char('.')
457 }
458}
459
460impl<Int: Intern> fmt::Debug for ClauseRef<'_, '_, Int> {
461 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462 let mut d = f.debug_struct("Clause");
463
464 let head = self.stor.get_term(self.id.head);
465 d.field("head", &NamedTermView::new(head, self.int2name));
466
467 if let Some(body) = self.id.body {
468 let body = self.stor.get_expr(body);
469 d.field("body", &NamedExprView::new(body, self.int2name));
470 }
471
472 d.finish()
473 }
474}
475
476impl Clause<Name<()>> {
477 pub(crate) fn convert_var_into_num<'int, Int: Intern>(
479 this: &ClauseIn<'int, Int>,
480 interner: &'int Int,
481 ) -> Option<ClauseIn<'int, Int>> {
482 let mut cloned: Option<ClauseIn<'int, Int>> = None;
483
484 let mut i = 0;
485
486 while let Some(var) = find_var_in_clause(cloned.as_ref().unwrap_or(this)) {
487 let from = var.clone();
488
489 let mut convert = |term: &Term<Name<_>>| {
490 (term == &from).then_some(Term {
491 functor: Name::with_intern(&format!("_{VAR_PREFIX}{i}"), interner),
492 args: [].into(),
493 })
494 };
495
496 if let Some(cloned) = &mut cloned {
497 cloned.replace_term(&mut convert);
498 } else {
499 let mut this = this.clone();
500 this.replace_term(&mut convert);
501 cloned = Some(this);
502 }
503
504 i += 1;
505 }
506
507 return cloned;
508
509 fn find_var_in_clause<T: AsRef<str>>(clause: &Clause<Name<T>>) -> Option<&Term<Name<T>>> {
512 let var = find_var_in_term(&clause.head);
513 if var.is_some() {
514 return var;
515 }
516 find_var_in_expr(clause.body.as_ref()?)
517 }
518
519 fn find_var_in_expr<T: AsRef<str>>(expr: &Expr<Name<T>>) -> Option<&Term<Name<T>>> {
520 match expr {
521 Expr::Term(term) => find_var_in_term(term),
522 Expr::Not(inner) => find_var_in_expr(inner),
523 Expr::And(args) | Expr::Or(args) => args.iter().find_map(find_var_in_expr),
524 }
525 }
526
527 fn find_var_in_term<T: AsRef<str>>(term: &Term<Name<T>>) -> Option<&Term<Name<T>>> {
528 const _: () = assert!(VAR_PREFIX == '$');
529
530 if term.is_variable() && !term.functor.as_ref().starts_with("_$") {
531 Some(term)
532 } else {
533 term.args.iter().find_map(find_var_in_term)
534 }
535 }
536 }
537}
538
539pub struct NamedTermViewIter<'a, 'int, Int: Intern> {
540 term_iter: TermViewIter<'a, Integer>,
541 int2name: &'a Int2Name<'int, Int>,
542}
543
544impl<'a, 'int, Int: Intern> Iterator for NamedTermViewIter<'a, 'int, Int> {
545 type Item = NamedTermView<'a, 'int, Int>;
546
547 fn next(&mut self) -> Option<Self::Item> {
548 self.term_iter
549 .next()
550 .map(|view| NamedTermView::new(view, self.int2name))
551 }
552}
553
554impl<Int: Intern> iter::FusedIterator for NamedTermViewIter<'_, '_, Int> {}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use crate::parse::{
560 self,
561 repr::{Clause, ClauseDataset, Expr},
562 };
563 use any_intern::DroplessInterner;
564
565 #[test]
566 fn test_parse() {
567 fn assert(text: &str, interner: &impl Intern) {
568 let clause: Clause<Name<_>> = parse::parse_str(text, interner).unwrap();
569 assert_eq!(text, clause.to_string());
570 }
571
572 let interner = DroplessInterner::default();
573
574 assert("f.", &interner);
575 assert("f(a, b).", &interner);
576 assert("f(a, b) :- f.", &interner);
577 assert("f(a, b) :- f(a).", &interner);
578 assert("f(a, b) :- f(a), f(b).", &interner);
579 assert("f(a, b) :- f(a); f(b).", &interner);
580 assert("f(a, b) :- f(a), (f(b); f(c)).", &interner);
581 }
582
583 #[test]
584 fn test_serial_queries() {
585 fn insert<Int: Intern>(db: &mut Database<'_, Int>) {
586 insert_dataset(
587 db,
588 r"
589 f(a).
590 f(b).
591 g($X) :- f($X).
592 ",
593 );
594 }
595
596 fn query<Int: Intern>(db: &mut Database<'_, Int>) {
597 let query = "g($X).";
598 let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
599 let answer = collect_answer(db.query(query));
600
601 let expected = [["$X = a"], ["$X = b"]];
602
603 assert_eq!(answer, expected);
604 }
605
606 let mut db = Database::new();
607
608 insert(&mut db);
609 let org_stor_len = db.stor.len();
610 query(&mut db);
611 debug_assert_eq!(org_stor_len, db.stor.len());
612
613 insert(&mut db);
614 debug_assert_eq!(org_stor_len, db.stor.len());
615 query(&mut db);
616 debug_assert_eq!(org_stor_len, db.stor.len());
617
618 db.dealloc();
619 }
620
621 #[test]
622 fn test_various_expressions() {
623 test_not_expression();
624 test_and_expression();
625 test_or_expression();
626 test_mixed_expression();
627 }
628
629 fn test_not_expression() {
630 let mut db = Database::new();
631
632 insert_dataset(
633 &mut db,
634 r"
635 g(a).
636 f($X) :- \+ g($X).
637 ",
638 );
639
640 let query = "f(a).";
641 let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
642 let answer = collect_answer(db.query(query));
643 assert!(answer.is_empty());
644
645 let query = "f(b).";
646 let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
647 let answer = collect_answer(db.query(query));
648 assert_eq!(answer.len(), 1);
649
650 db.dealloc();
651 }
652
653 fn test_and_expression() {
654 let mut db = Database::new();
655
656 insert_dataset(
657 &mut db,
658 r"
659 g(a).
660 g(b).
661 h(b).
662 f($X) :- g($X), h($X).
663 ",
664 );
665
666 let query = "f($X).";
667 let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
668 let answer = collect_answer(db.query(query));
669
670 let expected = [["$X = b"]];
671 assert_eq!(answer, expected);
672
673 db.dealloc();
674 }
675
676 fn test_or_expression() {
677 let mut db = Database::new();
678
679 insert_dataset(
680 &mut db,
681 r"
682 g(a).
683 h(b).
684 f($X) :- g($X); h($X).
685 ",
686 );
687
688 let query = "f($X).";
689 let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
690 let answer = collect_answer(db.query(query));
691
692 let expected = [["$X = a"], ["$X = b"]];
693 assert_eq!(answer, expected);
694
695 db.dealloc();
696 }
697
698 fn test_mixed_expression() {
699 let mut db = Database::new();
700
701 insert_dataset(
702 &mut db,
703 r"
704 g(b).
705 g(c).
706
707 h(b).
708
709 i(a).
710 i(b).
711 i(c).
712
713 f($X) :- (\+ g($X); h($X)), i($X).
714 ",
715 );
716
717 let query = "f($X).";
718 let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
719 let answer = collect_answer(db.query(query));
720
721 let expected = [["$X = b"]];
722 assert_eq!(answer, expected);
723
724 db.dealloc();
725 }
726
727 #[test]
728 fn test_recursion() {
729 test_simple_recursion();
730 test_right_recursion();
731 }
732
733 fn test_simple_recursion() {
734 let mut db = Database::new();
735
736 insert_dataset(
737 &mut db,
738 r"
739 impl(Clone, a).
740 impl(Clone, b).
741 impl(Clone, c).
742 impl(Clone, Vec($T)) :- impl(Clone, $T).
743 ",
744 );
745
746 let query = "impl(Clone, $T).";
747 let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
748 let mut cx = db.query(query);
749
750 let mut assert_next = |expected: &[&str]| {
751 let eval = cx.prove_next().unwrap();
752 let assignments = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
753 assert_eq!(assignments, expected);
754 };
755
756 assert_next(&["$T = a"]);
757 assert_next(&["$T = b"]);
758 assert_next(&["$T = c"]);
759 assert_next(&["$T = Vec(a)"]);
760 assert_next(&["$T = Vec(b)"]);
761 assert_next(&["$T = Vec(c)"]);
762 assert_next(&["$T = Vec(Vec(a))"]);
763 assert_next(&["$T = Vec(Vec(b))"]);
764 assert_next(&["$T = Vec(Vec(c))"]);
765
766 drop(cx);
767 db.dealloc();
768 }
769
770 fn test_right_recursion() {
771 let mut db = Database::new();
772
773 insert_dataset(
774 &mut db,
775 r"
776 child(a, b).
777 child(b, c).
778 child(c, d).
779 descend($X, $Y) :- child($X, $Y).
780 descend($X, $Z) :- child($X, $Y), descend($Y, $Z).
781 ",
782 );
783
784 let query = "descend($X, $Y).";
785 let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
786 let mut answer = collect_answer(db.query(query));
787
788 let mut expected = [
789 ["$X = a", "$Y = b"],
790 ["$X = a", "$Y = c"],
791 ["$X = a", "$Y = d"],
792 ["$X = b", "$Y = c"],
793 ["$X = b", "$Y = d"],
794 ["$X = c", "$Y = d"],
795 ];
796
797 answer.sort_unstable();
798 expected.sort_unstable();
799 assert_eq!(answer, expected);
800
801 db.dealloc();
802 }
803
804 #[test]
805 fn test_discarding_uncomitted_change() {
806 let mut db = Database::new();
807
808 let text = "f(a).";
809 let clause = parse::parse_str(text, db.interner()).unwrap();
810 db.insert_clause(clause);
811 let fa_state = db.state();
812 db.commit();
813
814 let text = "f(b).";
815 let clause = parse::parse_str(text, db.interner()).unwrap();
816 db.insert_clause(clause);
817
818 let query = "f($X).";
819 let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
820 let answer = collect_answer(db.query(query));
821
822 let expected = [["$X = a"]];
824 assert_eq!(answer, expected);
825 assert_eq!(db.state(), fa_state);
826
827 db.dealloc();
828 }
829
830 fn insert_dataset<Int: Intern>(db: &mut Database<'_, Int>, text: &str) {
831 let dataset: ClauseDataset<Name<_>> = parse::parse_str(text, db.interner()).unwrap();
832 db.insert_dataset(dataset);
833 db.commit();
834 }
835
836 fn collect_answer<Int: Intern>(mut cx: ProveCx<'_, '_, Int>) -> Vec<Vec<String>> {
837 let mut v = Vec::new();
838 while let Some(eval) = cx.prove_next() {
839 let x = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
840 v.push(x);
841 }
842 v
843 }
844}