logic_eval/prove/
db.rs

1use super::{
2    prover::{
3        Int, NameIntMap, NameIntMapState, ProveCx, Prover,
4        format::{NamedExprView, NamedTermView},
5    },
6    repr::{ClauseId, TermStorage, TermStorageLen},
7};
8use crate::{
9    Map,
10    parse::{
11        VAR_PREFIX,
12        repr::{Clause, ClauseDataset, Expr, Predicate, Term},
13        text::Name,
14    },
15    prove::repr::{ExprKind, ExprView, TermView, TermViewIter},
16};
17use indexmap::{IndexMap, IndexSet};
18use std::{
19    fmt::{self, Write},
20    iter,
21};
22
23#[derive(Debug)]
24pub struct Database {
25    /// Clause id dataset.
26    clauses: IndexMap<Predicate<Int>, Vec<ClauseId>>,
27
28    /// We do not allow duplicated clauses in the dataset.
29    clause_texts: IndexSet<String>,
30
31    /// Term and expression storage.
32    stor: TermStorage<Int>,
33
34    /// Proof search engine.
35    prover: Prover,
36
37    /// Mappings between [`Name`] and [`Int`].
38    ///
39    /// [`Int`] is internally used for fast comparison, but we need to get it
40    /// back to [`Name`] for the clients.
41    nimap: NameIntMap,
42
43    /// States of DB's fields.
44    ///
45    /// This is used when we discard some changes on the DB.
46    revert_point: Option<DatabaseState>,
47}
48
49impl Database {
50    pub fn new() -> Self {
51        Self {
52            clauses: IndexMap::default(),
53            clause_texts: IndexSet::default(),
54            stor: TermStorage::new(),
55            prover: Prover::new(),
56            nimap: NameIntMap::new(),
57            revert_point: None,
58        }
59    }
60
61    pub fn terms(&self) -> NamedTermViewIter<'_> {
62        NamedTermViewIter {
63            term_iter: self.stor.terms.terms(),
64            int2name: &self.nimap.int2name,
65        }
66    }
67
68    pub fn clauses(&self) -> impl iter::FusedIterator<Item = ClauseRef<'_>> {
69        self.clauses
70            .values()
71            .flat_map(|group| group.iter().cloned())
72            .map(|id| ClauseRef {
73                id,
74                stor: &self.stor,
75                int2name: &self.nimap.int2name,
76            })
77    }
78
79    pub fn insert_dataset(&mut self, dataset: ClauseDataset<Name>) {
80        for clause in dataset {
81            self.insert_clause(clause);
82        }
83    }
84
85    pub fn insert_clause(&mut self, clause: Clause<Name>) {
86        // Saves current state. We will revert DB when the change is not
87        // committed.
88        if self.revert_point.is_none() {
89            self.revert_point = Some(self.state());
90        }
91
92        // If this DB contains the given clause, then returns.
93        let serialized = if let Some(converted) = clause.convert_var_into_num() {
94            converted.to_string()
95        } else {
96            clause.to_string()
97        };
98        if !self.clause_texts.insert(serialized) {
99            return;
100        }
101
102        let clause = clause.map(&mut |name| self.nimap.name_to_int(name));
103
104        let key = clause.head.predicate();
105        let value = ClauseId {
106            head: self.stor.insert_term(clause.head),
107            body: clause.body.map(|expr| self.stor.insert_expr(expr)),
108        };
109
110        self.clauses
111            .entry(key)
112            .and_modify(|similar_clauses| {
113                if similar_clauses.iter().all(|clause| clause != &value) {
114                    similar_clauses.push(value);
115                }
116            })
117            .or_insert(vec![value]);
118    }
119
120    pub fn query(&mut self, expr: Expr<Name>) -> ProveCx<'_> {
121        // Discards uncomitted changes.
122        if let Some(revert_point) = self.revert_point.take() {
123            self.revert(revert_point);
124        }
125
126        self.prover
127            .prove(expr, &self.clauses, &mut self.stor, &mut self.nimap)
128    }
129
130    pub fn commit(&mut self) {
131        self.revert_point.take();
132    }
133
134    /// * sanitize - Removes unacceptable characters from prolog.
135    pub fn to_prolog<F: FnMut(&str) -> &str>(&self, sanitize: F) -> String {
136        let mut prolog_text = String::new();
137
138        let mut conv_map = ConversionMap {
139            int_to_str: Map::default(),
140            sanitized_to_suffix: Map::default(),
141            int2name: &self.nimap.int2name,
142            sanitizer: sanitize,
143        };
144
145        for clauses in self.clauses.values() {
146            for clause in clauses {
147                let head = self.stor.get_term(clause.head);
148                write_term(head, &mut conv_map, &mut prolog_text);
149
150                if let Some(body) = clause.body {
151                    prolog_text.push_str(" :- ");
152
153                    let body = self.stor.get_expr(body);
154                    write_expr(body, &mut conv_map, &mut prolog_text);
155                }
156
157                prolog_text.push_str(".\n");
158            }
159        }
160
161        return prolog_text;
162
163        // === Internal helper functions ===
164
165        struct ConversionMap<'a, F> {
166            int_to_str: Map<Int, String>,
167            // e.g. 0 -> No suffix, 1 -> _1, 2 -> _2, ...
168            sanitized_to_suffix: Map<&'a str, u32>,
169            int2name: &'a IndexMap<Int, Name>,
170            sanitizer: F,
171        }
172
173        impl<F: FnMut(&str) -> &str> ConversionMap<'_, F> {
174            fn int_to_str(&mut self, int: Int) -> &str {
175                self.int_to_str.entry(int).or_insert_with(|| {
176                    let name = self.int2name.get(&int).unwrap();
177
178                    let mut is_var = false;
179
180                    // Removes variable prefix.
181                    let name = if name.starts_with(VAR_PREFIX) {
182                        is_var = true;
183                        &name[1..]
184                    } else {
185                        name
186                    };
187
188                    // Removes other user-defined characters.
189                    let pure_name = (self.sanitizer)(name);
190
191                    let suffix = self
192                        .sanitized_to_suffix
193                        .entry(pure_name)
194                        .and_modify(|x| *x += 1)
195                        .or_insert(0);
196
197                    let mut buf = String::new();
198
199                    if is_var {
200                        let upper = pure_name.chars().next().unwrap().to_uppercase();
201                        for c in upper {
202                            buf.push(c);
203                        }
204                    } else {
205                        let lower = pure_name.chars().next().unwrap().to_lowercase();
206                        for c in lower {
207                            buf.push(c);
208                        }
209                    };
210                    buf.push_str(&pure_name[1..]);
211
212                    if *suffix == 0 {
213                        buf
214                    } else {
215                        write!(&mut buf, "_{suffix}").unwrap();
216                        buf
217                    }
218                })
219            }
220        }
221
222        fn write_term<F: FnMut(&str) -> &str>(
223            term: TermView<'_, Int>,
224            conv_map: &mut ConversionMap<'_, F>,
225            prolog_text: &mut String,
226        ) {
227            let functor = term.functor();
228            let args = term.args();
229            let num_args = args.len();
230
231            let functor = conv_map.int_to_str(*functor);
232            prolog_text.push_str(functor);
233
234            if num_args > 0 {
235                prolog_text.push('(');
236                for (i, arg) in args.enumerate() {
237                    write_term(arg, conv_map, prolog_text);
238                    if i + 1 < num_args {
239                        prolog_text.push_str(", ");
240                    }
241                }
242                prolog_text.push(')');
243            }
244        }
245
246        fn write_expr<F: FnMut(&str) -> &str>(
247            expr: ExprView<'_, Int>,
248            conv_map: &mut ConversionMap<'_, F>,
249            prolog_text: &mut String,
250        ) {
251            match expr.as_kind() {
252                ExprKind::Term(term) => {
253                    write_term(term, conv_map, prolog_text);
254                }
255                ExprKind::Not(inner) => {
256                    prolog_text.push_str("\\+ ");
257                    if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) {
258                        prolog_text.push('(');
259                        write_expr(inner, conv_map, prolog_text);
260                        prolog_text.push(')');
261                    } else {
262                        write_expr(inner, conv_map, prolog_text);
263                    }
264                }
265                ExprKind::And(args) => {
266                    let num_args = args.len();
267                    for (i, arg) in args.enumerate() {
268                        if matches!(arg.as_kind(), ExprKind::Or(_)) {
269                            prolog_text.push('(');
270                            write_expr(arg, conv_map, prolog_text);
271                            prolog_text.push(')');
272                        } else {
273                            write_expr(arg, conv_map, prolog_text);
274                        }
275                        if i + 1 < num_args {
276                            prolog_text.push_str(", ");
277                        }
278                    }
279                }
280                ExprKind::Or(args) => {
281                    let num_args = args.len();
282                    for (i, arg) in args.enumerate() {
283                        write_expr(arg, conv_map, prolog_text);
284                        if i + 1 < num_args {
285                            prolog_text.push_str("; ");
286                        }
287                    }
288                }
289            }
290        }
291    }
292
293    fn revert(
294        &mut self,
295        DatabaseState {
296            clauses_len,
297            clause_texts_len,
298            stor_len,
299            nimap_state,
300        }: DatabaseState,
301    ) {
302        self.clauses.truncate(clauses_len.len());
303        for (i, len) in clauses_len.into_iter().enumerate() {
304            self.clauses[i].truncate(len);
305        }
306        self.clause_texts.truncate(clause_texts_len);
307        self.stor.truncate(stor_len);
308        self.nimap.revert(nimap_state);
309        // `self.prover: Prover` does not store any persistent data.
310    }
311
312    fn state(&self) -> DatabaseState {
313        DatabaseState {
314            clauses_len: self.clauses.values().map(|v| v.len()).collect(),
315            clause_texts_len: self.clause_texts.len(),
316            stor_len: self.stor.len(),
317            nimap_state: self.nimap.state(),
318        }
319    }
320}
321
322impl Default for Database {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328#[derive(Debug, PartialEq, Eq)]
329struct DatabaseState {
330    clauses_len: Vec<usize>,
331    clause_texts_len: usize,
332    stor_len: TermStorageLen,
333    nimap_state: NameIntMapState,
334}
335
336pub struct ClauseRef<'a> {
337    id: ClauseId,
338    stor: &'a TermStorage<Int>,
339    int2name: &'a IndexMap<Int, Name>,
340}
341
342impl fmt::Display for ClauseRef<'_> {
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344        let head = self.stor.get_term(self.id.head);
345        fmt::Display::fmt(&NamedTermView::new(head, self.int2name), f)?;
346
347        if let Some(body) = self.id.body {
348            let body = self.stor.get_expr(body);
349            f.write_str(" :- ")?;
350            fmt::Display::fmt(&NamedExprView::new(body, self.int2name), f)?
351        }
352        f.write_char('.')
353    }
354}
355
356impl fmt::Debug for ClauseRef<'_> {
357    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358        let mut d = f.debug_struct("Clause");
359
360        let head = self.stor.get_term(self.id.head);
361        d.field("head", &NamedTermView::new(head, self.int2name));
362
363        if let Some(body) = self.id.body {
364            let body = self.stor.get_expr(body);
365            d.field("body", &NamedExprView::new(body, self.int2name));
366        }
367
368        d.finish()
369    }
370}
371
372impl Clause<Name> {
373    /// Turns variables into `_$0`, `_$1`, and so on.
374    pub(crate) fn convert_var_into_num(&self) -> Option<Self> {
375        let mut cloned: Option<Self> = None;
376
377        let mut i = 0;
378
379        while let Some(var) = find_var_in_clause(cloned.as_ref().unwrap_or(self)) {
380            let from = var.clone();
381
382            let mut convert = |term: &Term<Name>| {
383                (term == &from).then_some(Term {
384                    functor: format!("_{VAR_PREFIX}{i}").into(),
385                    args: [].into(),
386                })
387            };
388
389            if let Some(cloned) = &mut cloned {
390                cloned.replace_term(&mut convert);
391            } else {
392                let mut this = self.clone();
393                this.replace_term(&mut convert);
394                cloned = Some(this);
395            }
396
397            i += 1;
398        }
399
400        return cloned;
401
402        // === Internal helper functions ===
403
404        fn find_var_in_clause(clause: &Clause<Name>) -> Option<&Term<Name>> {
405            let var = find_var_in_term(&clause.head);
406            if var.is_some() {
407                return var;
408            }
409            find_var_in_expr(clause.body.as_ref()?)
410        }
411
412        fn find_var_in_expr(expr: &Expr<Name>) -> Option<&Term<Name>> {
413            match expr {
414                Expr::Term(term) => find_var_in_term(term),
415                Expr::Not(inner) => find_var_in_expr(inner),
416                Expr::And(args) | Expr::Or(args) => args.iter().find_map(find_var_in_expr),
417            }
418        }
419
420        fn find_var_in_term(term: &Term<Name>) -> Option<&Term<Name>> {
421            const _: () = assert!(VAR_PREFIX == '$');
422
423            if term.is_variable() && !term.functor.starts_with("_$") {
424                Some(term)
425            } else {
426                term.args.iter().find_map(find_var_in_term)
427            }
428        }
429    }
430}
431
432pub struct NamedTermViewIter<'a> {
433    term_iter: TermViewIter<'a, Int>,
434    int2name: &'a IndexMap<Int, Name>,
435}
436
437impl<'a> Iterator for NamedTermViewIter<'a> {
438    type Item = NamedTermView<'a>;
439
440    fn next(&mut self) -> Option<Self::Item> {
441        self.term_iter
442            .next()
443            .map(|view| NamedTermView::new(view, self.int2name))
444    }
445}
446
447impl iter::FusedIterator for NamedTermViewIter<'_> {}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use crate::parse::{
453        self,
454        repr::{Clause, Expr},
455    };
456
457    #[test]
458    fn test_parse() {
459        fn assert(text: &str) {
460            let clause: Clause<Name> = parse::parse_str(text).unwrap();
461            assert_eq!(text, clause.to_string());
462        }
463
464        assert("f.");
465        assert("f(a, b).");
466        assert("f(a, b) :- f.");
467        assert("f(a, b) :- f(a).");
468        assert("f(a, b) :- f(a), f(b).");
469        assert("f(a, b) :- f(a); f(b).");
470        assert("f(a, b) :- f(a), (f(b); f(c)).");
471    }
472
473    #[test]
474    fn test_serial_queries() {
475        let mut db = Database::new();
476
477        fn insert(db: &mut Database) {
478            insert_dataset(
479                db,
480                r"
481                f(a).
482                f(b).
483                g($X) :- f($X).
484                ",
485            );
486        }
487
488        fn query(db: &mut Database) {
489            let query = "g($X).";
490            let query: Expr<Name> = parse::parse_str(query).unwrap();
491            let answer = collect_answer(db.query(query));
492
493            let expected = [["$X = a"], ["$X = b"]];
494
495            assert_eq!(answer, expected);
496        }
497
498        insert(&mut db);
499        let org_stor_len = db.stor.len();
500        query(&mut db);
501        debug_assert_eq!(org_stor_len, db.stor.len());
502
503        insert(&mut db);
504        debug_assert_eq!(org_stor_len, db.stor.len());
505        query(&mut db);
506        debug_assert_eq!(org_stor_len, db.stor.len());
507    }
508
509    #[test]
510    fn test_various_expressions() {
511        test_not_expression();
512        test_and_expression();
513        test_or_expression();
514        test_mixed_expression();
515    }
516
517    fn test_not_expression() {
518        let mut db = Database::new();
519
520        insert_dataset(
521            &mut db,
522            r"
523            g(a).
524            f($X) :- \+ g($X).
525            ",
526        );
527
528        let query = "f(a).";
529        let query: Expr<Name> = parse::parse_str(query).unwrap();
530        let answer = collect_answer(db.query(query));
531        assert!(answer.is_empty());
532
533        let query = "f(b).";
534        let query: Expr<Name> = parse::parse_str(query).unwrap();
535        let answer = collect_answer(db.query(query));
536        assert_eq!(answer.len(), 1);
537    }
538
539    fn test_and_expression() {
540        let mut db = Database::new();
541
542        insert_dataset(
543            &mut db,
544            r"
545            g(a).
546            g(b).
547            h(b).
548            f($X) :- g($X), h($X).
549            ",
550        );
551
552        let query = "f($X).";
553        let query: Expr<Name> = parse::parse_str(query).unwrap();
554        let answer = collect_answer(db.query(query));
555
556        let expected = [["$X = b"]];
557
558        assert_eq!(answer, expected);
559    }
560
561    fn test_or_expression() {
562        let mut db = Database::new();
563
564        insert_dataset(
565            &mut db,
566            r"
567            g(a).
568            h(b).
569            f($X) :- g($X); h($X).
570            ",
571        );
572
573        let query = "f($X).";
574        let query: Expr<Name> = parse::parse_str(query).unwrap();
575        let answer = collect_answer(db.query(query));
576
577        let expected = [["$X = a"], ["$X = b"]];
578
579        assert_eq!(answer, expected);
580    }
581
582    fn test_mixed_expression() {
583        let mut db = Database::new();
584
585        insert_dataset(
586            &mut db,
587            r"
588            g(b).
589            g(c).
590
591            h(b).
592
593            i(a).
594            i(b).
595            i(c).
596
597            f($X) :- (\+ g($X); h($X)), i($X).
598            ",
599        );
600
601        let query = "f($X).";
602        let query: Expr<Name> = parse::parse_str(query).unwrap();
603        let answer = collect_answer(db.query(query));
604
605        let expected = [["$X = b"]];
606
607        assert_eq!(answer, expected);
608    }
609
610    #[test]
611    fn test_recursion() {
612        test_simple_recursion();
613        test_right_recursion();
614    }
615
616    fn test_simple_recursion() {
617        let mut db = Database::new();
618
619        insert_dataset(
620            &mut db,
621            r"
622            impl(Clone, a).
623            impl(Clone, b).
624            impl(Clone, c).
625            impl(Clone, Vec($T)) :- impl(Clone, $T).
626            ",
627        );
628
629        let query = "impl(Clone, $T).";
630        let query: Expr<Name> = parse::parse_str(query).unwrap();
631        let mut cx = db.query(query);
632
633        let mut assert_next = |expected: &[&str]| {
634            let eval = cx.prove_next().unwrap();
635            let assignments = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
636            assert_eq!(assignments, expected);
637        };
638
639        assert_next(&["$T = a"]);
640        assert_next(&["$T = b"]);
641        assert_next(&["$T = c"]);
642        assert_next(&["$T = Vec(a)"]);
643        assert_next(&["$T = Vec(b)"]);
644        assert_next(&["$T = Vec(c)"]);
645        assert_next(&["$T = Vec(Vec(a))"]);
646        assert_next(&["$T = Vec(Vec(b))"]);
647        assert_next(&["$T = Vec(Vec(c))"]);
648    }
649
650    fn test_right_recursion() {
651        let mut db = Database::new();
652
653        insert_dataset(
654            &mut db,
655            r"
656            child(a, b).
657            child(b, c).
658            child(c, d).
659            descend($X, $Y) :- child($X, $Y).
660            descend($X, $Z) :- child($X, $Y), descend($Y, $Z).
661            ",
662        );
663
664        let query = "descend($X, $Y).";
665        let query: Expr<Name> = parse::parse_str(query).unwrap();
666        let mut answer = collect_answer(db.query(query));
667
668        let mut expected = [
669            ["$X = a", "$Y = b"],
670            ["$X = a", "$Y = c"],
671            ["$X = a", "$Y = d"],
672            ["$X = b", "$Y = c"],
673            ["$X = b", "$Y = d"],
674            ["$X = c", "$Y = d"],
675        ];
676
677        answer.sort_unstable();
678        expected.sort_unstable();
679        assert_eq!(answer, expected);
680    }
681
682    #[test]
683    fn test_discarding_uncomitted_change() {
684        let mut db = Database::new();
685
686        let text = "f(a).";
687        let clause = parse::parse_str(text).unwrap();
688        db.insert_clause(clause);
689        let fa_state = db.state();
690        db.commit();
691
692        let text = "f(b).";
693        let clause = parse::parse_str(text).unwrap();
694        db.insert_clause(clause);
695
696        let query = "f($X).";
697        let query: Expr<Name> = parse::parse_str(query).unwrap();
698        let answer = collect_answer(db.query(query));
699
700        // `f(b).` was discarded.
701        let expected = [["$X = a"]];
702        assert_eq!(answer, expected);
703        assert_eq!(db.state(), fa_state);
704    }
705
706    fn insert_dataset(db: &mut Database, text: &str) {
707        let dataset: ClauseDataset<Name> = parse::parse_str(text).unwrap();
708        db.insert_dataset(dataset);
709        db.commit();
710    }
711
712    fn collect_answer(mut cx: ProveCx<'_>) -> Vec<Vec<String>> {
713        let mut v = Vec::new();
714        while let Some(eval) = cx.prove_next() {
715            let x = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
716            v.push(x);
717        }
718        v
719    }
720}