1use super::{
2 prover::{
3 IdxMap, Int, NameIntMap, NameIntMapState, ProveCx, Prover,
4 format::{NamedExprView, NamedTermView},
5 },
6 repr::{ClauseId, TermStorage, TermStorageLen},
7};
8use crate::{
9 Map,
10 parse::{
11 GlobalCx, VAR_PREFIX,
12 repr::{Clause, ClauseDataset, Expr, Predicate, Term},
13 text::Name,
14 },
15 prove::repr::{ExprKind, ExprView, TermView, TermViewIter},
16};
17use any_intern::DroplessInterner;
18use indexmap::{IndexMap, IndexSet};
19use std::{
20 fmt::{self, Write},
21 iter,
22};
23
24#[derive(Debug)]
25pub struct Database<'cx> {
26 clauses: IndexMap<Predicate<Int>, Vec<ClauseId>>,
28
29 clause_texts: IndexSet<String>,
31
32 stor: TermStorage<Int>,
34
35 prover: Prover,
37
38 nimap: NameIntMap<'cx>,
43
44 revert_point: Option<DatabaseState>,
48
49 gcx: GlobalCx<'cx>,
50}
51
52impl<'cx> Database<'cx> {
53 pub fn new(interner: &'cx DroplessInterner) -> Self {
54 let gcx = GlobalCx { interner };
55 Self {
56 clauses: IndexMap::default(),
57 clause_texts: IndexSet::default(),
58 stor: TermStorage::new(),
59 prover: Prover::new(),
60 nimap: NameIntMap::new(gcx),
61 revert_point: None,
62 gcx,
63 }
64 }
65
66 pub fn gcx(&self) -> &GlobalCx<'cx> {
67 &self.gcx
68 }
69
70 pub fn terms(&self) -> NamedTermViewIter<'_, 'cx> {
71 NamedTermViewIter {
72 term_iter: self.stor.terms.terms(),
73 int2name: &self.nimap.int2name,
74 }
75 }
76
77 pub fn clauses(&self) -> ClauseIter<'_, 'cx> {
78 ClauseIter {
79 clauses: &self.clauses,
80 stor: &self.stor,
81 int2name: &self.nimap.int2name,
82 i: 0,
83 j: 0,
84 }
85 }
86
87 pub fn insert_dataset(&mut self, dataset: ClauseDataset<Name<'cx>>) {
88 for clause in dataset {
89 self.insert_clause(clause);
90 }
91 }
92
93 pub fn insert_clause(&mut self, clause: Clause<Name<'cx>>) {
94 if self.revert_point.is_none() {
97 self.revert_point = Some(self.state());
98 }
99
100 let serialized = if let Some(converted) = clause.convert_var_into_num(&self.gcx) {
102 converted.to_string()
103 } else {
104 clause.to_string()
105 };
106 if !self.clause_texts.insert(serialized) {
107 return;
108 }
109
110 let clause = clause.map(&mut |name| self.nimap.name_to_int(name));
111
112 let key = clause.head.predicate();
113 let value = ClauseId {
114 head: self.stor.insert_term(clause.head),
115 body: clause.body.map(|expr| self.stor.insert_expr(expr)),
116 };
117
118 self.clauses
119 .entry(key)
120 .and_modify(|similar_clauses| {
121 if similar_clauses.iter().all(|clause| clause != &value) {
122 similar_clauses.push(value);
123 }
124 })
125 .or_insert(vec![value]);
126 }
127
128 pub fn query(&mut self, expr: Expr<Name<'cx>>) -> ProveCx<'_, 'cx> {
129 if let Some(revert_point) = self.revert_point.take() {
131 self.revert(revert_point);
132 }
133
134 self.prover
135 .prove(expr, &self.clauses, &mut self.stor, &mut self.nimap)
136 }
137
138 pub fn commit(&mut self) {
139 self.revert_point.take();
140 }
141
142 pub fn to_prolog<F: FnMut(&str) -> &str>(&self, sanitize: F) -> String {
144 let mut prolog_text = String::new();
145
146 let mut conv_map = ConversionMap {
147 int_to_str: Map::default(),
148 sanitized_to_suffix: Map::default(),
149 int2name: &self.nimap.int2name,
150 sanitizer: sanitize,
151 };
152
153 for clauses in self.clauses.values() {
154 for clause in clauses {
155 let head = self.stor.get_term(clause.head);
156 write_term(head, &mut conv_map, &mut prolog_text);
157
158 if let Some(body) = clause.body {
159 prolog_text.push_str(" :- ");
160
161 let body = self.stor.get_expr(body);
162 write_expr(body, &mut conv_map, &mut prolog_text);
163 }
164
165 prolog_text.push_str(".\n");
166 }
167 }
168
169 return prolog_text;
170
171 struct ConversionMap<'a, 'cx, F> {
174 int_to_str: Map<Int, String>,
175 sanitized_to_suffix: Map<&'a str, u32>,
177 int2name: &'a IdxMap<'cx, Int, Name<'cx>>,
178 sanitizer: F,
179 }
180
181 impl<F: FnMut(&str) -> &str> ConversionMap<'_, '_, F> {
182 fn int_to_str(&mut self, int: Int) -> &str {
183 self.int_to_str.entry(int).or_insert_with(|| {
184 let name = self.int2name.get(&int).unwrap();
185
186 let mut is_var = false;
187
188 let name = if name.starts_with(VAR_PREFIX) {
190 is_var = true;
191 &name[1..]
192 } else {
193 name
194 };
195
196 let pure_name = (self.sanitizer)(name);
198
199 let suffix = self
200 .sanitized_to_suffix
201 .entry(pure_name)
202 .and_modify(|x| *x += 1)
203 .or_insert(0);
204
205 let mut buf = String::new();
206
207 if is_var {
208 let upper = pure_name.chars().next().unwrap().to_uppercase();
209 for c in upper {
210 buf.push(c);
211 }
212 } else {
213 let lower = pure_name.chars().next().unwrap().to_lowercase();
214 for c in lower {
215 buf.push(c);
216 }
217 };
218 buf.push_str(&pure_name[1..]);
219
220 if *suffix == 0 {
221 buf
222 } else {
223 write!(&mut buf, "_{suffix}").unwrap();
224 buf
225 }
226 })
227 }
228 }
229
230 fn write_term<F: FnMut(&str) -> &str>(
231 term: TermView<'_, Int>,
232 conv_map: &mut ConversionMap<'_, '_, F>,
233 prolog_text: &mut String,
234 ) {
235 let functor = term.functor();
236 let args = term.args();
237 let num_args = args.len();
238
239 let functor = conv_map.int_to_str(*functor);
240 prolog_text.push_str(functor);
241
242 if num_args > 0 {
243 prolog_text.push('(');
244 for (i, arg) in args.enumerate() {
245 write_term(arg, conv_map, prolog_text);
246 if i + 1 < num_args {
247 prolog_text.push_str(", ");
248 }
249 }
250 prolog_text.push(')');
251 }
252 }
253
254 fn write_expr<F: FnMut(&str) -> &str>(
255 expr: ExprView<'_, Int>,
256 conv_map: &mut ConversionMap<'_, '_, F>,
257 prolog_text: &mut String,
258 ) {
259 match expr.as_kind() {
260 ExprKind::Term(term) => {
261 write_term(term, conv_map, prolog_text);
262 }
263 ExprKind::Not(inner) => {
264 prolog_text.push_str("\\+ ");
265 if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) {
266 prolog_text.push('(');
267 write_expr(inner, conv_map, prolog_text);
268 prolog_text.push(')');
269 } else {
270 write_expr(inner, conv_map, prolog_text);
271 }
272 }
273 ExprKind::And(args) => {
274 let num_args = args.len();
275 for (i, arg) in args.enumerate() {
276 if matches!(arg.as_kind(), ExprKind::Or(_)) {
277 prolog_text.push('(');
278 write_expr(arg, conv_map, prolog_text);
279 prolog_text.push(')');
280 } else {
281 write_expr(arg, conv_map, prolog_text);
282 }
283 if i + 1 < num_args {
284 prolog_text.push_str(", ");
285 }
286 }
287 }
288 ExprKind::Or(args) => {
289 let num_args = args.len();
290 for (i, arg) in args.enumerate() {
291 write_expr(arg, conv_map, prolog_text);
292 if i + 1 < num_args {
293 prolog_text.push_str("; ");
294 }
295 }
296 }
297 }
298 }
299 }
300
301 fn revert(
302 &mut self,
303 DatabaseState {
304 clauses_len,
305 clause_texts_len,
306 stor_len,
307 nimap_state,
308 }: DatabaseState,
309 ) {
310 self.clauses.truncate(clauses_len.len());
311 for (i, len) in clauses_len.into_iter().enumerate() {
312 self.clauses[i].truncate(len);
313 }
314 self.clause_texts.truncate(clause_texts_len);
315 self.stor.truncate(stor_len);
316 self.nimap.revert(nimap_state);
317 }
319
320 fn state(&self) -> DatabaseState {
321 DatabaseState {
322 clauses_len: self.clauses.values().map(|v| v.len()).collect(),
323 clause_texts_len: self.clause_texts.len(),
324 stor_len: self.stor.len(),
325 nimap_state: self.nimap.state(),
326 }
327 }
328}
329
330#[derive(Debug, PartialEq, Eq)]
331struct DatabaseState {
332 clauses_len: Vec<usize>,
333 clause_texts_len: usize,
334 stor_len: TermStorageLen,
335 nimap_state: NameIntMapState,
336}
337
338#[derive(Clone)]
339pub struct ClauseIter<'a, 'cx> {
340 clauses: &'a IndexMap<Predicate<Int>, Vec<ClauseId>>,
341 stor: &'a TermStorage<Int>,
342 int2name: &'a IdxMap<'cx, Int, Name<'cx>>,
343 i: usize,
344 j: usize,
345}
346
347impl<'a, 'cx> Iterator for ClauseIter<'a, 'cx> {
348 type Item = ClauseRef<'a, 'cx>;
349
350 fn next(&mut self) -> Option<Self::Item> {
351 let id = loop {
352 let (_, group) = self.clauses.get_index(self.i)?;
353
354 if let Some(id) = group.get(self.j) {
355 self.j += 1;
356 break *id;
357 }
358
359 self.i += 1;
360 self.j = 0;
361 };
362
363 Some(ClauseRef {
364 id,
365 stor: self.stor,
366 int2name: self.int2name,
367 })
368 }
369}
370
371impl iter::FusedIterator for ClauseIter<'_, '_> {}
372
373pub struct ClauseRef<'a, 'cx> {
374 id: ClauseId,
375 stor: &'a TermStorage<Int>,
376 int2name: &'a IdxMap<'cx, Int, Name<'cx>>,
377}
378
379impl<'a, 'cx> ClauseRef<'a, 'cx> {
380 pub fn head(&self) -> NamedTermView<'a, 'cx> {
381 let head = self.stor.get_term(self.id.head);
382 NamedTermView::new(head, self.int2name)
383 }
384
385 pub fn body(&self) -> Option<NamedExprView<'a, 'cx>> {
386 self.id.body.map(|id| {
387 let body = self.stor.get_expr(id);
388 NamedExprView::new(body, self.int2name)
389 })
390 }
391}
392
393impl fmt::Display for ClauseRef<'_, '_> {
394 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395 fmt::Display::fmt(&self.head(), f)?;
396
397 if let Some(body) = self.body() {
398 f.write_str(" :- ")?;
399 fmt::Display::fmt(&body, f)?
400 }
401
402 f.write_char('.')
403 }
404}
405
406impl fmt::Debug for ClauseRef<'_, '_> {
407 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
408 let mut d = f.debug_struct("Clause");
409
410 let head = self.stor.get_term(self.id.head);
411 d.field("head", &NamedTermView::new(head, self.int2name));
412
413 if let Some(body) = self.id.body {
414 let body = self.stor.get_expr(body);
415 d.field("body", &NamedExprView::new(body, self.int2name));
416 }
417
418 d.finish()
419 }
420}
421
422impl<'cx> Clause<Name<'cx>> {
423 pub(crate) fn convert_var_into_num(&self, gcx: &GlobalCx<'cx>) -> Option<Self> {
425 let mut cloned: Option<Self> = None;
426
427 let mut i = 0;
428
429 while let Some(var) = find_var_in_clause(cloned.as_ref().unwrap_or(self)) {
430 let from = var.clone();
431
432 let mut convert = |term: &Term<Name<'_>>| {
433 (term == &from).then_some(Term {
434 functor: Name::create(gcx, &format!("_{VAR_PREFIX}{i}")),
435 args: [].into(),
436 })
437 };
438
439 if let Some(cloned) = &mut cloned {
440 cloned.replace_term(&mut convert);
441 } else {
442 let mut this = self.clone();
443 this.replace_term(&mut convert);
444 cloned = Some(this);
445 }
446
447 i += 1;
448 }
449
450 return cloned;
451
452 fn find_var_in_clause<'a, 'cx>(
455 clause: &'a Clause<Name<'cx>>,
456 ) -> Option<&'a Term<Name<'cx>>> {
457 let var = find_var_in_term(&clause.head);
458 if var.is_some() {
459 return var;
460 }
461 find_var_in_expr(clause.body.as_ref()?)
462 }
463
464 fn find_var_in_expr<'a, 'cx>(expr: &'a Expr<Name<'cx>>) -> Option<&'a Term<Name<'cx>>> {
465 match expr {
466 Expr::Term(term) => find_var_in_term(term),
467 Expr::Not(inner) => find_var_in_expr(inner),
468 Expr::And(args) | Expr::Or(args) => args.iter().find_map(find_var_in_expr),
469 }
470 }
471
472 fn find_var_in_term<'a, 'cx>(term: &'a Term<Name<'cx>>) -> Option<&'a Term<Name<'cx>>> {
473 const _: () = assert!(VAR_PREFIX == '$');
474
475 if term.is_variable() && !term.functor.starts_with("_$") {
476 Some(term)
477 } else {
478 term.args.iter().find_map(find_var_in_term)
479 }
480 }
481 }
482}
483
484pub struct NamedTermViewIter<'a, 'cx> {
485 term_iter: TermViewIter<'a, Int>,
486 int2name: &'a IdxMap<'cx, Int, Name<'cx>>,
487}
488
489impl<'a, 'cx> Iterator for NamedTermViewIter<'a, 'cx> {
490 type Item = NamedTermView<'a, 'cx>;
491
492 fn next(&mut self) -> Option<Self::Item> {
493 self.term_iter
494 .next()
495 .map(|view| NamedTermView::new(view, self.int2name))
496 }
497}
498
499impl iter::FusedIterator for NamedTermViewIter<'_, '_> {}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504 use crate::parse::{
505 self,
506 repr::{Clause, Expr},
507 };
508
509 #[test]
510 fn test_parse() {
511 fn assert(gcx: &GlobalCx<'_>, text: &str) {
512 let clause: Clause<Name<'_>> = parse::parse_str(gcx, text).unwrap();
513 assert_eq!(text, clause.to_string());
514 }
515
516 let interner = DroplessInterner::default();
517 let gcx = GlobalCx {
518 interner: &interner,
519 };
520
521 assert(&gcx, "f.");
522 assert(&gcx, "f(a, b).");
523 assert(&gcx, "f(a, b) :- f.");
524 assert(&gcx, "f(a, b) :- f(a).");
525 assert(&gcx, "f(a, b) :- f(a), f(b).");
526 assert(&gcx, "f(a, b) :- f(a); f(b).");
527 assert(&gcx, "f(a, b) :- f(a), (f(b); f(c)).");
528 }
529
530 #[test]
531 fn test_serial_queries() {
532 let interner = DroplessInterner::default();
533 let mut db = Database::new(&interner);
534
535 fn insert(db: &mut Database<'_>) {
536 insert_dataset(
537 db,
538 r"
539 f(a).
540 f(b).
541 g($X) :- f($X).
542 ",
543 );
544 }
545
546 fn query(db: &mut Database<'_>) {
547 let query = "g($X).";
548 let query: Expr<Name<'_>> = parse::parse_str(db.gcx(), query).unwrap();
549 let answer = collect_answer(db.query(query));
550
551 let expected = [["$X = a"], ["$X = b"]];
552
553 assert_eq!(answer, expected);
554 }
555
556 insert(&mut db);
557 let org_stor_len = db.stor.len();
558 query(&mut db);
559 debug_assert_eq!(org_stor_len, db.stor.len());
560
561 insert(&mut db);
562 debug_assert_eq!(org_stor_len, db.stor.len());
563 query(&mut db);
564 debug_assert_eq!(org_stor_len, db.stor.len());
565 }
566
567 #[test]
568 fn test_various_expressions() {
569 test_not_expression();
570 test_and_expression();
571 test_or_expression();
572 test_mixed_expression();
573 }
574
575 fn test_not_expression() {
576 let interner = DroplessInterner::default();
577 let mut db = Database::new(&interner);
578
579 insert_dataset(
580 &mut db,
581 r"
582 g(a).
583 f($X) :- \+ g($X).
584 ",
585 );
586
587 let query = "f(a).";
588 let query: Expr<Name> = parse::parse_str(db.gcx(), query).unwrap();
589 let answer = collect_answer(db.query(query));
590 assert!(answer.is_empty());
591
592 let query = "f(b).";
593 let query: Expr<Name> = parse::parse_str(db.gcx(), query).unwrap();
594 let answer = collect_answer(db.query(query));
595 assert_eq!(answer.len(), 1);
596 }
597
598 fn test_and_expression() {
599 let interner = DroplessInterner::default();
600 let mut db = Database::new(&interner);
601
602 insert_dataset(
603 &mut db,
604 r"
605 g(a).
606 g(b).
607 h(b).
608 f($X) :- g($X), h($X).
609 ",
610 );
611
612 let query = "f($X).";
613 let query: Expr<Name> = parse::parse_str(db.gcx(), query).unwrap();
614 let answer = collect_answer(db.query(query));
615
616 let expected = [["$X = b"]];
617
618 assert_eq!(answer, expected);
619 }
620
621 fn test_or_expression() {
622 let interner = DroplessInterner::default();
623 let mut db = Database::new(&interner);
624
625 insert_dataset(
626 &mut db,
627 r"
628 g(a).
629 h(b).
630 f($X) :- g($X); h($X).
631 ",
632 );
633
634 let query = "f($X).";
635 let query: Expr<Name> = parse::parse_str(db.gcx(), query).unwrap();
636 let answer = collect_answer(db.query(query));
637
638 let expected = [["$X = a"], ["$X = b"]];
639
640 assert_eq!(answer, expected);
641 }
642
643 fn test_mixed_expression() {
644 let interner = DroplessInterner::default();
645 let mut db = Database::new(&interner);
646
647 insert_dataset(
648 &mut db,
649 r"
650 g(b).
651 g(c).
652
653 h(b).
654
655 i(a).
656 i(b).
657 i(c).
658
659 f($X) :- (\+ g($X); h($X)), i($X).
660 ",
661 );
662
663 let query = "f($X).";
664 let query: Expr<Name> = parse::parse_str(db.gcx(), query).unwrap();
665 let answer = collect_answer(db.query(query));
666
667 let expected = [["$X = b"]];
668
669 assert_eq!(answer, expected);
670 }
671
672 #[test]
673 fn test_recursion() {
674 test_simple_recursion();
675 test_right_recursion();
676 }
677
678 fn test_simple_recursion() {
679 let interner = DroplessInterner::default();
680 let mut db = Database::new(&interner);
681
682 insert_dataset(
683 &mut db,
684 r"
685 impl(Clone, a).
686 impl(Clone, b).
687 impl(Clone, c).
688 impl(Clone, Vec($T)) :- impl(Clone, $T).
689 ",
690 );
691
692 let query = "impl(Clone, $T).";
693 let query: Expr<Name<'_>> = parse::parse_str(db.gcx(), query).unwrap();
694 let mut cx = db.query(query);
695
696 let mut assert_next = |expected: &[&str]| {
697 let eval = cx.prove_next().unwrap();
698 let assignments = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
699 assert_eq!(assignments, expected);
700 };
701
702 assert_next(&["$T = a"]);
703 assert_next(&["$T = b"]);
704 assert_next(&["$T = c"]);
705 assert_next(&["$T = Vec(a)"]);
706 assert_next(&["$T = Vec(b)"]);
707 assert_next(&["$T = Vec(c)"]);
708 assert_next(&["$T = Vec(Vec(a))"]);
709 assert_next(&["$T = Vec(Vec(b))"]);
710 assert_next(&["$T = Vec(Vec(c))"]);
711 }
712
713 fn test_right_recursion() {
714 let interner = DroplessInterner::default();
715 let mut db = Database::new(&interner);
716
717 insert_dataset(
718 &mut db,
719 r"
720 child(a, b).
721 child(b, c).
722 child(c, d).
723 descend($X, $Y) :- child($X, $Y).
724 descend($X, $Z) :- child($X, $Y), descend($Y, $Z).
725 ",
726 );
727
728 let query = "descend($X, $Y).";
729 let query: Expr<Name<'_>> = parse::parse_str(db.gcx(), query).unwrap();
730 let mut answer = collect_answer(db.query(query));
731
732 let mut expected = [
733 ["$X = a", "$Y = b"],
734 ["$X = a", "$Y = c"],
735 ["$X = a", "$Y = d"],
736 ["$X = b", "$Y = c"],
737 ["$X = b", "$Y = d"],
738 ["$X = c", "$Y = d"],
739 ];
740
741 answer.sort_unstable();
742 expected.sort_unstable();
743 assert_eq!(answer, expected);
744 }
745
746 #[test]
747 fn test_discarding_uncomitted_change() {
748 let interner = DroplessInterner::default();
749 let mut db = Database::new(&interner);
750
751 let text = "f(a).";
752 let clause = parse::parse_str(db.gcx(), text).unwrap();
753 db.insert_clause(clause);
754 let fa_state = db.state();
755 db.commit();
756
757 let text = "f(b).";
758 let clause = parse::parse_str(db.gcx(), text).unwrap();
759 db.insert_clause(clause);
760
761 let query = "f($X).";
762 let query: Expr<Name<'_>> = parse::parse_str(db.gcx(), query).unwrap();
763 let answer = collect_answer(db.query(query));
764
765 let expected = [["$X = a"]];
767 assert_eq!(answer, expected);
768 assert_eq!(db.state(), fa_state);
769 }
770
771 fn insert_dataset(db: &mut Database, text: &str) {
772 let dataset: ClauseDataset<Name<'_>> = parse::parse_str(db.gcx(), text).unwrap();
773 db.insert_dataset(dataset);
774 db.commit();
775 }
776
777 fn collect_answer(mut cx: ProveCx<'_, '_>) -> Vec<Vec<String>> {
778 let mut v = Vec::new();
779 while let Some(eval) = cx.prove_next() {
780 let x = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
781 v.push(x);
782 }
783 v
784 }
785}