Skip to main content

logic_eval/prove/
db.rs

1use super::{
2    prover::{
3        format::{NamedExprView, NamedTermView},
4        Integer, NameIntMap, NameIntMapState, ProveCx, Prover,
5    },
6    repr::{ClauseId, TermStorage, TermStorageLen},
7};
8use crate::{
9    parse::{
10        repr::{Clause, Expr, Predicate, Term},
11        text::Name,
12        VAR_PREFIX,
13    },
14    prove::repr::{ExprKind, ExprView, TermView, TermViewIter},
15    ClauseDatasetIn, ClauseIn, DefaultInterner, ExprIn, Int2Name, Intern, Map,
16};
17use indexmap::{IndexMap, IndexSet};
18use logic_eval_util::reference::Ref;
19use std::{
20    fmt::{self, Write},
21    iter,
22};
23
24#[derive(Debug)]
25pub struct Database<'int, Int: Intern = DefaultInterner> {
26    /// Clause id dataset.
27    clauses: IndexMap<Predicate<Integer>, Vec<ClauseId>>,
28
29    /// We do not allow duplicated clauses in the dataset.
30    clause_texts: IndexSet<String>,
31
32    /// Term and expression storage.
33    stor: TermStorage<Integer>,
34
35    /// Proof search engine.
36    prover: Prover,
37
38    /// Mappings between [`Name`] and [`Int`].
39    ///
40    /// [`Int`] is internally used for fast comparison, but we need to get it
41    /// back to [`Name`] for the clients.
42    nimap: NameIntMap<'int, Int>,
43
44    /// States of DB's fields.
45    ///
46    /// This is used when we discard some changes on the DB.
47    revert_point: Option<DatabaseState>,
48
49    interner: Ref<'int, Int>,
50}
51
52impl Database<'static> {
53    /// Creates a database with default string interner.
54    ///
55    /// Because creating database through this method makes a leaked memory, you should not forget
56    /// to call [`dealloc`](Self::dealloc).
57    #[allow(clippy::new_without_default)]
58    pub fn new() -> Self {
59        let leaked = Box::leak(Box::new(DefaultInterner::default()));
60        let interner = Ref::from_mut(leaked);
61
62        Self {
63            clauses: IndexMap::default(),
64            clause_texts: IndexSet::default(),
65            stor: TermStorage::new(),
66            prover: Prover::new(),
67            nimap: NameIntMap::new(interner),
68            revert_point: None,
69            interner,
70        }
71    }
72
73    #[rustfmt::skip]
74    pub fn dealloc(self) {
75        // Prevents us from accidently implementing Clone for the type. Deallocation relies on the
76        // fact that the type cannot be cloned, so that we have only one instance.
77        #[allow(dead_code)]
78        {
79            struct ImplDetector<T>(std::marker::PhantomData<T>);
80            trait NotClone { const IS_CLONE: bool = false; }
81            impl<T> NotClone for ImplDetector<T> {}
82            impl<T: Clone> ImplDetector<T> { const IS_CLONE: bool = true; }
83            const _: () = assert!(!ImplDetector::<Database<'static>>::IS_CLONE);
84        }
85
86        // Safety:
87        // * There's only one instance and we got the ownership of the instance.
88        let _ = unsafe { Box::from_raw(self.interner.as_ptr()) };
89    }
90}
91
92impl<'int, Int: Intern> Database<'int, Int> {
93    pub fn with_interner(interner: &'int Int) -> Self {
94        let interner = Ref::from_ref(interner);
95
96        Self {
97            clauses: IndexMap::default(),
98            clause_texts: IndexSet::default(),
99            stor: TermStorage::new(),
100            prover: Prover::new(),
101            nimap: NameIntMap::new(interner),
102            revert_point: None,
103            interner,
104        }
105    }
106
107    pub fn interner(&self) -> &'int Int {
108        self.interner.as_ref()
109    }
110
111    pub fn terms(&self) -> NamedTermViewIter<'_, 'int, Int> {
112        NamedTermViewIter {
113            term_iter: self.stor.terms.terms(),
114            int2name: &self.nimap.int2name,
115        }
116    }
117
118    pub fn clauses(&self) -> ClauseIter<'_, 'int, Int> {
119        ClauseIter {
120            clauses: &self.clauses,
121            stor: &self.stor,
122            int2name: &self.nimap.int2name,
123            i: 0,
124            j: 0,
125        }
126    }
127
128    pub fn insert_dataset(&mut self, dataset: ClauseDatasetIn<'int, Int>) {
129        for clause in dataset {
130            self.insert_clause(clause);
131        }
132    }
133
134    pub fn insert_clause(&mut self, clause: ClauseIn<'int, Int>) {
135        // Saves current state. We will revert DB when the change is not
136        // committed.
137        if self.revert_point.is_none() {
138            self.revert_point = Some(self.state());
139        }
140
141        // If this DB contains the given clause, then returns.
142        let serialized = if let Some(converted) =
143            Clause::convert_var_into_num(&clause, self.interner.as_ref())
144        {
145            converted.to_string()
146        } else {
147            clause.to_string()
148        };
149        if !self.clause_texts.insert(serialized) {
150            return;
151        }
152
153        let clause = clause.map(&mut |name| self.nimap.name_to_int(name));
154
155        let key = clause.head.predicate();
156        let value = ClauseId {
157            head: self.stor.insert_term(clause.head),
158            body: clause.body.map(|expr| self.stor.insert_expr(expr)),
159        };
160
161        self.clauses
162            .entry(key)
163            .and_modify(|similar_clauses| {
164                if similar_clauses.iter().all(|clause| clause != &value) {
165                    similar_clauses.push(value);
166                }
167            })
168            .or_insert(vec![value]);
169    }
170
171    pub fn query(&mut self, expr: ExprIn<'int, Int>) -> ProveCx<'_, 'int, Int> {
172        // Discards uncomitted changes.
173        if let Some(revert_point) = self.revert_point.take() {
174            self.revert(revert_point);
175        }
176
177        self.prover
178            .prove(expr, &self.clauses, &mut self.stor, &mut self.nimap)
179    }
180
181    pub fn commit(&mut self) {
182        self.revert_point.take();
183    }
184
185    /// * sanitize - Removes unacceptable characters from prolog.
186    pub fn to_prolog<F: FnMut(&str) -> &str>(&self, sanitize: F) -> String {
187        let mut prolog_text = String::new();
188
189        let mut conv_map = ConversionMap {
190            int_to_str: Map::default(),
191            sanitized_to_suffix: Map::default(),
192            int2name: &self.nimap.int2name,
193            sanitizer: sanitize,
194        };
195
196        for clauses in self.clauses.values() {
197            for clause in clauses {
198                let head = self.stor.get_term(clause.head);
199                write_term(head, &mut conv_map, &mut prolog_text);
200
201                if let Some(body) = clause.body {
202                    prolog_text.push_str(" :- ");
203
204                    let body = self.stor.get_expr(body);
205                    write_expr(body, &mut conv_map, &mut prolog_text);
206                }
207
208                prolog_text.push_str(".\n");
209            }
210        }
211
212        return prolog_text;
213
214        // === Internal helper functions ===
215
216        struct ConversionMap<'a, 'int, Int: Intern, F> {
217            int_to_str: Map<Integer, String>,
218            // e.g. 0 -> No suffix, 1 -> _1, 2 -> _2, ...
219            sanitized_to_suffix: Map<&'a str, u32>,
220            int2name: &'a Int2Name<'int, Int>,
221            sanitizer: F,
222        }
223
224        impl<Int, F> ConversionMap<'_, '_, Int, F>
225        where
226            Int: Intern,
227            F: FnMut(&str) -> &str,
228        {
229            fn int_to_str(&mut self, int: Integer) -> &str {
230                self.int_to_str.entry(int).or_insert_with(|| {
231                    let name = self.int2name.get(&int).unwrap();
232                    let name: &str = name.as_ref();
233
234                    let mut is_var = false;
235
236                    // Removes variable prefix.
237                    let name = if name.starts_with(VAR_PREFIX) {
238                        is_var = true;
239                        &name[1..]
240                    } else {
241                        name
242                    };
243
244                    // Removes other user-defined characters.
245                    let pure_name = (self.sanitizer)(name);
246
247                    let suffix = self
248                        .sanitized_to_suffix
249                        .entry(pure_name)
250                        .and_modify(|x| *x += 1)
251                        .or_insert(0);
252
253                    let mut buf = String::new();
254
255                    if is_var {
256                        let upper = pure_name.chars().next().unwrap().to_uppercase();
257                        for c in upper {
258                            buf.push(c);
259                        }
260                    } else {
261                        let lower = pure_name.chars().next().unwrap().to_lowercase();
262                        for c in lower {
263                            buf.push(c);
264                        }
265                    };
266                    buf.push_str(&pure_name[1..]);
267
268                    if *suffix == 0 {
269                        buf
270                    } else {
271                        write!(&mut buf, "_{suffix}").unwrap();
272                        buf
273                    }
274                })
275            }
276        }
277
278        fn write_term<Int, F>(
279            term: TermView<'_, Integer>,
280            conv_map: &mut ConversionMap<'_, '_, Int, F>,
281            prolog_text: &mut String,
282        ) where
283            Int: Intern,
284            F: FnMut(&str) -> &str,
285        {
286            let functor = term.functor();
287            let args = term.args();
288            let num_args = args.len();
289
290            let functor = conv_map.int_to_str(*functor);
291            prolog_text.push_str(functor);
292
293            if num_args > 0 {
294                prolog_text.push('(');
295                for (i, arg) in args.enumerate() {
296                    write_term(arg, conv_map, prolog_text);
297                    if i + 1 < num_args {
298                        prolog_text.push_str(", ");
299                    }
300                }
301                prolog_text.push(')');
302            }
303        }
304
305        fn write_expr<Int, F>(
306            expr: ExprView<'_, Integer>,
307            conv_map: &mut ConversionMap<'_, '_, Int, F>,
308            prolog_text: &mut String,
309        ) where
310            Int: Intern,
311            F: FnMut(&str) -> &str,
312        {
313            match expr.as_kind() {
314                ExprKind::Term(term) => {
315                    write_term(term, conv_map, prolog_text);
316                }
317                ExprKind::Not(inner) => {
318                    prolog_text.push_str("\\+ ");
319                    if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) {
320                        prolog_text.push('(');
321                        write_expr(inner, conv_map, prolog_text);
322                        prolog_text.push(')');
323                    } else {
324                        write_expr(inner, conv_map, prolog_text);
325                    }
326                }
327                ExprKind::And(args) => {
328                    let num_args = args.len();
329                    for (i, arg) in args.enumerate() {
330                        if matches!(arg.as_kind(), ExprKind::Or(_)) {
331                            prolog_text.push('(');
332                            write_expr(arg, conv_map, prolog_text);
333                            prolog_text.push(')');
334                        } else {
335                            write_expr(arg, conv_map, prolog_text);
336                        }
337                        if i + 1 < num_args {
338                            prolog_text.push_str(", ");
339                        }
340                    }
341                }
342                ExprKind::Or(args) => {
343                    let num_args = args.len();
344                    for (i, arg) in args.enumerate() {
345                        write_expr(arg, conv_map, prolog_text);
346                        if i + 1 < num_args {
347                            prolog_text.push_str("; ");
348                        }
349                    }
350                }
351            }
352        }
353    }
354
355    fn revert(
356        &mut self,
357        DatabaseState {
358            clauses_len,
359            clause_texts_len,
360            stor_len,
361            nimap_state,
362        }: DatabaseState,
363    ) {
364        self.clauses.truncate(clauses_len.len());
365        for (i, len) in clauses_len.into_iter().enumerate() {
366            self.clauses[i].truncate(len);
367        }
368        self.clause_texts.truncate(clause_texts_len);
369        self.stor.truncate(stor_len);
370        self.nimap.revert(nimap_state);
371        // `self.prover: Prover` does not store any persistent data.
372    }
373
374    fn state(&self) -> DatabaseState {
375        DatabaseState {
376            clauses_len: self.clauses.values().map(|v| v.len()).collect(),
377            clause_texts_len: self.clause_texts.len(),
378            stor_len: self.stor.len(),
379            nimap_state: self.nimap.state(),
380        }
381    }
382}
383
384#[derive(Debug, PartialEq, Eq)]
385struct DatabaseState {
386    clauses_len: Vec<usize>,
387    clause_texts_len: usize,
388    stor_len: TermStorageLen,
389    nimap_state: NameIntMapState,
390}
391
392#[derive(Clone)]
393pub struct ClauseIter<'a, 'int, Int: Intern> {
394    clauses: &'a IndexMap<Predicate<Integer>, Vec<ClauseId>>,
395    stor: &'a TermStorage<Integer>,
396    int2name: &'a Int2Name<'int, Int>,
397    i: usize,
398    j: usize,
399}
400
401impl<'a, 'int, Int: Intern> Iterator for ClauseIter<'a, 'int, Int> {
402    type Item = ClauseRef<'a, 'int, Int>;
403
404    fn next(&mut self) -> Option<Self::Item> {
405        let id = loop {
406            let (_, group) = self.clauses.get_index(self.i)?;
407
408            if let Some(id) = group.get(self.j) {
409                self.j += 1;
410                break *id;
411            }
412
413            self.i += 1;
414            self.j = 0;
415        };
416
417        Some(ClauseRef {
418            id,
419            stor: self.stor,
420            int2name: self.int2name,
421        })
422    }
423}
424
425impl<Int: Intern> iter::FusedIterator for ClauseIter<'_, '_, Int> {}
426
427pub struct ClauseRef<'a, 'int, Int: Intern> {
428    id: ClauseId,
429    stor: &'a TermStorage<Integer>,
430    int2name: &'a Int2Name<'int, Int>,
431}
432
433impl<'a, 'int, Int: Intern> ClauseRef<'a, 'int, Int> {
434    pub fn head(&self) -> NamedTermView<'a, 'int, Int> {
435        let head = self.stor.get_term(self.id.head);
436        NamedTermView::new(head, self.int2name)
437    }
438
439    pub fn body(&self) -> Option<NamedExprView<'a, 'int, Int>> {
440        self.id.body.map(|id| {
441            let body = self.stor.get_expr(id);
442            NamedExprView::new(body, self.int2name)
443        })
444    }
445}
446
447impl<Int: Intern> fmt::Display for ClauseRef<'_, '_, Int> {
448    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449        fmt::Display::fmt(&self.head(), f)?;
450
451        if let Some(body) = self.body() {
452            f.write_str(" :- ")?;
453            fmt::Display::fmt(&body, f)?
454        }
455
456        f.write_char('.')
457    }
458}
459
460impl<Int: Intern> fmt::Debug for ClauseRef<'_, '_, Int> {
461    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462        let mut d = f.debug_struct("Clause");
463
464        let head = self.stor.get_term(self.id.head);
465        d.field("head", &NamedTermView::new(head, self.int2name));
466
467        if let Some(body) = self.id.body {
468            let body = self.stor.get_expr(body);
469            d.field("body", &NamedExprView::new(body, self.int2name));
470        }
471
472        d.finish()
473    }
474}
475
476impl Clause<Name<()>> {
477    /// Turns variables into `_$0`, `_$1`, and so on.
478    pub(crate) fn convert_var_into_num<'int, Int: Intern>(
479        this: &ClauseIn<'int, Int>,
480        interner: &'int Int,
481    ) -> Option<ClauseIn<'int, Int>> {
482        let mut cloned: Option<ClauseIn<'int, Int>> = None;
483
484        let mut i = 0;
485
486        while let Some(var) = find_var_in_clause(cloned.as_ref().unwrap_or(this)) {
487            let from = var.clone();
488
489            let mut convert = |term: &Term<Name<_>>| {
490                (term == &from).then_some(Term {
491                    functor: Name::with_intern(&format!("_{VAR_PREFIX}{i}"), interner),
492                    args: [].into(),
493                })
494            };
495
496            if let Some(cloned) = &mut cloned {
497                cloned.replace_term(&mut convert);
498            } else {
499                let mut this = this.clone();
500                this.replace_term(&mut convert);
501                cloned = Some(this);
502            }
503
504            i += 1;
505        }
506
507        return cloned;
508
509        // === Internal helper functions ===
510
511        fn find_var_in_clause<T: AsRef<str>>(clause: &Clause<Name<T>>) -> Option<&Term<Name<T>>> {
512            let var = find_var_in_term(&clause.head);
513            if var.is_some() {
514                return var;
515            }
516            find_var_in_expr(clause.body.as_ref()?)
517        }
518
519        fn find_var_in_expr<T: AsRef<str>>(expr: &Expr<Name<T>>) -> Option<&Term<Name<T>>> {
520            match expr {
521                Expr::Term(term) => find_var_in_term(term),
522                Expr::Not(inner) => find_var_in_expr(inner),
523                Expr::And(args) | Expr::Or(args) => args.iter().find_map(find_var_in_expr),
524            }
525        }
526
527        fn find_var_in_term<T: AsRef<str>>(term: &Term<Name<T>>) -> Option<&Term<Name<T>>> {
528            const _: () = assert!(VAR_PREFIX == '$');
529
530            if term.is_variable() && !term.functor.as_ref().starts_with("_$") {
531                Some(term)
532            } else {
533                term.args.iter().find_map(find_var_in_term)
534            }
535        }
536    }
537}
538
539pub struct NamedTermViewIter<'a, 'int, Int: Intern> {
540    term_iter: TermViewIter<'a, Integer>,
541    int2name: &'a Int2Name<'int, Int>,
542}
543
544impl<'a, 'int, Int: Intern> Iterator for NamedTermViewIter<'a, 'int, Int> {
545    type Item = NamedTermView<'a, 'int, Int>;
546
547    fn next(&mut self) -> Option<Self::Item> {
548        self.term_iter
549            .next()
550            .map(|view| NamedTermView::new(view, self.int2name))
551    }
552}
553
554impl<Int: Intern> iter::FusedIterator for NamedTermViewIter<'_, '_, Int> {}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559    use crate::parse::{
560        self,
561        repr::{Clause, ClauseDataset, Expr},
562    };
563    use any_intern::DroplessInterner;
564
565    #[test]
566    fn test_parse() {
567        fn assert(text: &str, interner: &impl Intern) {
568            let clause: Clause<Name<_>> = parse::parse_str(text, interner).unwrap();
569            assert_eq!(text, clause.to_string());
570        }
571
572        let interner = DroplessInterner::default();
573
574        assert("f.", &interner);
575        assert("f(a, b).", &interner);
576        assert("f(a, b) :- f.", &interner);
577        assert("f(a, b) :- f(a).", &interner);
578        assert("f(a, b) :- f(a), f(b).", &interner);
579        assert("f(a, b) :- f(a); f(b).", &interner);
580        assert("f(a, b) :- f(a), (f(b); f(c)).", &interner);
581    }
582
583    #[test]
584    fn test_serial_queries() {
585        fn insert<Int: Intern>(db: &mut Database<'_, Int>) {
586            insert_dataset(
587                db,
588                r"
589                f(a).
590                f(b).
591                g($X) :- f($X).
592                ",
593            );
594        }
595
596        fn query<Int: Intern>(db: &mut Database<'_, Int>) {
597            let query = "g($X).";
598            let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
599            let answer = collect_answer(db.query(query));
600
601            let expected = [["$X = a"], ["$X = b"]];
602
603            assert_eq!(answer, expected);
604        }
605
606        let mut db = Database::new();
607
608        insert(&mut db);
609        let org_stor_len = db.stor.len();
610        query(&mut db);
611        debug_assert_eq!(org_stor_len, db.stor.len());
612
613        insert(&mut db);
614        debug_assert_eq!(org_stor_len, db.stor.len());
615        query(&mut db);
616        debug_assert_eq!(org_stor_len, db.stor.len());
617
618        db.dealloc();
619    }
620
621    #[test]
622    fn test_various_expressions() {
623        test_not_expression();
624        test_and_expression();
625        test_or_expression();
626        test_mixed_expression();
627    }
628
629    fn test_not_expression() {
630        let mut db = Database::new();
631
632        insert_dataset(
633            &mut db,
634            r"
635            g(a).
636            f($X) :- \+ g($X).
637            ",
638        );
639
640        let query = "f(a).";
641        let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
642        let answer = collect_answer(db.query(query));
643        assert!(answer.is_empty());
644
645        let query = "f(b).";
646        let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
647        let answer = collect_answer(db.query(query));
648        assert_eq!(answer.len(), 1);
649
650        db.dealloc();
651    }
652
653    fn test_and_expression() {
654        let mut db = Database::new();
655
656        insert_dataset(
657            &mut db,
658            r"
659            g(a).
660            g(b).
661            h(b).
662            f($X) :- g($X), h($X).
663            ",
664        );
665
666        let query = "f($X).";
667        let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
668        let answer = collect_answer(db.query(query));
669
670        let expected = [["$X = b"]];
671        assert_eq!(answer, expected);
672
673        db.dealloc();
674    }
675
676    fn test_or_expression() {
677        let mut db = Database::new();
678
679        insert_dataset(
680            &mut db,
681            r"
682            g(a).
683            h(b).
684            f($X) :- g($X); h($X).
685            ",
686        );
687
688        let query = "f($X).";
689        let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
690        let answer = collect_answer(db.query(query));
691
692        let expected = [["$X = a"], ["$X = b"]];
693        assert_eq!(answer, expected);
694
695        db.dealloc();
696    }
697
698    fn test_mixed_expression() {
699        let mut db = Database::new();
700
701        insert_dataset(
702            &mut db,
703            r"
704            g(b).
705            g(c).
706
707            h(b).
708
709            i(a).
710            i(b).
711            i(c).
712
713            f($X) :- (\+ g($X); h($X)), i($X).
714            ",
715        );
716
717        let query = "f($X).";
718        let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
719        let answer = collect_answer(db.query(query));
720
721        let expected = [["$X = b"]];
722        assert_eq!(answer, expected);
723
724        db.dealloc();
725    }
726
727    #[test]
728    fn test_recursion() {
729        test_simple_recursion();
730        test_right_recursion();
731    }
732
733    fn test_simple_recursion() {
734        let mut db = Database::new();
735
736        insert_dataset(
737            &mut db,
738            r"
739            impl(Clone, a).
740            impl(Clone, b).
741            impl(Clone, c).
742            impl(Clone, Vec($T)) :- impl(Clone, $T).
743            ",
744        );
745
746        let query = "impl(Clone, $T).";
747        let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
748        let mut cx = db.query(query);
749
750        let mut assert_next = |expected: &[&str]| {
751            let eval = cx.prove_next().unwrap();
752            let assignments = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
753            assert_eq!(assignments, expected);
754        };
755
756        assert_next(&["$T = a"]);
757        assert_next(&["$T = b"]);
758        assert_next(&["$T = c"]);
759        assert_next(&["$T = Vec(a)"]);
760        assert_next(&["$T = Vec(b)"]);
761        assert_next(&["$T = Vec(c)"]);
762        assert_next(&["$T = Vec(Vec(a))"]);
763        assert_next(&["$T = Vec(Vec(b))"]);
764        assert_next(&["$T = Vec(Vec(c))"]);
765
766        drop(cx);
767        db.dealloc();
768    }
769
770    fn test_right_recursion() {
771        let mut db = Database::new();
772
773        insert_dataset(
774            &mut db,
775            r"
776            child(a, b).
777            child(b, c).
778            child(c, d).
779            descend($X, $Y) :- child($X, $Y).
780            descend($X, $Z) :- child($X, $Y), descend($Y, $Z).
781            ",
782        );
783
784        let query = "descend($X, $Y).";
785        let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
786        let mut answer = collect_answer(db.query(query));
787
788        let mut expected = [
789            ["$X = a", "$Y = b"],
790            ["$X = a", "$Y = c"],
791            ["$X = a", "$Y = d"],
792            ["$X = b", "$Y = c"],
793            ["$X = b", "$Y = d"],
794            ["$X = c", "$Y = d"],
795        ];
796
797        answer.sort_unstable();
798        expected.sort_unstable();
799        assert_eq!(answer, expected);
800
801        db.dealloc();
802    }
803
804    #[test]
805    fn test_discarding_uncomitted_change() {
806        let mut db = Database::new();
807
808        let text = "f(a).";
809        let clause = parse::parse_str(text, db.interner()).unwrap();
810        db.insert_clause(clause);
811        let fa_state = db.state();
812        db.commit();
813
814        let text = "f(b).";
815        let clause = parse::parse_str(text, db.interner()).unwrap();
816        db.insert_clause(clause);
817
818        let query = "f($X).";
819        let query: Expr<Name<_>> = parse::parse_str(query, db.interner()).unwrap();
820        let answer = collect_answer(db.query(query));
821
822        // `f(b).` was discarded.
823        let expected = [["$X = a"]];
824        assert_eq!(answer, expected);
825        assert_eq!(db.state(), fa_state);
826
827        db.dealloc();
828    }
829
830    fn insert_dataset<Int: Intern>(db: &mut Database<'_, Int>, text: &str) {
831        let dataset: ClauseDataset<Name<_>> = parse::parse_str(text, db.interner()).unwrap();
832        db.insert_dataset(dataset);
833        db.commit();
834    }
835
836    fn collect_answer<Int: Intern>(mut cx: ProveCx<'_, '_, Int>) -> Vec<Vec<String>> {
837        let mut v = Vec::new();
838        while let Some(eval) = cx.prove_next() {
839            let x = eval.map(|assign| assign.to_string()).collect::<Vec<_>>();
840            v.push(x);
841        }
842        v
843    }
844}