Skip to main content

logic_eval/prove/
db.rs

1use super::{
2    prover::{
3        format::{NamedExprView, NamedTermView},
4        Integer, NameIntMap, NameIntMapState, ProveCx, Prover,
5    },
6    repr::{ClauseId, TermStorage, TermStorageLen},
7};
8use crate::{
9    parse::{
10        repr::{Clause, ClauseDataset, Expr, Predicate, Term},
11        VAR_PREFIX,
12    },
13    prove::repr::{ExprKind, ExprView, TermView, TermViewIter},
14    Atom, IndexMap, IndexSet, Map,
15};
16use core::{
17    fmt::{self, Debug, Display, Write},
18    iter::FusedIterator,
19};
20
21pub struct Database<T> {
22    /// Clause id dataset.
23    clauses: IndexMap<Predicate<Integer>, Vec<ClauseId>>,
24
25    /// Clauses that should be handled by tabling.
26    table_clauses: IndexSet<Predicate<Integer>>,
27
28    /// We do not allow duplicate clauses in the dataset.
29    dup_checker: DuplicateClauseChecker,
30
31    /// Term and expression storage.
32    stor: TermStorage<Integer>,
33
34    /// Proof search engine.
35    prover: Prover,
36
37    /// Mappings between T and Integer.
38    ///
39    /// Integer is internally used for fast comparison, but we need to get it back to T for the
40    /// clients.
41    nimap: NameIntMap<T>,
42
43    /// States of DB's fields.
44    ///
45    /// This is used when we discard some changes on the DB.
46    revert_point: Option<DatabaseState>,
47}
48
49impl<T: Atom> Database<T> {
50    pub fn new() -> Self {
51        Self {
52            clauses: IndexMap::default(),
53            table_clauses: IndexSet::default(),
54            dup_checker: DuplicateClauseChecker::default(),
55            stor: TermStorage::new(),
56            prover: Prover::new(),
57            nimap: NameIntMap::new(),
58            revert_point: None,
59        }
60    }
61
62    pub fn terms(&self) -> NamedTermViewIter<'_, T> {
63        NamedTermViewIter {
64            term_iter: self.stor.terms.terms(),
65            nimap: &self.nimap,
66        }
67    }
68
69    pub fn clauses(&self) -> ClauseIter<'_, T> {
70        ClauseIter {
71            clauses: &self.clauses,
72            stor: &self.stor,
73            nimap: &self.nimap,
74            i: 0,
75            j: 0,
76        }
77    }
78
79    pub fn insert_dataset(&mut self, dataset: ClauseDataset<T>) {
80        for clause in dataset {
81            self.insert_clause(clause);
82        }
83    }
84
85    /// Inserts the given clause to the DB.
86    pub fn insert_clause(&mut self, clause: Clause<T>) {
87        // Saves current state. We will revert DB when the change is not committed.
88        if self.revert_point.is_none() {
89            self.revert_point = Some(self.state());
90        }
91
92        let clause = clause.map(&mut |t| self.nimap.name_to_int(t));
93
94        // Records whether the clause needs tabling.
95        if clause.needs_tabling() {
96            self.table_clauses.insert(clause.head.predicate());
97        }
98
99        // If the DB already contains the given clause, then returns.
100        if !self.dup_checker.insert(clause.clone()) {
101            return;
102        }
103
104        let key = clause.head.predicate();
105        let value = ClauseId {
106            head: self.stor.insert_term(clause.head),
107            body: clause.body.map(|expr| self.stor.insert_expr(expr)),
108        };
109
110        self.clauses
111            .entry(key)
112            .and_modify(|similar_clauses| {
113                if similar_clauses.iter().all(|clause| clause != &value) {
114                    similar_clauses.push(value);
115                }
116            })
117            .or_insert(vec![value]);
118    }
119
120    pub fn query(&mut self, expr: Expr<T>) -> ProveCx<'_, T> {
121        // Discards uncommitted changes.
122        if let Some(revert_point) = self.revert_point.take() {
123            self.revert(revert_point);
124        }
125
126        self.prover.prove(
127            expr,
128            &self.clauses,
129            &self.table_clauses,
130            &mut self.stor,
131            &mut self.nimap,
132        )
133    }
134
135    pub fn commit(&mut self) {
136        self.revert_point.take();
137    }
138
139    /// * sanitize - Removes unacceptable characters from prolog.
140    ///
141    /// Requires T to implement [`AsRef<str>`] so that functor names can be serialized into Prolog
142    /// syntax.
143    pub fn to_prolog<F: FnMut(&str) -> &str>(&self, sanitize: F) -> String
144    where
145        T: AsRef<str>,
146    {
147        let mut prolog_text = String::new();
148
149        let mut conv_map = ConversionMap {
150            int_to_str: Map::default(),
151            sanitized_to_suffix: Map::default(),
152            nimap: &self.nimap,
153            sanitizer: sanitize,
154        };
155
156        for clauses in self.clauses.values() {
157            for clause in clauses {
158                let head = self.stor.get_term(clause.head);
159                write_term(head, &mut conv_map, &mut prolog_text);
160
161                if let Some(body) = clause.body {
162                    prolog_text.push_str(" :- ");
163
164                    let body = self.stor.get_expr(body);
165                    write_expr(body, &mut conv_map, &mut prolog_text);
166                }
167
168                prolog_text.push_str(".\n");
169            }
170        }
171
172        return prolog_text;
173
174        // === Internal helper functions ===
175
176        struct ConversionMap<'a, T, F> {
177            int_to_str: Map<Integer, String>,
178            // e.g. 0 -> No suffix, 1 -> _1, 2 -> _2, ...
179            sanitized_to_suffix: Map<&'a str, u32>,
180            nimap: &'a NameIntMap<T>,
181            sanitizer: F,
182        }
183
184        impl<T, F> ConversionMap<'_, T, F>
185        where
186            T: AsRef<str>,
187            F: FnMut(&str) -> &str,
188        {
189            fn int_to_str(&mut self, int: Integer) -> &str {
190                self.int_to_str.entry(int).or_insert_with(|| {
191                    let name = self.nimap.get_name(&int).unwrap();
192                    let name: &str = name.as_ref();
193
194                    let mut is_var = false;
195
196                    // Removes variable prefix.
197                    let name = if name.starts_with(VAR_PREFIX) {
198                        is_var = true;
199                        &name[1..]
200                    } else {
201                        name
202                    };
203
204                    // Removes other user-defined characters.
205                    let pure_name = (self.sanitizer)(name);
206
207                    let suffix = self
208                        .sanitized_to_suffix
209                        .entry(pure_name)
210                        .and_modify(|x| *x += 1)
211                        .or_insert(0);
212
213                    let mut buf = String::new();
214
215                    if is_var {
216                        let upper = pure_name.chars().next().unwrap().to_uppercase();
217                        for c in upper {
218                            buf.push(c);
219                        }
220                    } else {
221                        let lower = pure_name.chars().next().unwrap().to_lowercase();
222                        for c in lower {
223                            buf.push(c);
224                        }
225                    };
226                    buf.push_str(&pure_name[1..]);
227
228                    if *suffix == 0 {
229                        buf
230                    } else {
231                        write!(&mut buf, "_{suffix}").unwrap();
232                        buf
233                    }
234                })
235            }
236        }
237
238        fn write_term<T, F>(
239            term: TermView<'_, Integer>,
240            conv_map: &mut ConversionMap<'_, T, F>,
241            prolog_text: &mut String,
242        ) where
243            T: AsRef<str>,
244            F: FnMut(&str) -> &str,
245        {
246            let functor = term.functor();
247            let args = term.args();
248            let num_args = args.len();
249
250            let functor = conv_map.int_to_str(*functor);
251            prolog_text.push_str(functor);
252
253            if num_args > 0 {
254                prolog_text.push('(');
255                for (i, arg) in args.enumerate() {
256                    write_term(arg, conv_map, prolog_text);
257                    if i + 1 < num_args {
258                        prolog_text.push_str(", ");
259                    }
260                }
261                prolog_text.push(')');
262            }
263        }
264
265        fn write_expr<T, F>(
266            expr: ExprView<'_, Integer>,
267            conv_map: &mut ConversionMap<'_, T, F>,
268            prolog_text: &mut String,
269        ) where
270            T: AsRef<str>,
271            F: FnMut(&str) -> &str,
272        {
273            match expr.as_kind() {
274                ExprKind::Term(term) => {
275                    write_term(term, conv_map, prolog_text);
276                }
277                ExprKind::Not(inner) => {
278                    prolog_text.push_str("\\+ ");
279                    if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) {
280                        prolog_text.push('(');
281                        write_expr(inner, conv_map, prolog_text);
282                        prolog_text.push(')');
283                    } else {
284                        write_expr(inner, conv_map, prolog_text);
285                    }
286                }
287                ExprKind::And(args) => {
288                    let num_args = args.len();
289                    for (i, arg) in args.enumerate() {
290                        if matches!(arg.as_kind(), ExprKind::Or(_)) {
291                            prolog_text.push('(');
292                            write_expr(arg, conv_map, prolog_text);
293                            prolog_text.push(')');
294                        } else {
295                            write_expr(arg, conv_map, prolog_text);
296                        }
297                        if i + 1 < num_args {
298                            prolog_text.push_str(", ");
299                        }
300                    }
301                }
302                ExprKind::Or(args) => {
303                    let num_args = args.len();
304                    for (i, arg) in args.enumerate() {
305                        write_expr(arg, conv_map, prolog_text);
306                        if i + 1 < num_args {
307                            prolog_text.push_str("; ");
308                        }
309                    }
310                }
311            }
312        }
313    }
314
315    fn revert(
316        &mut self,
317        DatabaseState {
318            clauses_len,
319            clause_set_len,
320            stor_len,
321            nimap_state,
322        }: DatabaseState,
323    ) {
324        self.clauses.truncate(clauses_len.len());
325        for (i, len) in clauses_len.into_iter().enumerate() {
326            self.clauses[i].truncate(len);
327        }
328        self.dup_checker.truncate(clause_set_len);
329        self.stor.truncate(stor_len);
330        self.nimap.revert(nimap_state);
331        // `self.prover: Prover` does not store any persistent data.
332    }
333
334    fn state(&self) -> DatabaseState {
335        DatabaseState {
336            clauses_len: self.clauses.values().map(|v| v.len()).collect(),
337            clause_set_len: self.dup_checker.len(),
338            stor_len: self.stor.len(),
339            nimap_state: self.nimap.state(),
340        }
341    }
342}
343
344impl<T: Atom> Default for Database<T> {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350impl<T: Debug> Debug for Database<T> {
351    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352        f.debug_struct("Database")
353            .field("clauses", &self.clauses)
354            .field("dup_checker", &self.dup_checker)
355            .field("stor", &self.stor)
356            .field("nimap", &self.nimap)
357            .field("revert_point", &self.revert_point)
358            .finish_non_exhaustive()
359    }
360}
361
362#[derive(Debug, PartialEq, Eq)]
363struct DatabaseState {
364    clauses_len: Vec<usize>,
365    clause_set_len: usize,
366    stor_len: TermStorageLen,
367    nimap_state: NameIntMapState,
368}
369
370#[derive(Debug, Default)]
371struct DuplicateClauseChecker {
372    seen: IndexSet<Clause<Integer>>,
373
374    /// Temporary buffer for granting [`Integer`] to variables.
375    vars: Vec<Integer>,
376}
377
378impl DuplicateClauseChecker {
379    /// Returns true if the given clause is new, has not been seen before.
380    fn insert(&mut self, clause: Clause<Integer>) -> bool {
381        let canonical_clause = clause.map(&mut |t| {
382            if !t.is_variable() {
383                t
384            } else if let Some(found) = self.vars.iter().find(|&&var| var == t) {
385                *found
386            } else {
387                let next_int = self.vars.len() as u32;
388                let int = Integer::variable(next_int);
389                self.vars.push(int);
390                int
391            }
392        });
393        let is_new = self.seen.insert(canonical_clause);
394        self.vars.clear();
395        is_new
396    }
397
398    fn len(&self) -> usize {
399        self.seen.len()
400    }
401
402    fn truncate(&mut self, len: usize) {
403        self.seen.truncate(len);
404    }
405}
406
407/// Turns variables into `_$0`, `_$1`, and so on using the given canonical_var function.
408///
409/// Returns `None` if `canonical_var` is `None` (i.e. deduplication disabled).
410fn _convert_var_into_num<T: Atom>(
411    this: &Clause<T>,
412    canonical_var: Option<&dyn Fn(usize) -> T>,
413) -> Option<Clause<T>> {
414    let canonical_var = canonical_var?;
415    let mut cloned: Option<Clause<T>> = None;
416
417    let mut i = 0;
418
419    while let Some(from) = find_var_in_clause(cloned.as_ref().unwrap_or(this)) {
420        let from = from.clone();
421        let canonical_t = canonical_var(i);
422
423        let mut convert = |term: &Term<T>| {
424            (term.functor == from && term.args.is_empty()).then_some(Term {
425                functor: canonical_t.clone(),
426                args: vec![],
427            })
428        };
429
430        if let Some(cloned) = &mut cloned {
431            cloned.replace_term(&mut convert);
432        } else {
433            let mut this = this.clone();
434            this.replace_term(&mut convert);
435            cloned = Some(this);
436        }
437
438        i += 1;
439    }
440
441    return cloned;
442
443    // === Internal helper functions ===
444
445    fn find_var_in_clause<T: Atom>(clause: &Clause<T>) -> Option<T> {
446        find_var_in_term(&clause.head).or_else(|| clause.body.as_ref().and_then(find_var_in_expr))
447    }
448
449    fn find_var_in_expr<T: Atom>(expr: &Expr<T>) -> Option<T> {
450        match expr {
451            Expr::Term(term) => find_var_in_term(term),
452            Expr::Not(arg) => find_var_in_expr(arg),
453            Expr::And(args) | Expr::Or(args) => args.iter().find_map(find_var_in_expr),
454        }
455    }
456
457    fn find_var_in_term<T: Atom>(term: &Term<T>) -> Option<T> {
458        if term.functor.is_variable() {
459            Some(term.functor.clone())
460        } else {
461            term.args.iter().find_map(find_var_in_term)
462        }
463    }
464}
465
466#[derive(Clone)]
467pub struct ClauseIter<'a, T> {
468    clauses: &'a IndexMap<Predicate<Integer>, Vec<ClauseId>>,
469    stor: &'a TermStorage<Integer>,
470    nimap: &'a NameIntMap<T>,
471    i: usize,
472    j: usize,
473}
474
475impl<'a, T> Iterator for ClauseIter<'a, T> {
476    type Item = ClauseRef<'a, T>;
477
478    fn next(&mut self) -> Option<Self::Item> {
479        let id = loop {
480            let (_, group) = self.clauses.get_index(self.i)?;
481
482            if let Some(id) = group.get(self.j) {
483                self.j += 1;
484                break *id;
485            }
486
487            self.i += 1;
488            self.j = 0;
489        };
490
491        Some(ClauseRef {
492            id,
493            stor: self.stor,
494            nimap: self.nimap,
495        })
496    }
497}
498
499impl<T> FusedIterator for ClauseIter<'_, T> {}
500
501pub struct ClauseRef<'a, T> {
502    id: ClauseId,
503    stor: &'a TermStorage<Integer>,
504    nimap: &'a NameIntMap<T>,
505}
506
507impl<'a, T: Atom> ClauseRef<'a, T> {
508    pub fn head(&self) -> NamedTermView<'a, T> {
509        let head = self.stor.get_term(self.id.head);
510        NamedTermView::new(head, self.nimap)
511    }
512
513    pub fn body(&self) -> Option<NamedExprView<'a, T>> {
514        self.id.body.map(|id| {
515            let body = self.stor.get_expr(id);
516            NamedExprView::new(body, self.nimap)
517        })
518    }
519}
520
521impl<T: Atom + Display> Display for ClauseRef<'_, T> {
522    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
523        Display::fmt(&self.head(), f)?;
524
525        if let Some(body) = self.body() {
526            f.write_str(" :- ")?;
527            Display::fmt(&body, f)?
528        }
529
530        f.write_char('.')
531    }
532}
533
534impl<T: Atom + Debug> Debug for ClauseRef<'_, T> {
535    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
536        let mut d = f.debug_struct("Clause");
537
538        let head = self.stor.get_term(self.id.head);
539        d.field("head", &NamedTermView::new(head, self.nimap));
540
541        if let Some(body) = self.id.body {
542            let body = self.stor.get_expr(body);
543            d.field("body", &NamedExprView::new(body, self.nimap));
544        }
545
546        d.finish()
547    }
548}
549
550pub struct NamedTermViewIter<'a, T> {
551    term_iter: TermViewIter<'a, Integer>,
552    nimap: &'a NameIntMap<T>,
553}
554
555impl<'a, T: Atom> Iterator for NamedTermViewIter<'a, T> {
556    type Item = NamedTermView<'a, T>;
557
558    fn next(&mut self) -> Option<Self::Item> {
559        self.term_iter
560            .next()
561            .map(|view| NamedTermView::new(view, self.nimap))
562    }
563}
564
565impl<T: Atom> FusedIterator for NamedTermViewIter<'_, T> {}
566
567#[cfg(test)]
568mod str_atom_tests {
569    use crate::{parse, NameIn};
570
571    type Interner = any_intern::DroplessInterner;
572    type Database<'int> = crate::Database<NameIn<'int, Interner>>;
573    type ProveCx<'a, 'int> = crate::ProveCx<'a, NameIn<'int, Interner>>;
574    type ClauseDataset<'int> = crate::ClauseDatasetIn<'int, Interner>;
575    type Expr<'int> = crate::ExprIn<'int, Interner>;
576    type Clause<'int> = crate::ClauseIn<'int, Interner>;
577
578    #[test]
579    fn test_serial_queries() {
580        fn assert_query<'int>(db: &mut Database<'int>, interner: &'int Interner) {
581            let query = "g($X).";
582            let query: Expr<'int> = parse::parse_str(query, interner).unwrap();
583
584            let cx = db.query(query);
585            let answer = collect_answer(cx);
586            let expected = [["$X = a"], ["$X = b"]];
587            assert_eq!(answer, expected);
588        }
589
590        let mut db = Database::new();
591        let interner = Interner::new();
592
593        for _ in 0..2 {
594            insert_dataset(
595                &mut db,
596                &interner,
597                r"
598                f(a).
599                f(b).
600                g($X) :- f($X).
601                ",
602            );
603            let len = db.stor.len();
604            assert_query(&mut db, &interner);
605            assert_eq!(db.stor.len(), len);
606        }
607    }
608
609    #[test]
610    fn test_not_expression() {
611        let mut db = Database::new();
612        let interner = Interner::new();
613
614        insert_dataset(
615            &mut db,
616            &interner,
617            r"
618            g(a).
619            f($X) :- \+ g($X).
620            ",
621        );
622
623        let query: Expr<'_> = parse::parse_str("f(a).", &interner).unwrap();
624        let answer = collect_answer(db.query(query));
625        assert!(answer.is_empty());
626
627        let query: Expr<'_> = parse::parse_str("f(b).", &interner).unwrap();
628        let answer = collect_answer(db.query(query));
629        assert_eq!(answer.len(), 1);
630    }
631
632    #[test]
633    fn test_and_expression() {
634        let mut db = Database::new();
635        let interner = Interner::new();
636
637        insert_dataset(
638            &mut db,
639            &interner,
640            r"
641            g(a).
642            g(b).
643            h(b).
644            f($X) :- g($X), h($X).
645            ",
646        );
647
648        let query: Expr<'_> = parse::parse_str("f($X).", &interner).unwrap();
649        let answer = collect_answer(db.query(query));
650        let expected = [["$X = b"]];
651        assert_eq!(answer, expected);
652    }
653
654    #[test]
655    fn test_or_expression() {
656        let mut db = Database::new();
657        let interner = Interner::new();
658
659        insert_dataset(
660            &mut db,
661            &interner,
662            r"
663            g(a).
664            h(b).
665            f($X) :- g($X); h($X).
666            ",
667        );
668
669        let query: Expr<'_> = parse::parse_str("f($X).", &interner).unwrap();
670        let answer = collect_answer(db.query(query));
671        let expected = [["$X = a"], ["$X = b"]];
672        assert_eq!(answer, expected);
673    }
674
675    #[test]
676    fn test_mixed_expression() {
677        let mut db = Database::new();
678        let interner = Interner::new();
679
680        insert_dataset(
681            &mut db,
682            &interner,
683            r"
684            g(b).
685            g(c).
686
687            h(b).
688
689            i(a).
690            i(b).
691            i(c).
692
693            f($X) :- (\+ g($X); h($X)), i($X).
694            ",
695        );
696
697        let query: Expr<'_> = parse::parse_str("f($X).", &interner).unwrap();
698        let answer = collect_answer(db.query(query));
699        let expected = [["$X = b"]];
700        assert_eq!(answer, expected);
701    }
702
703    #[test]
704    fn test_simple_recursion() {
705        let mut db = Database::new();
706        let interner = Interner::new();
707
708        insert_dataset(
709            &mut db,
710            &interner,
711            r"
712            impl(Clone, a).
713            impl(Clone, b).
714            impl(Clone, c).
715            impl(Clone, Vec($T)) :- impl(Clone, $T).
716            ",
717        );
718
719        let query: Expr<'_> = parse::parse_str("impl(Clone, $T).", &interner).unwrap();
720        let mut cx = db.query(query);
721
722        let mut assert_next = |expected: &[&str]| {
723            let eval = cx.prove_next().unwrap();
724            let assignments = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
725            assert_eq!(assignments, expected);
726        };
727
728        assert_next(&["$T = a"]);
729        assert_next(&["$T = b"]);
730        assert_next(&["$T = c"]);
731        assert_next(&["$T = Vec(a)"]);
732        assert_next(&["$T = Vec(b)"]);
733        assert_next(&["$T = Vec(c)"]);
734        assert_next(&["$T = Vec(Vec(a))"]);
735        assert_next(&["$T = Vec(Vec(b))"]);
736        assert_next(&["$T = Vec(Vec(c))"]);
737    }
738
739    #[test]
740    fn test_right_recursion() {
741        let mut db = Database::new();
742        let interner = Interner::new();
743
744        insert_dataset(
745            &mut db,
746            &interner,
747            r"
748            child(a, b).
749            child(b, c).
750            child(c, d).
751            descend($X, $Y) :- child($X, $Y).
752            descend($X, $Z) :- child($X, $Y), descend($Y, $Z).
753            ",
754        );
755
756        let query: Expr<'_> = parse::parse_str("descend($X, $Y).", &interner).unwrap();
757        let mut answer = collect_answer(db.query(query));
758
759        let mut expected = [
760            ["$X = a", "$Y = b"],
761            ["$X = a", "$Y = c"],
762            ["$X = a", "$Y = d"],
763            ["$X = b", "$Y = c"],
764            ["$X = b", "$Y = d"],
765            ["$X = c", "$Y = d"],
766        ];
767
768        answer.sort_unstable();
769        expected.sort_unstable();
770        assert_eq!(answer, expected);
771    }
772
773    // SLG resolution (tabling) is required to pass this test.
774    #[test]
775    fn test_mid_recursion() {
776        let mut db = Database::new();
777        let interner = Interner::new();
778
779        insert_dataset(
780            &mut db,
781            &interner,
782            r"
783            edge(a, b).
784            edge(b, c).
785            edge(c, a).
786            path($X, $Y) :- edge($X, $Z), path($Z, $W), edge($W, $Y).
787            path($X, $Y) :- edge($X, $Y).
788            ",
789        );
790
791        let query: Expr<'_> = parse::parse_str("path($X, $Y).", &interner).unwrap();
792        let mut answer = collect_answer(db.query(query));
793
794        let mut expected = [
795            ["$X = a", "$Y = a"],
796            ["$X = a", "$Y = b"],
797            ["$X = a", "$Y = c"],
798            ["$X = b", "$Y = a"],
799            ["$X = b", "$Y = b"],
800            ["$X = b", "$Y = c"],
801            ["$X = c", "$Y = a"],
802            ["$X = c", "$Y = b"],
803            ["$X = c", "$Y = c"],
804        ];
805
806        answer.sort_unstable();
807        expected.sort_unstable();
808        assert_eq!(answer, expected);
809    }
810
811    // SLG resolution (tabling) is required to pass this test.
812    #[test]
813    fn test_left_recursion() {
814        let mut db = Database::new();
815        let interner = Interner::new();
816
817        insert_dataset(
818            &mut db,
819            &interner,
820            r"
821            parent(a, b).
822            parent(b, c).
823            parent(c, d).
824            ancestor($X, $Y) :- ancestor($X, $Z), parent($Z, $Y).
825            ancestor($X, $Y) :- parent($X, $Y).
826            ",
827        );
828
829        let query: Expr<'_> = parse::parse_str("ancestor($X, $Y).", &interner).unwrap();
830        let mut answer = collect_answer(db.query(query));
831
832        let mut expected = [
833            ["$X = a", "$Y = b"],
834            ["$X = a", "$Y = c"],
835            ["$X = a", "$Y = d"],
836            ["$X = b", "$Y = c"],
837            ["$X = b", "$Y = d"],
838            ["$X = c", "$Y = d"],
839        ];
840
841        answer.sort_unstable();
842        expected.sort_unstable();
843        assert_eq!(answer, expected);
844    }
845
846    #[test]
847    fn test_discarding_uncomitted_change() {
848        let mut db = Database::new();
849        let interner = Interner::new();
850
851        let clause: Clause<'_> = parse::parse_str("f(a).", &interner).unwrap();
852        db.insert_clause(clause);
853        let fa_state = db.state();
854        db.commit();
855
856        let clause: Clause<'_> = parse::parse_str("f(b).", &interner).unwrap();
857        db.insert_clause(clause);
858
859        let query: Expr<'_> = parse::parse_str("f($X).", &interner).unwrap();
860        let answer = collect_answer(db.query(query));
861
862        // `f(b).` was discarded.
863        let expected = [["$X = a"]];
864        assert_eq!(answer, expected);
865        assert_eq!(db.state(), fa_state);
866    }
867
868    // === Test helper functions ===
869
870    fn insert_dataset<'int>(db: &mut Database<'int>, interner: &'int Interner, text: &str) {
871        let dataset: ClauseDataset<'int> = parse::parse_str(text, interner).unwrap();
872        db.insert_dataset(dataset);
873        db.commit();
874    }
875
876    fn collect_answer(mut cx: ProveCx<'_, '_>) -> Vec<Vec<String>> {
877        let mut v = Vec::new();
878        while let Some(eval) = cx.prove_next() {
879            let x = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
880            v.push(x);
881        }
882        v
883    }
884}
885
886#[cfg(test)]
887mod tests {
888    use crate::{Atom, Clause, ClauseDataset, Database, Expr, ProveCx, Term};
889
890    #[test]
891    fn test_custom_atom() {
892        #[allow(non_camel_case_types)]
893        #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
894        enum A {
895            child,
896            descend,
897            a,
898            b,
899            c,
900            d,
901            X,
902            Y,
903            Z,
904        }
905
906        impl Atom for A {
907            fn is_variable(&self) -> bool {
908                matches!(self, A::X | A::Y | A::Z)
909            }
910        }
911
912        let mut db = Database::new();
913
914        let child_a_b = Clause::fact(Term::compound(
915            A::child,
916            [Term::atom(A::a), Term::atom(A::b)],
917        ));
918        let child_b_c = Clause::fact(Term::compound(
919            A::child,
920            [Term::atom(A::b), Term::atom(A::c)],
921        ));
922        let child_c_d = Clause::fact(Term::compound(
923            A::child,
924            [Term::atom(A::c), Term::atom(A::d)],
925        ));
926        let descend_x_y = Clause::rule(
927            Term::compound(A::descend, [Term::atom(A::X), Term::atom(A::Y)]),
928            Expr::term_compound(A::child, [Term::atom(A::X), Term::atom(A::Y)]),
929        );
930        let descend_x_z = Clause::rule(
931            Term::compound(A::descend, [Term::atom(A::X), Term::atom(A::Z)]),
932            Expr::expr_and([
933                Expr::term_compound(A::child, [Term::atom(A::X), Term::atom(A::Y)]),
934                Expr::term_compound(A::descend, [Term::atom(A::Y), Term::atom(A::Z)]),
935            ]),
936        );
937        insert_dataset(
938            &mut db,
939            crate::ClauseDataset(vec![
940                child_a_b,
941                child_b_c,
942                child_c_d,
943                descend_x_y,
944                descend_x_z,
945            ]),
946        );
947
948        let query = Expr::term_compound(A::descend, [Term::atom(A::X), Term::atom(A::Y)]);
949        let mut answer = collect_answer(db.query(query));
950
951        let mut expected = [
952            [
953                (Term::atom(A::X), Term::atom(A::a)),
954                (Term::atom(A::Y), Term::atom(A::b)),
955            ],
956            [
957                (Term::atom(A::X), Term::atom(A::a)),
958                (Term::atom(A::Y), Term::atom(A::c)),
959            ],
960            [
961                (Term::atom(A::X), Term::atom(A::a)),
962                (Term::atom(A::Y), Term::atom(A::d)),
963            ],
964            [
965                (Term::atom(A::X), Term::atom(A::b)),
966                (Term::atom(A::Y), Term::atom(A::c)),
967            ],
968            [
969                (Term::atom(A::X), Term::atom(A::b)),
970                (Term::atom(A::Y), Term::atom(A::d)),
971            ],
972            [
973                (Term::atom(A::X), Term::atom(A::c)),
974                (Term::atom(A::Y), Term::atom(A::d)),
975            ],
976        ];
977
978        answer.sort_unstable();
979        expected.sort_unstable();
980        assert_eq!(answer, expected);
981    }
982
983    // === Test helper functions ===
984
985    fn insert_dataset<T: Atom>(db: &mut Database<T>, dataset: ClauseDataset<T>) {
986        db.insert_dataset(dataset);
987        db.commit();
988    }
989
990    fn collect_answer<T: Atom>(mut cx: ProveCx<'_, T>) -> Vec<Vec<(Term<T>, Term<T>)>> {
991        let mut v = Vec::new();
992        while let Some(eval) = cx.prove_next() {
993            let pairs = eval.map(|assign| (assign.lhs(), assign.rhs())).collect();
994            v.push(pairs);
995        }
996        v
997    }
998}