logic_eval/prove/
db.rs

1use super::{
2    prover::{
3        IdxMap, Int, NameIntMap, NameIntMapState, ProveCx, Prover,
4        format::{NamedExprView, NamedTermView},
5    },
6    repr::{ClauseId, TermStorage, TermStorageLen},
7};
8use crate::{
9    Map,
10    parse::{
11        GlobalCx, VAR_PREFIX,
12        repr::{Clause, ClauseDataset, Expr, Predicate, Term},
13        text::Name,
14    },
15    prove::repr::{ExprKind, ExprView, TermView, TermViewIter},
16};
17use any_intern::DroplessInterner;
18use indexmap::{IndexMap, IndexSet};
19use std::{
20    fmt::{self, Write},
21    iter,
22};
23
24#[derive(Debug)]
25pub struct Database<'cx> {
26    /// Clause id dataset.
27    clauses: IndexMap<Predicate<Int>, Vec<ClauseId>>,
28
29    /// We do not allow duplicated clauses in the dataset.
30    clause_texts: IndexSet<String>,
31
32    /// Term and expression storage.
33    stor: TermStorage<Int>,
34
35    /// Proof search engine.
36    prover: Prover,
37
38    /// Mappings between [`Name`] and [`Int`].
39    ///
40    /// [`Int`] is internally used for fast comparison, but we need to get it
41    /// back to [`Name`] for the clients.
42    nimap: NameIntMap<'cx>,
43
44    /// States of DB's fields.
45    ///
46    /// This is used when we discard some changes on the DB.
47    revert_point: Option<DatabaseState>,
48
49    gcx: GlobalCx<'cx>,
50}
51
52impl<'cx> Database<'cx> {
53    pub fn new(interner: &'cx DroplessInterner) -> Self {
54        let gcx = GlobalCx { interner };
55        Self {
56            clauses: IndexMap::default(),
57            clause_texts: IndexSet::default(),
58            stor: TermStorage::new(),
59            prover: Prover::new(),
60            nimap: NameIntMap::new(gcx),
61            revert_point: None,
62            gcx,
63        }
64    }
65
66    pub fn gcx(&self) -> &GlobalCx<'cx> {
67        &self.gcx
68    }
69
70    pub fn terms(&self) -> NamedTermViewIter<'_, 'cx> {
71        NamedTermViewIter {
72            term_iter: self.stor.terms.terms(),
73            int2name: &self.nimap.int2name,
74        }
75    }
76
77    pub fn clauses(&self) -> ClauseIter<'_, 'cx> {
78        ClauseIter {
79            clauses: &self.clauses,
80            stor: &self.stor,
81            int2name: &self.nimap.int2name,
82            i: 0,
83            j: 0,
84        }
85    }
86
87    pub fn insert_dataset(&mut self, dataset: ClauseDataset<Name<'cx>>) {
88        for clause in dataset {
89            self.insert_clause(clause);
90        }
91    }
92
93    pub fn insert_clause(&mut self, clause: Clause<Name<'cx>>) {
94        // Saves current state. We will revert DB when the change is not
95        // committed.
96        if self.revert_point.is_none() {
97            self.revert_point = Some(self.state());
98        }
99
100        // If this DB contains the given clause, then returns.
101        let serialized = if let Some(converted) = clause.convert_var_into_num(&self.gcx) {
102            converted.to_string()
103        } else {
104            clause.to_string()
105        };
106        if !self.clause_texts.insert(serialized) {
107            return;
108        }
109
110        let clause = clause.map(&mut |name| self.nimap.name_to_int(name));
111
112        let key = clause.head.predicate();
113        let value = ClauseId {
114            head: self.stor.insert_term(clause.head),
115            body: clause.body.map(|expr| self.stor.insert_expr(expr)),
116        };
117
118        self.clauses
119            .entry(key)
120            .and_modify(|similar_clauses| {
121                if similar_clauses.iter().all(|clause| clause != &value) {
122                    similar_clauses.push(value);
123                }
124            })
125            .or_insert(vec![value]);
126    }
127
128    pub fn query(&mut self, expr: Expr<Name<'cx>>) -> ProveCx<'_, 'cx> {
129        // Discards uncomitted changes.
130        if let Some(revert_point) = self.revert_point.take() {
131            self.revert(revert_point);
132        }
133
134        self.prover
135            .prove(expr, &self.clauses, &mut self.stor, &mut self.nimap)
136    }
137
138    pub fn commit(&mut self) {
139        self.revert_point.take();
140    }
141
142    /// * sanitize - Removes unacceptable characters from prolog.
143    pub fn to_prolog<F: FnMut(&str) -> &str>(&self, sanitize: F) -> String {
144        let mut prolog_text = String::new();
145
146        let mut conv_map = ConversionMap {
147            int_to_str: Map::default(),
148            sanitized_to_suffix: Map::default(),
149            int2name: &self.nimap.int2name,
150            sanitizer: sanitize,
151        };
152
153        for clauses in self.clauses.values() {
154            for clause in clauses {
155                let head = self.stor.get_term(clause.head);
156                write_term(head, &mut conv_map, &mut prolog_text);
157
158                if let Some(body) = clause.body {
159                    prolog_text.push_str(" :- ");
160
161                    let body = self.stor.get_expr(body);
162                    write_expr(body, &mut conv_map, &mut prolog_text);
163                }
164
165                prolog_text.push_str(".\n");
166            }
167        }
168
169        return prolog_text;
170
171        // === Internal helper functions ===
172
173        struct ConversionMap<'a, 'cx, F> {
174            int_to_str: Map<Int, String>,
175            // e.g. 0 -> No suffix, 1 -> _1, 2 -> _2, ...
176            sanitized_to_suffix: Map<&'a str, u32>,
177            int2name: &'a IdxMap<'cx, Int, Name<'cx>>,
178            sanitizer: F,
179        }
180
181        impl<F: FnMut(&str) -> &str> ConversionMap<'_, '_, F> {
182            fn int_to_str(&mut self, int: Int) -> &str {
183                self.int_to_str.entry(int).or_insert_with(|| {
184                    let name = self.int2name.get(&int).unwrap();
185
186                    let mut is_var = false;
187
188                    // Removes variable prefix.
189                    let name = if name.starts_with(VAR_PREFIX) {
190                        is_var = true;
191                        &name[1..]
192                    } else {
193                        name
194                    };
195
196                    // Removes other user-defined characters.
197                    let pure_name = (self.sanitizer)(name);
198
199                    let suffix = self
200                        .sanitized_to_suffix
201                        .entry(pure_name)
202                        .and_modify(|x| *x += 1)
203                        .or_insert(0);
204
205                    let mut buf = String::new();
206
207                    if is_var {
208                        let upper = pure_name.chars().next().unwrap().to_uppercase();
209                        for c in upper {
210                            buf.push(c);
211                        }
212                    } else {
213                        let lower = pure_name.chars().next().unwrap().to_lowercase();
214                        for c in lower {
215                            buf.push(c);
216                        }
217                    };
218                    buf.push_str(&pure_name[1..]);
219
220                    if *suffix == 0 {
221                        buf
222                    } else {
223                        write!(&mut buf, "_{suffix}").unwrap();
224                        buf
225                    }
226                })
227            }
228        }
229
230        fn write_term<F: FnMut(&str) -> &str>(
231            term: TermView<'_, Int>,
232            conv_map: &mut ConversionMap<'_, '_, F>,
233            prolog_text: &mut String,
234        ) {
235            let functor = term.functor();
236            let args = term.args();
237            let num_args = args.len();
238
239            let functor = conv_map.int_to_str(*functor);
240            prolog_text.push_str(functor);
241
242            if num_args > 0 {
243                prolog_text.push('(');
244                for (i, arg) in args.enumerate() {
245                    write_term(arg, conv_map, prolog_text);
246                    if i + 1 < num_args {
247                        prolog_text.push_str(", ");
248                    }
249                }
250                prolog_text.push(')');
251            }
252        }
253
254        fn write_expr<F: FnMut(&str) -> &str>(
255            expr: ExprView<'_, Int>,
256            conv_map: &mut ConversionMap<'_, '_, F>,
257            prolog_text: &mut String,
258        ) {
259            match expr.as_kind() {
260                ExprKind::Term(term) => {
261                    write_term(term, conv_map, prolog_text);
262                }
263                ExprKind::Not(inner) => {
264                    prolog_text.push_str("\\+ ");
265                    if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) {
266                        prolog_text.push('(');
267                        write_expr(inner, conv_map, prolog_text);
268                        prolog_text.push(')');
269                    } else {
270                        write_expr(inner, conv_map, prolog_text);
271                    }
272                }
273                ExprKind::And(args) => {
274                    let num_args = args.len();
275                    for (i, arg) in args.enumerate() {
276                        if matches!(arg.as_kind(), ExprKind::Or(_)) {
277                            prolog_text.push('(');
278                            write_expr(arg, conv_map, prolog_text);
279                            prolog_text.push(')');
280                        } else {
281                            write_expr(arg, conv_map, prolog_text);
282                        }
283                        if i + 1 < num_args {
284                            prolog_text.push_str(", ");
285                        }
286                    }
287                }
288                ExprKind::Or(args) => {
289                    let num_args = args.len();
290                    for (i, arg) in args.enumerate() {
291                        write_expr(arg, conv_map, prolog_text);
292                        if i + 1 < num_args {
293                            prolog_text.push_str("; ");
294                        }
295                    }
296                }
297            }
298        }
299    }
300
301    fn revert(
302        &mut self,
303        DatabaseState {
304            clauses_len,
305            clause_texts_len,
306            stor_len,
307            nimap_state,
308        }: DatabaseState,
309    ) {
310        self.clauses.truncate(clauses_len.len());
311        for (i, len) in clauses_len.into_iter().enumerate() {
312            self.clauses[i].truncate(len);
313        }
314        self.clause_texts.truncate(clause_texts_len);
315        self.stor.truncate(stor_len);
316        self.nimap.revert(nimap_state);
317        // `self.prover: Prover` does not store any persistent data.
318    }
319
320    fn state(&self) -> DatabaseState {
321        DatabaseState {
322            clauses_len: self.clauses.values().map(|v| v.len()).collect(),
323            clause_texts_len: self.clause_texts.len(),
324            stor_len: self.stor.len(),
325            nimap_state: self.nimap.state(),
326        }
327    }
328}
329
330#[derive(Debug, PartialEq, Eq)]
331struct DatabaseState {
332    clauses_len: Vec<usize>,
333    clause_texts_len: usize,
334    stor_len: TermStorageLen,
335    nimap_state: NameIntMapState,
336}
337
338#[derive(Clone)]
339pub struct ClauseIter<'a, 'cx> {
340    clauses: &'a IndexMap<Predicate<Int>, Vec<ClauseId>>,
341    stor: &'a TermStorage<Int>,
342    int2name: &'a IdxMap<'cx, Int, Name<'cx>>,
343    i: usize,
344    j: usize,
345}
346
347impl<'a, 'cx> Iterator for ClauseIter<'a, 'cx> {
348    type Item = ClauseRef<'a, 'cx>;
349
350    fn next(&mut self) -> Option<Self::Item> {
351        let id = loop {
352            let (_, group) = self.clauses.get_index(self.i)?;
353
354            if let Some(id) = group.get(self.j) {
355                self.j += 1;
356                break *id;
357            }
358
359            self.i += 1;
360            self.j = 0;
361        };
362
363        Some(ClauseRef {
364            id,
365            stor: self.stor,
366            int2name: self.int2name,
367        })
368    }
369}
370
371impl iter::FusedIterator for ClauseIter<'_, '_> {}
372
373pub struct ClauseRef<'a, 'cx> {
374    id: ClauseId,
375    stor: &'a TermStorage<Int>,
376    int2name: &'a IdxMap<'cx, Int, Name<'cx>>,
377}
378
379impl<'a, 'cx> ClauseRef<'a, 'cx> {
380    pub fn head(&self) -> NamedTermView<'a, 'cx> {
381        let head = self.stor.get_term(self.id.head);
382        NamedTermView::new(head, self.int2name)
383    }
384
385    pub fn body(&self) -> Option<NamedExprView<'a, 'cx>> {
386        self.id.body.map(|id| {
387            let body = self.stor.get_expr(id);
388            NamedExprView::new(body, self.int2name)
389        })
390    }
391}
392
393impl fmt::Display for ClauseRef<'_, '_> {
394    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395        fmt::Display::fmt(&self.head(), f)?;
396
397        if let Some(body) = self.body() {
398            f.write_str(" :- ")?;
399            fmt::Display::fmt(&body, f)?
400        }
401
402        f.write_char('.')
403    }
404}
405
406impl fmt::Debug for ClauseRef<'_, '_> {
407    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
408        let mut d = f.debug_struct("Clause");
409
410        let head = self.stor.get_term(self.id.head);
411        d.field("head", &NamedTermView::new(head, self.int2name));
412
413        if let Some(body) = self.id.body {
414            let body = self.stor.get_expr(body);
415            d.field("body", &NamedExprView::new(body, self.int2name));
416        }
417
418        d.finish()
419    }
420}
421
422impl<'cx> Clause<Name<'cx>> {
423    /// Turns variables into `_$0`, `_$1`, and so on.
424    pub(crate) fn convert_var_into_num(&self, gcx: &GlobalCx<'cx>) -> Option<Self> {
425        let mut cloned: Option<Self> = None;
426
427        let mut i = 0;
428
429        while let Some(var) = find_var_in_clause(cloned.as_ref().unwrap_or(self)) {
430            let from = var.clone();
431
432            let mut convert = |term: &Term<Name<'_>>| {
433                (term == &from).then_some(Term {
434                    functor: Name::create(gcx, &format!("_{VAR_PREFIX}{i}")),
435                    args: [].into(),
436                })
437            };
438
439            if let Some(cloned) = &mut cloned {
440                cloned.replace_term(&mut convert);
441            } else {
442                let mut this = self.clone();
443                this.replace_term(&mut convert);
444                cloned = Some(this);
445            }
446
447            i += 1;
448        }
449
450        return cloned;
451
452        // === Internal helper functions ===
453
454        fn find_var_in_clause<'a, 'cx>(
455            clause: &'a Clause<Name<'cx>>,
456        ) -> Option<&'a Term<Name<'cx>>> {
457            let var = find_var_in_term(&clause.head);
458            if var.is_some() {
459                return var;
460            }
461            find_var_in_expr(clause.body.as_ref()?)
462        }
463
464        fn find_var_in_expr<'a, 'cx>(expr: &'a Expr<Name<'cx>>) -> Option<&'a Term<Name<'cx>>> {
465            match expr {
466                Expr::Term(term) => find_var_in_term(term),
467                Expr::Not(inner) => find_var_in_expr(inner),
468                Expr::And(args) | Expr::Or(args) => args.iter().find_map(find_var_in_expr),
469            }
470        }
471
472        fn find_var_in_term<'a, 'cx>(term: &'a Term<Name<'cx>>) -> Option<&'a Term<Name<'cx>>> {
473            const _: () = assert!(VAR_PREFIX == '$');
474
475            if term.is_variable() && !term.functor.starts_with("_$") {
476                Some(term)
477            } else {
478                term.args.iter().find_map(find_var_in_term)
479            }
480        }
481    }
482}
483
484pub struct NamedTermViewIter<'a, 'cx> {
485    term_iter: TermViewIter<'a, Int>,
486    int2name: &'a IdxMap<'cx, Int, Name<'cx>>,
487}
488
489impl<'a, 'cx> Iterator for NamedTermViewIter<'a, 'cx> {
490    type Item = NamedTermView<'a, 'cx>;
491
492    fn next(&mut self) -> Option<Self::Item> {
493        self.term_iter
494            .next()
495            .map(|view| NamedTermView::new(view, self.int2name))
496    }
497}
498
499impl iter::FusedIterator for NamedTermViewIter<'_, '_> {}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504    use crate::parse::{
505        self,
506        repr::{Clause, Expr},
507    };
508
509    #[test]
510    fn test_parse() {
511        fn assert(gcx: &GlobalCx<'_>, text: &str) {
512            let clause: Clause<Name<'_>> = parse::parse_str(gcx, text).unwrap();
513            assert_eq!(text, clause.to_string());
514        }
515
516        let interner = DroplessInterner::default();
517        let gcx = GlobalCx {
518            interner: &interner,
519        };
520
521        assert(&gcx, "f.");
522        assert(&gcx, "f(a, b).");
523        assert(&gcx, "f(a, b) :- f.");
524        assert(&gcx, "f(a, b) :- f(a).");
525        assert(&gcx, "f(a, b) :- f(a), f(b).");
526        assert(&gcx, "f(a, b) :- f(a); f(b).");
527        assert(&gcx, "f(a, b) :- f(a), (f(b); f(c)).");
528    }
529
530    #[test]
531    fn test_serial_queries() {
532        let interner = DroplessInterner::default();
533        let mut db = Database::new(&interner);
534
535        fn insert(db: &mut Database<'_>) {
536            insert_dataset(
537                db,
538                r"
539                f(a).
540                f(b).
541                g($X) :- f($X).
542                ",
543            );
544        }
545
546        fn query(db: &mut Database<'_>) {
547            let query = "g($X).";
548            let query: Expr<Name<'_>> = parse::parse_str(db.gcx(), query).unwrap();
549            let answer = collect_answer(db.query(query));
550
551            let expected = [["$X = a"], ["$X = b"]];
552
553            assert_eq!(answer, expected);
554        }
555
556        insert(&mut db);
557        let org_stor_len = db.stor.len();
558        query(&mut db);
559        debug_assert_eq!(org_stor_len, db.stor.len());
560
561        insert(&mut db);
562        debug_assert_eq!(org_stor_len, db.stor.len());
563        query(&mut db);
564        debug_assert_eq!(org_stor_len, db.stor.len());
565    }
566
567    #[test]
568    fn test_various_expressions() {
569        test_not_expression();
570        test_and_expression();
571        test_or_expression();
572        test_mixed_expression();
573    }
574
575    fn test_not_expression() {
576        let interner = DroplessInterner::default();
577        let mut db = Database::new(&interner);
578
579        insert_dataset(
580            &mut db,
581            r"
582            g(a).
583            f($X) :- \+ g($X).
584            ",
585        );
586
587        let query = "f(a).";
588        let query: Expr<Name> = parse::parse_str(db.gcx(), query).unwrap();
589        let answer = collect_answer(db.query(query));
590        assert!(answer.is_empty());
591
592        let query = "f(b).";
593        let query: Expr<Name> = parse::parse_str(db.gcx(), query).unwrap();
594        let answer = collect_answer(db.query(query));
595        assert_eq!(answer.len(), 1);
596    }
597
598    fn test_and_expression() {
599        let interner = DroplessInterner::default();
600        let mut db = Database::new(&interner);
601
602        insert_dataset(
603            &mut db,
604            r"
605            g(a).
606            g(b).
607            h(b).
608            f($X) :- g($X), h($X).
609            ",
610        );
611
612        let query = "f($X).";
613        let query: Expr<Name> = parse::parse_str(db.gcx(), query).unwrap();
614        let answer = collect_answer(db.query(query));
615
616        let expected = [["$X = b"]];
617
618        assert_eq!(answer, expected);
619    }
620
621    fn test_or_expression() {
622        let interner = DroplessInterner::default();
623        let mut db = Database::new(&interner);
624
625        insert_dataset(
626            &mut db,
627            r"
628            g(a).
629            h(b).
630            f($X) :- g($X); h($X).
631            ",
632        );
633
634        let query = "f($X).";
635        let query: Expr<Name> = parse::parse_str(db.gcx(), query).unwrap();
636        let answer = collect_answer(db.query(query));
637
638        let expected = [["$X = a"], ["$X = b"]];
639
640        assert_eq!(answer, expected);
641    }
642
643    fn test_mixed_expression() {
644        let interner = DroplessInterner::default();
645        let mut db = Database::new(&interner);
646
647        insert_dataset(
648            &mut db,
649            r"
650            g(b).
651            g(c).
652
653            h(b).
654
655            i(a).
656            i(b).
657            i(c).
658
659            f($X) :- (\+ g($X); h($X)), i($X).
660            ",
661        );
662
663        let query = "f($X).";
664        let query: Expr<Name> = parse::parse_str(db.gcx(), query).unwrap();
665        let answer = collect_answer(db.query(query));
666
667        let expected = [["$X = b"]];
668
669        assert_eq!(answer, expected);
670    }
671
672    #[test]
673    fn test_recursion() {
674        test_simple_recursion();
675        test_right_recursion();
676    }
677
678    fn test_simple_recursion() {
679        let interner = DroplessInterner::default();
680        let mut db = Database::new(&interner);
681
682        insert_dataset(
683            &mut db,
684            r"
685            impl(Clone, a).
686            impl(Clone, b).
687            impl(Clone, c).
688            impl(Clone, Vec($T)) :- impl(Clone, $T).
689            ",
690        );
691
692        let query = "impl(Clone, $T).";
693        let query: Expr<Name<'_>> = parse::parse_str(db.gcx(), query).unwrap();
694        let mut cx = db.query(query);
695
696        let mut assert_next = |expected: &[&str]| {
697            let eval = cx.prove_next().unwrap();
698            let assignments = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
699            assert_eq!(assignments, expected);
700        };
701
702        assert_next(&["$T = a"]);
703        assert_next(&["$T = b"]);
704        assert_next(&["$T = c"]);
705        assert_next(&["$T = Vec(a)"]);
706        assert_next(&["$T = Vec(b)"]);
707        assert_next(&["$T = Vec(c)"]);
708        assert_next(&["$T = Vec(Vec(a))"]);
709        assert_next(&["$T = Vec(Vec(b))"]);
710        assert_next(&["$T = Vec(Vec(c))"]);
711    }
712
713    fn test_right_recursion() {
714        let interner = DroplessInterner::default();
715        let mut db = Database::new(&interner);
716
717        insert_dataset(
718            &mut db,
719            r"
720            child(a, b).
721            child(b, c).
722            child(c, d).
723            descend($X, $Y) :- child($X, $Y).
724            descend($X, $Z) :- child($X, $Y), descend($Y, $Z).
725            ",
726        );
727
728        let query = "descend($X, $Y).";
729        let query: Expr<Name<'_>> = parse::parse_str(db.gcx(), query).unwrap();
730        let mut answer = collect_answer(db.query(query));
731
732        let mut expected = [
733            ["$X = a", "$Y = b"],
734            ["$X = a", "$Y = c"],
735            ["$X = a", "$Y = d"],
736            ["$X = b", "$Y = c"],
737            ["$X = b", "$Y = d"],
738            ["$X = c", "$Y = d"],
739        ];
740
741        answer.sort_unstable();
742        expected.sort_unstable();
743        assert_eq!(answer, expected);
744    }
745
746    #[test]
747    fn test_discarding_uncomitted_change() {
748        let interner = DroplessInterner::default();
749        let mut db = Database::new(&interner);
750
751        let text = "f(a).";
752        let clause = parse::parse_str(db.gcx(), text).unwrap();
753        db.insert_clause(clause);
754        let fa_state = db.state();
755        db.commit();
756
757        let text = "f(b).";
758        let clause = parse::parse_str(db.gcx(), text).unwrap();
759        db.insert_clause(clause);
760
761        let query = "f($X).";
762        let query: Expr<Name<'_>> = parse::parse_str(db.gcx(), query).unwrap();
763        let answer = collect_answer(db.query(query));
764
765        // `f(b).` was discarded.
766        let expected = [["$X = a"]];
767        assert_eq!(answer, expected);
768        assert_eq!(db.state(), fa_state);
769    }
770
771    fn insert_dataset(db: &mut Database, text: &str) {
772        let dataset: ClauseDataset<Name<'_>> = parse::parse_str(db.gcx(), text).unwrap();
773        db.insert_dataset(dataset);
774        db.commit();
775    }
776
777    fn collect_answer(mut cx: ProveCx<'_, '_>) -> Vec<Vec<String>> {
778        let mut v = Vec::new();
779        while let Some(eval) = cx.prove_next() {
780            let x = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
781            v.push(x);
782        }
783        v
784    }
785}