1use super::{
2 prover::{
3 Int, NameIntMap, NameIntMapState, ProveCx, Prover,
4 format::{NamedExprView, NamedTermView},
5 },
6 repr::{ClauseId, TermStorage, TermStorageLen},
7};
8use crate::{
9 Map,
10 parse::{
11 VAR_PREFIX,
12 repr::{Clause, ClauseDataset, Expr, Predicate, Term},
13 text::Name,
14 },
15 prove::repr::{ExprKind, ExprView, TermView, TermViewIter},
16};
17use indexmap::{IndexMap, IndexSet};
18use std::{
19 fmt::{self, Write},
20 iter,
21};
22
23#[derive(Debug)]
24pub struct Database {
25 clauses: IndexMap<Predicate<Int>, Vec<ClauseId>>,
27
28 clause_texts: IndexSet<String>,
30
31 stor: TermStorage<Int>,
33
34 prover: Prover,
36
37 nimap: NameIntMap,
42
43 revert_point: Option<DatabaseState>,
47}
48
49impl Database {
50 pub fn new() -> Self {
51 Self {
52 clauses: IndexMap::default(),
53 clause_texts: IndexSet::default(),
54 stor: TermStorage::new(),
55 prover: Prover::new(),
56 nimap: NameIntMap::new(),
57 revert_point: None,
58 }
59 }
60
61 pub fn terms(&self) -> NamedTermViewIter<'_> {
62 NamedTermViewIter {
63 term_iter: self.stor.terms.terms(),
64 int2name: &self.nimap.int2name,
65 }
66 }
67
68 pub fn clauses(&self) -> impl iter::FusedIterator<Item = ClauseRef<'_>> {
69 self.clauses
70 .values()
71 .flat_map(|group| group.iter().cloned())
72 .map(|id| ClauseRef {
73 id,
74 stor: &self.stor,
75 int2name: &self.nimap.int2name,
76 })
77 }
78
79 pub fn insert_dataset(&mut self, dataset: ClauseDataset<Name>) {
80 for clause in dataset {
81 self.insert_clause(clause);
82 }
83 }
84
85 pub fn insert_clause(&mut self, clause: Clause<Name>) {
86 if self.revert_point.is_none() {
89 self.revert_point = Some(self.state());
90 }
91
92 let serialized = if let Some(converted) = clause.convert_var_into_num() {
94 converted.to_string()
95 } else {
96 clause.to_string()
97 };
98 if !self.clause_texts.insert(serialized) {
99 return;
100 }
101
102 let clause = clause.map(&mut |name| self.nimap.name_to_int(name));
103
104 let key = clause.head.predicate();
105 let value = ClauseId {
106 head: self.stor.insert_term(clause.head),
107 body: clause.body.map(|expr| self.stor.insert_expr(expr)),
108 };
109
110 self.clauses
111 .entry(key)
112 .and_modify(|similar_clauses| {
113 if similar_clauses.iter().all(|clause| clause != &value) {
114 similar_clauses.push(value);
115 }
116 })
117 .or_insert(vec![value]);
118 }
119
120 pub fn query(&mut self, expr: Expr<Name>) -> ProveCx<'_> {
121 if let Some(revert_point) = self.revert_point.take() {
123 self.revert(revert_point);
124 }
125
126 self.prover
127 .prove(expr, &self.clauses, &mut self.stor, &mut self.nimap)
128 }
129
130 pub fn commit(&mut self) {
131 self.revert_point.take();
132 }
133
134 pub fn to_prolog<F: FnMut(&str) -> &str>(&self, sanitize: F) -> String {
136 let mut prolog_text = String::new();
137
138 let mut conv_map = ConversionMap {
139 int_to_str: Map::default(),
140 sanitized_to_suffix: Map::default(),
141 int2name: &self.nimap.int2name,
142 sanitizer: sanitize,
143 };
144
145 for clauses in self.clauses.values() {
146 for clause in clauses {
147 let head = self.stor.get_term(clause.head);
148 write_term(head, &mut conv_map, &mut prolog_text);
149
150 if let Some(body) = clause.body {
151 prolog_text.push_str(" :- ");
152
153 let body = self.stor.get_expr(body);
154 write_expr(body, &mut conv_map, &mut prolog_text);
155 }
156
157 prolog_text.push_str(".\n");
158 }
159 }
160
161 return prolog_text;
162
163 struct ConversionMap<'a, F> {
166 int_to_str: Map<Int, String>,
167 sanitized_to_suffix: Map<&'a str, u32>,
169 int2name: &'a IndexMap<Int, Name>,
170 sanitizer: F,
171 }
172
173 impl<F: FnMut(&str) -> &str> ConversionMap<'_, F> {
174 fn int_to_str(&mut self, int: Int) -> &str {
175 self.int_to_str.entry(int).or_insert_with(|| {
176 let name = self.int2name.get(&int).unwrap();
177
178 let mut is_var = false;
179
180 let name = if name.starts_with(VAR_PREFIX) {
182 is_var = true;
183 &name[1..]
184 } else {
185 name
186 };
187
188 let pure_name = (self.sanitizer)(name);
190
191 let suffix = self
192 .sanitized_to_suffix
193 .entry(pure_name)
194 .and_modify(|x| *x += 1)
195 .or_insert(0);
196
197 let mut buf = String::new();
198
199 if is_var {
200 let upper = pure_name.chars().next().unwrap().to_uppercase();
201 for c in upper {
202 buf.push(c);
203 }
204 } else {
205 let lower = pure_name.chars().next().unwrap().to_lowercase();
206 for c in lower {
207 buf.push(c);
208 }
209 };
210 buf.push_str(&pure_name[1..]);
211
212 if *suffix == 0 {
213 buf
214 } else {
215 write!(&mut buf, "_{suffix}").unwrap();
216 buf
217 }
218 })
219 }
220 }
221
222 fn write_term<F: FnMut(&str) -> &str>(
223 term: TermView<'_, Int>,
224 conv_map: &mut ConversionMap<'_, F>,
225 prolog_text: &mut String,
226 ) {
227 let functor = term.functor();
228 let args = term.args();
229 let num_args = args.len();
230
231 let functor = conv_map.int_to_str(*functor);
232 prolog_text.push_str(functor);
233
234 if num_args > 0 {
235 prolog_text.push('(');
236 for (i, arg) in args.enumerate() {
237 write_term(arg, conv_map, prolog_text);
238 if i + 1 < num_args {
239 prolog_text.push_str(", ");
240 }
241 }
242 prolog_text.push(')');
243 }
244 }
245
246 fn write_expr<F: FnMut(&str) -> &str>(
247 expr: ExprView<'_, Int>,
248 conv_map: &mut ConversionMap<'_, F>,
249 prolog_text: &mut String,
250 ) {
251 match expr.as_kind() {
252 ExprKind::Term(term) => {
253 write_term(term, conv_map, prolog_text);
254 }
255 ExprKind::Not(inner) => {
256 prolog_text.push_str("\\+ ");
257 if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) {
258 prolog_text.push('(');
259 write_expr(inner, conv_map, prolog_text);
260 prolog_text.push(')');
261 } else {
262 write_expr(inner, conv_map, prolog_text);
263 }
264 }
265 ExprKind::And(args) => {
266 let num_args = args.len();
267 for (i, arg) in args.enumerate() {
268 if matches!(arg.as_kind(), ExprKind::Or(_)) {
269 prolog_text.push('(');
270 write_expr(arg, conv_map, prolog_text);
271 prolog_text.push(')');
272 } else {
273 write_expr(arg, conv_map, prolog_text);
274 }
275 if i + 1 < num_args {
276 prolog_text.push_str(", ");
277 }
278 }
279 }
280 ExprKind::Or(args) => {
281 let num_args = args.len();
282 for (i, arg) in args.enumerate() {
283 write_expr(arg, conv_map, prolog_text);
284 if i + 1 < num_args {
285 prolog_text.push_str("; ");
286 }
287 }
288 }
289 }
290 }
291 }
292
293 fn revert(
294 &mut self,
295 DatabaseState {
296 clauses_len,
297 clause_texts_len,
298 stor_len,
299 nimap_state,
300 }: DatabaseState,
301 ) {
302 self.clauses.truncate(clauses_len.len());
303 for (i, len) in clauses_len.into_iter().enumerate() {
304 self.clauses[i].truncate(len);
305 }
306 self.clause_texts.truncate(clause_texts_len);
307 self.stor.truncate(stor_len);
308 self.nimap.revert(nimap_state);
309 }
311
312 fn state(&self) -> DatabaseState {
313 DatabaseState {
314 clauses_len: self.clauses.values().map(|v| v.len()).collect(),
315 clause_texts_len: self.clause_texts.len(),
316 stor_len: self.stor.len(),
317 nimap_state: self.nimap.state(),
318 }
319 }
320}
321
322impl Default for Database {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328#[derive(Debug, PartialEq, Eq)]
329struct DatabaseState {
330 clauses_len: Vec<usize>,
331 clause_texts_len: usize,
332 stor_len: TermStorageLen,
333 nimap_state: NameIntMapState,
334}
335
336pub struct ClauseRef<'a> {
337 id: ClauseId,
338 stor: &'a TermStorage<Int>,
339 int2name: &'a IndexMap<Int, Name>,
340}
341
342impl fmt::Display for ClauseRef<'_> {
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 let head = self.stor.get_term(self.id.head);
345 fmt::Display::fmt(&NamedTermView::new(head, self.int2name), f)?;
346
347 if let Some(body) = self.id.body {
348 let body = self.stor.get_expr(body);
349 f.write_str(" :- ")?;
350 fmt::Display::fmt(&NamedExprView::new(body, self.int2name), f)?
351 }
352 f.write_char('.')
353 }
354}
355
356impl fmt::Debug for ClauseRef<'_> {
357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358 let mut d = f.debug_struct("Clause");
359
360 let head = self.stor.get_term(self.id.head);
361 d.field("head", &NamedTermView::new(head, self.int2name));
362
363 if let Some(body) = self.id.body {
364 let body = self.stor.get_expr(body);
365 d.field("body", &NamedExprView::new(body, self.int2name));
366 }
367
368 d.finish()
369 }
370}
371
372impl Clause<Name> {
373 pub(crate) fn convert_var_into_num(&self) -> Option<Self> {
375 let mut cloned: Option<Self> = None;
376
377 let mut i = 0;
378
379 while let Some(var) = find_var_in_clause(cloned.as_ref().unwrap_or(self)) {
380 let from = var.clone();
381
382 let mut convert = |term: &Term<Name>| {
383 (term == &from).then_some(Term {
384 functor: format!("_{VAR_PREFIX}{i}").into(),
385 args: [].into(),
386 })
387 };
388
389 if let Some(cloned) = &mut cloned {
390 cloned.replace_term(&mut convert);
391 } else {
392 let mut this = self.clone();
393 this.replace_term(&mut convert);
394 cloned = Some(this);
395 }
396
397 i += 1;
398 }
399
400 return cloned;
401
402 fn find_var_in_clause(clause: &Clause<Name>) -> Option<&Term<Name>> {
405 let var = find_var_in_term(&clause.head);
406 if var.is_some() {
407 return var;
408 }
409 find_var_in_expr(clause.body.as_ref()?)
410 }
411
412 fn find_var_in_expr(expr: &Expr<Name>) -> Option<&Term<Name>> {
413 match expr {
414 Expr::Term(term) => find_var_in_term(term),
415 Expr::Not(inner) => find_var_in_expr(inner),
416 Expr::And(args) | Expr::Or(args) => args.iter().find_map(find_var_in_expr),
417 }
418 }
419
420 fn find_var_in_term(term: &Term<Name>) -> Option<&Term<Name>> {
421 const _: () = assert!(VAR_PREFIX == '$');
422
423 if term.is_variable() && !term.functor.starts_with("_$") {
424 Some(term)
425 } else {
426 term.args.iter().find_map(find_var_in_term)
427 }
428 }
429 }
430}
431
432pub struct NamedTermViewIter<'a> {
433 term_iter: TermViewIter<'a, Int>,
434 int2name: &'a IndexMap<Int, Name>,
435}
436
437impl<'a> Iterator for NamedTermViewIter<'a> {
438 type Item = NamedTermView<'a>;
439
440 fn next(&mut self) -> Option<Self::Item> {
441 self.term_iter
442 .next()
443 .map(|view| NamedTermView::new(view, self.int2name))
444 }
445}
446
447impl iter::FusedIterator for NamedTermViewIter<'_> {}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use crate::parse::{
453 self,
454 repr::{Clause, Expr},
455 };
456
457 #[test]
458 fn test_parse() {
459 fn assert(text: &str) {
460 let clause: Clause<Name> = parse::parse_str(text).unwrap();
461 assert_eq!(text, clause.to_string());
462 }
463
464 assert("f.");
465 assert("f(a, b).");
466 assert("f(a, b) :- f.");
467 assert("f(a, b) :- f(a).");
468 assert("f(a, b) :- f(a), f(b).");
469 assert("f(a, b) :- f(a); f(b).");
470 assert("f(a, b) :- f(a), (f(b); f(c)).");
471 }
472
473 #[test]
474 fn test_serial_queries() {
475 let mut db = Database::new();
476
477 fn insert(db: &mut Database) {
478 insert_dataset(
479 db,
480 r"
481 f(a).
482 f(b).
483 g($X) :- f($X).
484 ",
485 );
486 }
487
488 fn query(db: &mut Database) {
489 let query = "g($X).";
490 let query: Expr<Name> = parse::parse_str(query).unwrap();
491 let answer = collect_answer(db.query(query));
492
493 let expected = [["$X = a"], ["$X = b"]];
494
495 assert_eq!(answer, expected);
496 }
497
498 insert(&mut db);
499 let org_stor_len = db.stor.len();
500 query(&mut db);
501 debug_assert_eq!(org_stor_len, db.stor.len());
502
503 insert(&mut db);
504 debug_assert_eq!(org_stor_len, db.stor.len());
505 query(&mut db);
506 debug_assert_eq!(org_stor_len, db.stor.len());
507 }
508
509 #[test]
510 fn test_various_expressions() {
511 test_not_expression();
512 test_and_expression();
513 test_or_expression();
514 test_mixed_expression();
515 }
516
517 fn test_not_expression() {
518 let mut db = Database::new();
519
520 insert_dataset(
521 &mut db,
522 r"
523 g(a).
524 f($X) :- \+ g($X).
525 ",
526 );
527
528 let query = "f(a).";
529 let query: Expr<Name> = parse::parse_str(query).unwrap();
530 let answer = collect_answer(db.query(query));
531 assert!(answer.is_empty());
532
533 let query = "f(b).";
534 let query: Expr<Name> = parse::parse_str(query).unwrap();
535 let answer = collect_answer(db.query(query));
536 assert_eq!(answer.len(), 1);
537 }
538
539 fn test_and_expression() {
540 let mut db = Database::new();
541
542 insert_dataset(
543 &mut db,
544 r"
545 g(a).
546 g(b).
547 h(b).
548 f($X) :- g($X), h($X).
549 ",
550 );
551
552 let query = "f($X).";
553 let query: Expr<Name> = parse::parse_str(query).unwrap();
554 let answer = collect_answer(db.query(query));
555
556 let expected = [["$X = b"]];
557
558 assert_eq!(answer, expected);
559 }
560
561 fn test_or_expression() {
562 let mut db = Database::new();
563
564 insert_dataset(
565 &mut db,
566 r"
567 g(a).
568 h(b).
569 f($X) :- g($X); h($X).
570 ",
571 );
572
573 let query = "f($X).";
574 let query: Expr<Name> = parse::parse_str(query).unwrap();
575 let answer = collect_answer(db.query(query));
576
577 let expected = [["$X = a"], ["$X = b"]];
578
579 assert_eq!(answer, expected);
580 }
581
582 fn test_mixed_expression() {
583 let mut db = Database::new();
584
585 insert_dataset(
586 &mut db,
587 r"
588 g(b).
589 g(c).
590
591 h(b).
592
593 i(a).
594 i(b).
595 i(c).
596
597 f($X) :- (\+ g($X); h($X)), i($X).
598 ",
599 );
600
601 let query = "f($X).";
602 let query: Expr<Name> = parse::parse_str(query).unwrap();
603 let answer = collect_answer(db.query(query));
604
605 let expected = [["$X = b"]];
606
607 assert_eq!(answer, expected);
608 }
609
610 #[test]
611 fn test_recursion() {
612 test_simple_recursion();
613 test_right_recursion();
614 }
615
616 fn test_simple_recursion() {
617 let mut db = Database::new();
618
619 insert_dataset(
620 &mut db,
621 r"
622 impl(Clone, a).
623 impl(Clone, b).
624 impl(Clone, c).
625 impl(Clone, Vec($T)) :- impl(Clone, $T).
626 ",
627 );
628
629 let query = "impl(Clone, $T).";
630 let query: Expr<Name> = parse::parse_str(query).unwrap();
631 let mut cx = db.query(query);
632
633 let mut assert_next = |expected: &[&str]| {
634 let eval = cx.prove_next().unwrap();
635 let assignments = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
636 assert_eq!(assignments, expected);
637 };
638
639 assert_next(&["$T = a"]);
640 assert_next(&["$T = b"]);
641 assert_next(&["$T = c"]);
642 assert_next(&["$T = Vec(a)"]);
643 assert_next(&["$T = Vec(b)"]);
644 assert_next(&["$T = Vec(c)"]);
645 assert_next(&["$T = Vec(Vec(a))"]);
646 assert_next(&["$T = Vec(Vec(b))"]);
647 assert_next(&["$T = Vec(Vec(c))"]);
648 }
649
650 fn test_right_recursion() {
651 let mut db = Database::new();
652
653 insert_dataset(
654 &mut db,
655 r"
656 child(a, b).
657 child(b, c).
658 child(c, d).
659 descend($X, $Y) :- child($X, $Y).
660 descend($X, $Z) :- child($X, $Y), descend($Y, $Z).
661 ",
662 );
663
664 let query = "descend($X, $Y).";
665 let query: Expr<Name> = parse::parse_str(query).unwrap();
666 let mut answer = collect_answer(db.query(query));
667
668 let mut expected = [
669 ["$X = a", "$Y = b"],
670 ["$X = a", "$Y = c"],
671 ["$X = a", "$Y = d"],
672 ["$X = b", "$Y = c"],
673 ["$X = b", "$Y = d"],
674 ["$X = c", "$Y = d"],
675 ];
676
677 answer.sort_unstable();
678 expected.sort_unstable();
679 assert_eq!(answer, expected);
680 }
681
682 #[test]
683 fn test_discarding_uncomitted_change() {
684 let mut db = Database::new();
685
686 let text = "f(a).";
687 let clause = parse::parse_str(text).unwrap();
688 db.insert_clause(clause);
689 let fa_state = db.state();
690 db.commit();
691
692 let text = "f(b).";
693 let clause = parse::parse_str(text).unwrap();
694 db.insert_clause(clause);
695
696 let query = "f($X).";
697 let query: Expr<Name> = parse::parse_str(query).unwrap();
698 let answer = collect_answer(db.query(query));
699
700 let expected = [["$X = a"]];
702 assert_eq!(answer, expected);
703 assert_eq!(db.state(), fa_state);
704 }
705
706 fn insert_dataset(db: &mut Database, text: &str) {
707 let dataset: ClauseDataset<Name> = parse::parse_str(text).unwrap();
708 db.insert_dataset(dataset);
709 db.commit();
710 }
711
712 fn collect_answer(mut cx: ProveCx<'_>) -> Vec<Vec<String>> {
713 let mut v = Vec::new();
714 while let Some(eval) = cx.prove_next() {
715 let x = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
716 v.push(x);
717 }
718 v
719 }
720}