1use super::{
2 prover::{
3 Integer, NameIntMap, NameIntMapState, ProveCx, Prover,
4 format::{NamedExprView, NamedTermView},
5 },
6 repr::{ClauseId, TermStorage, TermStorageLen},
7};
8use crate::{
9 ClauseDatasetIn, ClauseIn, ExprIn, Int2Name, Intern, Map,
10 parse::{
11 VAR_PREFIX,
12 repr::{Clause, Expr, Predicate, Term},
13 text::Name,
14 },
15 prove::repr::{ExprKind, ExprView, TermView, TermViewIter},
16};
17use indexmap::{IndexMap, IndexSet};
18use std::{
19 fmt::{self, Write},
20 iter,
21};
22
23#[derive(Debug)]
24pub struct Database<'int, Int: Intern> {
25 clauses: IndexMap<Predicate<Integer>, Vec<ClauseId>>,
27
28 clause_texts: IndexSet<String>,
30
31 stor: TermStorage<Integer>,
33
34 prover: Prover,
36
37 nimap: NameIntMap<'int, Int>,
42
43 revert_point: Option<DatabaseState>,
47
48 interner: &'int Int,
49}
50
51impl<'int, Int: Intern> Database<'int, Int> {
52 pub fn new(interner: &'int Int) -> Self {
53 Self {
54 clauses: IndexMap::default(),
55 clause_texts: IndexSet::default(),
56 stor: TermStorage::new(),
57 prover: Prover::new(),
58 nimap: NameIntMap::new(interner),
59 revert_point: None,
60 interner,
61 }
62 }
63
64 pub fn interner(&self) -> &'int Int {
65 self.interner
66 }
67
68 pub fn terms(&self) -> NamedTermViewIter<'_, 'int, Int> {
69 NamedTermViewIter {
70 term_iter: self.stor.terms.terms(),
71 int2name: &self.nimap.int2name,
72 }
73 }
74
75 pub fn clauses(&self) -> ClauseIter<'_, 'int, Int> {
76 ClauseIter {
77 clauses: &self.clauses,
78 stor: &self.stor,
79 int2name: &self.nimap.int2name,
80 i: 0,
81 j: 0,
82 }
83 }
84
85 pub fn insert_dataset(&mut self, dataset: ClauseDatasetIn<'int, Int>) {
86 for clause in dataset {
87 self.insert_clause(clause);
88 }
89 }
90
91 pub fn insert_clause(&mut self, clause: ClauseIn<'int, Int>) {
92 if self.revert_point.is_none() {
95 self.revert_point = Some(self.state());
96 }
97
98 let serialized =
100 if let Some(converted) = Clause::convert_var_into_num(&clause, self.interner) {
101 converted.to_string()
102 } else {
103 clause.to_string()
104 };
105 if !self.clause_texts.insert(serialized) {
106 return;
107 }
108
109 let clause = clause.map(&mut |name| self.nimap.name_to_int(name));
110
111 let key = clause.head.predicate();
112 let value = ClauseId {
113 head: self.stor.insert_term(clause.head),
114 body: clause.body.map(|expr| self.stor.insert_expr(expr)),
115 };
116
117 self.clauses
118 .entry(key)
119 .and_modify(|similar_clauses| {
120 if similar_clauses.iter().all(|clause| clause != &value) {
121 similar_clauses.push(value);
122 }
123 })
124 .or_insert(vec![value]);
125 }
126
127 pub fn query(&mut self, expr: ExprIn<'int, Int>) -> ProveCx<'_, 'int, Int> {
128 if let Some(revert_point) = self.revert_point.take() {
130 self.revert(revert_point);
131 }
132
133 self.prover
134 .prove(expr, &self.clauses, &mut self.stor, &mut self.nimap)
135 }
136
137 pub fn commit(&mut self) {
138 self.revert_point.take();
139 }
140
141 pub fn to_prolog<F: FnMut(&str) -> &str>(&self, sanitize: F) -> String {
143 let mut prolog_text = String::new();
144
145 let mut conv_map = ConversionMap {
146 int_to_str: Map::default(),
147 sanitized_to_suffix: Map::default(),
148 int2name: &self.nimap.int2name,
149 sanitizer: sanitize,
150 };
151
152 for clauses in self.clauses.values() {
153 for clause in clauses {
154 let head = self.stor.get_term(clause.head);
155 write_term(head, &mut conv_map, &mut prolog_text);
156
157 if let Some(body) = clause.body {
158 prolog_text.push_str(" :- ");
159
160 let body = self.stor.get_expr(body);
161 write_expr(body, &mut conv_map, &mut prolog_text);
162 }
163
164 prolog_text.push_str(".\n");
165 }
166 }
167
168 return prolog_text;
169
170 struct ConversionMap<'a, 'int, Int: Intern, F> {
173 int_to_str: Map<Integer, String>,
174 sanitized_to_suffix: Map<&'a str, u32>,
176 int2name: &'a Int2Name<'int, Int>,
177 sanitizer: F,
178 }
179
180 impl<Int, F> ConversionMap<'_, '_, Int, F>
181 where
182 Int: Intern,
183 F: FnMut(&str) -> &str,
184 {
185 fn int_to_str(&mut self, int: Integer) -> &str {
186 self.int_to_str.entry(int).or_insert_with(|| {
187 let name = self.int2name.get(&int).unwrap();
188 let name: &str = name.as_ref();
189
190 let mut is_var = false;
191
192 let name = if name.starts_with(VAR_PREFIX) {
194 is_var = true;
195 &name[1..]
196 } else {
197 name
198 };
199
200 let pure_name = (self.sanitizer)(name);
202
203 let suffix = self
204 .sanitized_to_suffix
205 .entry(pure_name)
206 .and_modify(|x| *x += 1)
207 .or_insert(0);
208
209 let mut buf = String::new();
210
211 if is_var {
212 let upper = pure_name.chars().next().unwrap().to_uppercase();
213 for c in upper {
214 buf.push(c);
215 }
216 } else {
217 let lower = pure_name.chars().next().unwrap().to_lowercase();
218 for c in lower {
219 buf.push(c);
220 }
221 };
222 buf.push_str(&pure_name[1..]);
223
224 if *suffix == 0 {
225 buf
226 } else {
227 write!(&mut buf, "_{suffix}").unwrap();
228 buf
229 }
230 })
231 }
232 }
233
234 fn write_term<Int, F>(
235 term: TermView<'_, Integer>,
236 conv_map: &mut ConversionMap<'_, '_, Int, F>,
237 prolog_text: &mut String,
238 ) where
239 Int: Intern,
240 F: FnMut(&str) -> &str,
241 {
242 let functor = term.functor();
243 let args = term.args();
244 let num_args = args.len();
245
246 let functor = conv_map.int_to_str(*functor);
247 prolog_text.push_str(functor);
248
249 if num_args > 0 {
250 prolog_text.push('(');
251 for (i, arg) in args.enumerate() {
252 write_term(arg, conv_map, prolog_text);
253 if i + 1 < num_args {
254 prolog_text.push_str(", ");
255 }
256 }
257 prolog_text.push(')');
258 }
259 }
260
261 fn write_expr<Int, F>(
262 expr: ExprView<'_, Integer>,
263 conv_map: &mut ConversionMap<'_, '_, Int, F>,
264 prolog_text: &mut String,
265 ) where
266 Int: Intern,
267 F: FnMut(&str) -> &str,
268 {
269 match expr.as_kind() {
270 ExprKind::Term(term) => {
271 write_term(term, conv_map, prolog_text);
272 }
273 ExprKind::Not(inner) => {
274 prolog_text.push_str("\\+ ");
275 if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) {
276 prolog_text.push('(');
277 write_expr(inner, conv_map, prolog_text);
278 prolog_text.push(')');
279 } else {
280 write_expr(inner, conv_map, prolog_text);
281 }
282 }
283 ExprKind::And(args) => {
284 let num_args = args.len();
285 for (i, arg) in args.enumerate() {
286 if matches!(arg.as_kind(), ExprKind::Or(_)) {
287 prolog_text.push('(');
288 write_expr(arg, conv_map, prolog_text);
289 prolog_text.push(')');
290 } else {
291 write_expr(arg, conv_map, prolog_text);
292 }
293 if i + 1 < num_args {
294 prolog_text.push_str(", ");
295 }
296 }
297 }
298 ExprKind::Or(args) => {
299 let num_args = args.len();
300 for (i, arg) in args.enumerate() {
301 write_expr(arg, conv_map, prolog_text);
302 if i + 1 < num_args {
303 prolog_text.push_str("; ");
304 }
305 }
306 }
307 }
308 }
309 }
310
311 fn revert(
312 &mut self,
313 DatabaseState {
314 clauses_len,
315 clause_texts_len,
316 stor_len,
317 nimap_state,
318 }: DatabaseState,
319 ) {
320 self.clauses.truncate(clauses_len.len());
321 for (i, len) in clauses_len.into_iter().enumerate() {
322 self.clauses[i].truncate(len);
323 }
324 self.clause_texts.truncate(clause_texts_len);
325 self.stor.truncate(stor_len);
326 self.nimap.revert(nimap_state);
327 }
329
330 fn state(&self) -> DatabaseState {
331 DatabaseState {
332 clauses_len: self.clauses.values().map(|v| v.len()).collect(),
333 clause_texts_len: self.clause_texts.len(),
334 stor_len: self.stor.len(),
335 nimap_state: self.nimap.state(),
336 }
337 }
338}
339
340#[derive(Debug, PartialEq, Eq)]
341struct DatabaseState {
342 clauses_len: Vec<usize>,
343 clause_texts_len: usize,
344 stor_len: TermStorageLen,
345 nimap_state: NameIntMapState,
346}
347
348#[derive(Clone)]
349pub struct ClauseIter<'a, 'int, Int: Intern> {
350 clauses: &'a IndexMap<Predicate<Integer>, Vec<ClauseId>>,
351 stor: &'a TermStorage<Integer>,
352 int2name: &'a Int2Name<'int, Int>,
353 i: usize,
354 j: usize,
355}
356
357impl<'a, 'int, Int: Intern> Iterator for ClauseIter<'a, 'int, Int> {
358 type Item = ClauseRef<'a, 'int, Int>;
359
360 fn next(&mut self) -> Option<Self::Item> {
361 let id = loop {
362 let (_, group) = self.clauses.get_index(self.i)?;
363
364 if let Some(id) = group.get(self.j) {
365 self.j += 1;
366 break *id;
367 }
368
369 self.i += 1;
370 self.j = 0;
371 };
372
373 Some(ClauseRef {
374 id,
375 stor: self.stor,
376 int2name: self.int2name,
377 })
378 }
379}
380
381impl<Int: Intern> iter::FusedIterator for ClauseIter<'_, '_, Int> {}
382
383pub struct ClauseRef<'a, 'int, Int: Intern> {
384 id: ClauseId,
385 stor: &'a TermStorage<Integer>,
386 int2name: &'a Int2Name<'int, Int>,
387}
388
389impl<'a, 'int, Int: Intern> ClauseRef<'a, 'int, Int> {
390 pub fn head(&self) -> NamedTermView<'a, 'int, Int> {
391 let head = self.stor.get_term(self.id.head);
392 NamedTermView::new(head, self.int2name)
393 }
394
395 pub fn body(&self) -> Option<NamedExprView<'a, 'int, Int>> {
396 self.id.body.map(|id| {
397 let body = self.stor.get_expr(id);
398 NamedExprView::new(body, self.int2name)
399 })
400 }
401}
402
403impl<Int: Intern> fmt::Display for ClauseRef<'_, '_, Int> {
404 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
405 fmt::Display::fmt(&self.head(), f)?;
406
407 if let Some(body) = self.body() {
408 f.write_str(" :- ")?;
409 fmt::Display::fmt(&body, f)?
410 }
411
412 f.write_char('.')
413 }
414}
415
416impl<Int: Intern> fmt::Debug for ClauseRef<'_, '_, Int> {
417 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
418 let mut d = f.debug_struct("Clause");
419
420 let head = self.stor.get_term(self.id.head);
421 d.field("head", &NamedTermView::new(head, self.int2name));
422
423 if let Some(body) = self.id.body {
424 let body = self.stor.get_expr(body);
425 d.field("body", &NamedExprView::new(body, self.int2name));
426 }
427
428 d.finish()
429 }
430}
431
432impl Clause<Name<()>> {
433 pub(crate) fn convert_var_into_num<'int, Int: Intern>(
435 this: &ClauseIn<'int, Int>,
436 interner: &'int Int,
437 ) -> Option<ClauseIn<'int, Int>> {
438 let mut cloned: Option<ClauseIn<'int, Int>> = None;
439
440 let mut i = 0;
441
442 while let Some(var) = find_var_in_clause(cloned.as_ref().unwrap_or(this)) {
443 let from = var.clone();
444
445 let mut convert = |term: &Term<Name<_>>| {
446 (term == &from).then_some(Term {
447 functor: Name::with_intern(&format!("_{VAR_PREFIX}{i}"), interner),
448 args: [].into(),
449 })
450 };
451
452 if let Some(cloned) = &mut cloned {
453 cloned.replace_term(&mut convert);
454 } else {
455 let mut this = this.clone();
456 this.replace_term(&mut convert);
457 cloned = Some(this);
458 }
459
460 i += 1;
461 }
462
463 return cloned;
464
465 fn find_var_in_clause<T: AsRef<str>>(clause: &Clause<Name<T>>) -> Option<&Term<Name<T>>> {
468 let var = find_var_in_term(&clause.head);
469 if var.is_some() {
470 return var;
471 }
472 find_var_in_expr(clause.body.as_ref()?)
473 }
474
475 fn find_var_in_expr<T: AsRef<str>>(expr: &Expr<Name<T>>) -> Option<&Term<Name<T>>> {
476 match expr {
477 Expr::Term(term) => find_var_in_term(term),
478 Expr::Not(inner) => find_var_in_expr(inner),
479 Expr::And(args) | Expr::Or(args) => args.iter().find_map(find_var_in_expr),
480 }
481 }
482
483 fn find_var_in_term<T: AsRef<str>>(term: &Term<Name<T>>) -> Option<&Term<Name<T>>> {
484 const _: () = assert!(VAR_PREFIX == '$');
485
486 if term.is_variable() && !term.functor.as_ref().starts_with("_$") {
487 Some(term)
488 } else {
489 term.args.iter().find_map(find_var_in_term)
490 }
491 }
492 }
493}
494
495pub struct NamedTermViewIter<'a, 'int, Int: Intern> {
496 term_iter: TermViewIter<'a, Integer>,
497 int2name: &'a Int2Name<'int, Int>,
498}
499
500impl<'a, 'int, Int: Intern> Iterator for NamedTermViewIter<'a, 'int, Int> {
501 type Item = NamedTermView<'a, 'int, Int>;
502
503 fn next(&mut self) -> Option<Self::Item> {
504 self.term_iter
505 .next()
506 .map(|view| NamedTermView::new(view, self.int2name))
507 }
508}
509
510impl<Int: Intern> iter::FusedIterator for NamedTermViewIter<'_, '_, Int> {}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use crate::parse::{
516 self,
517 repr::{Clause, ClauseDataset, Expr},
518 };
519 use any_intern::DroplessInterner;
520
521 #[test]
522 fn test_parse() {
523 fn assert(text: &str, interner: &impl Intern) {
524 let clause: Clause<Name<_>> = parse::parse_str(text, interner).unwrap();
525 assert_eq!(text, clause.to_string());
526 }
527
528 let interner = DroplessInterner::default();
529
530 assert("f.", &interner);
531 assert("f(a, b).", &interner);
532 assert("f(a, b) :- f.", &interner);
533 assert("f(a, b) :- f(a).", &interner);
534 assert("f(a, b) :- f(a), f(b).", &interner);
535 assert("f(a, b) :- f(a); f(b).", &interner);
536 assert("f(a, b) :- f(a), (f(b); f(c)).", &interner);
537 }
538
539 #[test]
540 fn test_serial_queries() {
541 fn insert<Int: Intern>(db: &mut Database<'_, Int>) {
542 insert_dataset(
543 db,
544 r"
545 f(a).
546 f(b).
547 g($X) :- f($X).
548 ",
549 );
550 }
551
552 fn query<Int: Intern>(db: &mut Database<'_, Int>) {
553 let query = "g($X).";
554 let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
555 let answer = collect_answer(db.query(query));
556
557 let expected = [["$X = a"], ["$X = b"]];
558
559 assert_eq!(answer, expected);
560 }
561
562 let interner = DroplessInterner::default();
563 let mut db = Database::new(&interner);
564
565 insert(&mut db);
566 let org_stor_len = db.stor.len();
567 query(&mut db);
568 debug_assert_eq!(org_stor_len, db.stor.len());
569
570 insert(&mut db);
571 debug_assert_eq!(org_stor_len, db.stor.len());
572 query(&mut db);
573 debug_assert_eq!(org_stor_len, db.stor.len());
574 }
575
576 #[test]
577 fn test_various_expressions() {
578 test_not_expression();
579 test_and_expression();
580 test_or_expression();
581 test_mixed_expression();
582 }
583
584 fn test_not_expression() {
585 let interner = DroplessInterner::default();
586 let mut db = Database::new(&interner);
587
588 insert_dataset(
589 &mut db,
590 r"
591 g(a).
592 f($X) :- \+ g($X).
593 ",
594 );
595
596 let query = "f(a).";
597 let query: Expr<Name<_>> = parse::parse_str(query, &interner).unwrap();
598 let answer = collect_answer(db.query(query));
599 assert!(answer.is_empty());
600
601 let query = "f(b).";
602 let query: Expr<Name<_>> = parse::parse_str(query, &interner).unwrap();
603 let answer = collect_answer(db.query(query));
604 assert_eq!(answer.len(), 1);
605 }
606
607 fn test_and_expression() {
608 let interner = DroplessInterner::default();
609 let mut db = Database::new(&interner);
610
611 insert_dataset(
612 &mut db,
613 r"
614 g(a).
615 g(b).
616 h(b).
617 f($X) :- g($X), h($X).
618 ",
619 );
620
621 let query = "f($X).";
622 let query: Expr<Name<_>> = parse::parse_str(query, &interner).unwrap();
623 let answer = collect_answer(db.query(query));
624
625 let expected = [["$X = b"]];
626
627 assert_eq!(answer, expected);
628 }
629
630 fn test_or_expression() {
631 let interner = DroplessInterner::default();
632 let mut db = Database::new(&interner);
633
634 insert_dataset(
635 &mut db,
636 r"
637 g(a).
638 h(b).
639 f($X) :- g($X); h($X).
640 ",
641 );
642
643 let query = "f($X).";
644 let query: Expr<Name<_>> = parse::parse_str(query, &interner).unwrap();
645 let answer = collect_answer(db.query(query));
646
647 let expected = [["$X = a"], ["$X = b"]];
648
649 assert_eq!(answer, expected);
650 }
651
652 fn test_mixed_expression() {
653 let interner = DroplessInterner::default();
654 let mut db = Database::new(&interner);
655
656 insert_dataset(
657 &mut db,
658 r"
659 g(b).
660 g(c).
661
662 h(b).
663
664 i(a).
665 i(b).
666 i(c).
667
668 f($X) :- (\+ g($X); h($X)), i($X).
669 ",
670 );
671
672 let query = "f($X).";
673 let query: Expr<Name<_>> = parse::parse_str(query, &interner).unwrap();
674 let answer = collect_answer(db.query(query));
675
676 let expected = [["$X = b"]];
677
678 assert_eq!(answer, expected);
679 }
680
681 #[test]
682 fn test_recursion() {
683 test_simple_recursion();
684 test_right_recursion();
685 }
686
687 fn test_simple_recursion() {
688 let interner = DroplessInterner::default();
689 let mut db = Database::new(&interner);
690
691 insert_dataset(
692 &mut db,
693 r"
694 impl(Clone, a).
695 impl(Clone, b).
696 impl(Clone, c).
697 impl(Clone, Vec($T)) :- impl(Clone, $T).
698 ",
699 );
700
701 let query = "impl(Clone, $T).";
702 let query: Expr<Name<_>> = parse::parse_str(query, &interner).unwrap();
703 let mut cx = db.query(query);
704
705 let mut assert_next = |expected: &[&str]| {
706 let eval = cx.prove_next().unwrap();
707 let assignments = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
708 assert_eq!(assignments, expected);
709 };
710
711 assert_next(&["$T = a"]);
712 assert_next(&["$T = b"]);
713 assert_next(&["$T = c"]);
714 assert_next(&["$T = Vec(a)"]);
715 assert_next(&["$T = Vec(b)"]);
716 assert_next(&["$T = Vec(c)"]);
717 assert_next(&["$T = Vec(Vec(a))"]);
718 assert_next(&["$T = Vec(Vec(b))"]);
719 assert_next(&["$T = Vec(Vec(c))"]);
720 }
721
722 fn test_right_recursion() {
723 let interner = DroplessInterner::default();
724 let mut db = Database::new(&interner);
725
726 insert_dataset(
727 &mut db,
728 r"
729 child(a, b).
730 child(b, c).
731 child(c, d).
732 descend($X, $Y) :- child($X, $Y).
733 descend($X, $Z) :- child($X, $Y), descend($Y, $Z).
734 ",
735 );
736
737 let query = "descend($X, $Y).";
738 let query: Expr<Name<_>> = parse::parse_str(query, &interner).unwrap();
739 let mut answer = collect_answer(db.query(query));
740
741 let mut expected = [
742 ["$X = a", "$Y = b"],
743 ["$X = a", "$Y = c"],
744 ["$X = a", "$Y = d"],
745 ["$X = b", "$Y = c"],
746 ["$X = b", "$Y = d"],
747 ["$X = c", "$Y = d"],
748 ];
749
750 answer.sort_unstable();
751 expected.sort_unstable();
752 assert_eq!(answer, expected);
753 }
754
755 #[test]
756 fn test_discarding_uncomitted_change() {
757 let interner = DroplessInterner::default();
758 let mut db = Database::new(&interner);
759
760 let text = "f(a).";
761 let clause = parse::parse_str(text, &interner).unwrap();
762 db.insert_clause(clause);
763 let fa_state = db.state();
764 db.commit();
765
766 let text = "f(b).";
767 let clause = parse::parse_str(text, &interner).unwrap();
768 db.insert_clause(clause);
769
770 let query = "f($X).";
771 let query: Expr<Name<_>> = parse::parse_str(query, &interner).unwrap();
772 let answer = collect_answer(db.query(query));
773
774 let expected = [["$X = a"]];
776 assert_eq!(answer, expected);
777 assert_eq!(db.state(), fa_state);
778 }
779
780 fn insert_dataset<Int: Intern>(db: &mut Database<'_, Int>, text: &str) {
781 let dataset: ClauseDataset<Name<_>> = parse::parse_str(text, db.interner()).unwrap();
782 db.insert_dataset(dataset);
783 db.commit();
784 }
785
786 fn collect_answer<Int: Intern>(mut cx: ProveCx<'_, '_, Int>) -> Vec<Vec<String>> {
787 let mut v = Vec::new();
788 while let Some(eval) = cx.prove_next() {
789 let x = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
790 v.push(x);
791 }
792 v
793 }
794}