logic_eval/prove/
prover.rs

1use super::repr::{
2    ApplyResult, ClauseId, ExprId, ExprKind, ExprView, TermDeepView, TermElem, TermId, TermStorage,
3    TermStorageLen, TermView, TermViewMut, UniqueTermArray,
4};
5use crate::{
6    Map,
7    parse::{
8        VAR_PREFIX,
9        repr::{Expr, Predicate, Term},
10        text::Name,
11    },
12};
13use indexmap::IndexMap;
14use std::{
15    collections::VecDeque,
16    fmt::{self, Write},
17    iter,
18    ops::{self, Range},
19};
20
21pub(crate) type ClauseMap = IndexMap<Predicate<Int>, Vec<ClauseId>>;
22
23#[derive(Debug)]
24pub(crate) struct Prover {
25    uni_op: UnificationOperator,
26
27    /// Nodes created during proof search.
28    nodes: Vec<Node>,
29
30    /// Variable assignments.
31    ///
32    /// For example, `assignment[X] = a` means that `X(term id)` is assigned to
33    /// `a(term id)`. If a value is identical to its index, it means the term is
34    /// not assigned to anything.
35    assignments: Vec<usize>,
36
37    /// A given query.
38    query: ExprId,
39
40    /// Variables in the root node(query).
41    ///
42    /// This could be used to find what terms these variables are assigned to.
43    query_vars: Vec<TermId>,
44
45    /// Task queue containing node index.
46    queue: VecDeque<usize>,
47
48    /// A buffer containing mapping between variables and temporary variables.
49    ///
50    /// This buffer is used when we convert variables into temporary variables
51    /// for a clause. After the conversion, this buffer get empty.
52    temp_var_buf: Map<TermId, TermId>,
53
54    /// A monotonically increasing integer that is used for generating
55    /// temporary variables.
56    temp_var_int: u32,
57}
58
59impl Prover {
60    pub(crate) fn new() -> Self {
61        Self {
62            uni_op: UnificationOperator::new(),
63            nodes: Vec::new(),
64            assignments: Vec::new(),
65            query: ExprId(0),
66            query_vars: Vec::new(),
67            queue: VecDeque::new(),
68            temp_var_buf: Map::default(),
69            temp_var_int: 0,
70        }
71    }
72
73    fn clear(&mut self) {
74        self.uni_op.clear();
75        self.nodes.clear();
76        self.assignments.clear();
77        self.query_vars.clear();
78        self.queue.clear();
79    }
80
81    pub(crate) fn prove<'a>(
82        &'a mut self,
83        query: Expr<Name>,
84        clause_map: &'a ClauseMap,
85        stor: &'a mut TermStorage<Int>,
86        nimap: &'a mut NameIntMap,
87    ) -> ProveCx<'a> {
88        self.clear();
89
90        let old_nimap_state = nimap.state();
91        let query = query.map(&mut |name| nimap.name_to_int(name));
92
93        let old_stor_len = stor.len();
94        self.query = stor.insert_expr(query);
95
96        stor.get_expr(self.query)
97            .with_term(&mut |term: TermView<'_, Int>| {
98                term.with_variable(&mut |term| self.query_vars.push(term.id));
99            });
100
101        self.nodes.push(Node {
102            kind: NodeKind::Expr(self.query),
103            uni_delta: 0..0,
104            parent: self.nodes.len(),
105        });
106        self.queue.push_back(0);
107
108        ProveCx::new(self, clause_map, stor, nimap, old_stor_len, old_nimap_state)
109    }
110
111    /// Evaluates the given node with all possible clauses in the clause
112    /// dataset, then returns whether a proof search path is complete or not.
113    ///
114    /// If it reached an end of paths, it returns proof search result within
115    /// `Some`. The proof search result is either true or false, which means
116    /// the expression in the given node is true or not.
117    fn evaluate_node(
118        &mut self,
119        node_index: usize,
120        clause_map: &ClauseMap,
121        stor: &mut TermStorage<Int>,
122    ) -> Option<bool> {
123        let node = self.nodes[node_index].clone();
124        let node_expr = match node.kind {
125            NodeKind::Expr(expr_id) => expr_id,
126            NodeKind::Leaf(eval) => {
127                self.find_assignment(node_index);
128                return Some(eval);
129            }
130        };
131
132        let predicate = stor.get_expr(node_expr).leftmost_term().predicate();
133        let similar_clauses = if let Some(v) = clause_map.get(&predicate) {
134            v.as_slice()
135        } else {
136            &[]
137        };
138
139        let old_len = self.nodes.len();
140        for clause in similar_clauses {
141            let head = stor.get_term(clause.head);
142
143            if !stor.get_expr(node_expr).is_unifiable(head) {
144                continue;
145            }
146
147            let clause = Self::convert_var_into_temp(
148                *clause,
149                stor,
150                &mut self.temp_var_buf,
151                &mut self.temp_var_int,
152            );
153            if let Some(new_node) = self.unify_node_with_clause(node_index, clause, stor) {
154                self.nodes.push(new_node);
155                self.queue.push_back(self.nodes.len() - 1);
156            }
157        }
158
159        // We may need to apply true or false to the leftmost term of the node
160        // expression due to unification failure or exhaustive search.
161        // - Unification failure means the leftmost term should be false.
162        // - But we need to consider exhaustive search possibility at the same
163        //   time.
164
165        let expr = stor.get_expr(node_expr);
166        let eval = self.nodes.len() > old_len;
167        let mut need_apply = None;
168
169        let lost_possibility = match assume_leftmost_term(expr, eval) {
170            AssumeResult::Incomplete { lost } => lost,
171            AssumeResult::Complete { lost, .. } => lost,
172        };
173        if lost_possibility {
174            need_apply = Some(!eval);
175        } else if !eval {
176            need_apply = Some(false);
177        }
178
179        if let Some(to) = need_apply {
180            let mut expr = stor.get_expr_mut(node_expr);
181            let kind = match expr.apply_to_leftmost_term(to) {
182                ApplyResult::Expr => NodeKind::Expr(expr.id()),
183                ApplyResult::Complete(eval) => NodeKind::Leaf(eval),
184            };
185            self.nodes.push(Node {
186                kind,
187                uni_delta: 0..0,
188                parent: node_index,
189            });
190            self.queue.push_back(self.nodes.len() - 1);
191        }
192
193        return None;
194
195        // === Internal helper functions ===
196
197        enum AssumeResult {
198            /// The whole expression could not completely evaluated even though
199            /// the assumption is realized.
200            Incomplete {
201                /// Whether or not the assumption will make us lose some search
202                /// possibilities.
203                lost: bool,
204            },
205
206            /// The whole expression will be completely evaluated if the
207            /// assumption is realized.
208            Complete {
209                /// Evalauted as true or false.
210                eval: bool,
211                lost: bool,
212            },
213        }
214
215        fn assume_leftmost_term(expr: ExprView<'_, Int>, to: bool) -> AssumeResult {
216            match expr.as_kind() {
217                ExprKind::Term(_) => AssumeResult::Complete {
218                    eval: to,
219                    lost: false,
220                },
221                ExprKind::Not(inner) => match assume_leftmost_term(inner, to) {
222                    res @ AssumeResult::Incomplete { .. } => res,
223                    AssumeResult::Complete { eval, lost } => {
224                        AssumeResult::Complete { eval: !eval, lost }
225                    }
226                },
227                ExprKind::And(mut args) => {
228                    // Unlike 'Or', even if 'And' contains variables and the
229                    // whole expression will be evaluated false, those variables
230                    // must be ignored. They don't belong to 'lost'.
231                    match assume_leftmost_term(args.next().unwrap(), to) {
232                        res @ AssumeResult::Incomplete { .. } => res,
233                        AssumeResult::Complete { eval, lost } => {
234                            if !eval {
235                                AssumeResult::Complete { eval: false, lost }
236                            } else {
237                                AssumeResult::Incomplete { lost }
238                            }
239                        }
240                    }
241                }
242                ExprKind::Or(mut args) => {
243                    // The whole 'Or' is true if any argument is true. But we
244                    // will lose possible search paths if we ignore right
245                    // variables.
246                    match assume_leftmost_term(args.next().unwrap(), to) {
247                        res @ AssumeResult::Incomplete { .. } => res,
248                        AssumeResult::Complete { eval, lost } => {
249                            if eval {
250                                let right_var = args.any(|arg| arg.contains_variable());
251                                AssumeResult::Complete {
252                                    eval: true,
253                                    lost: lost | right_var,
254                                }
255                            } else {
256                                AssumeResult::Incomplete { lost }
257                            }
258                        }
259                    }
260                }
261            }
262        }
263    }
264
265    /// Replaces variables in a clause with other temporary variables.
266    ///
267    // Why we replace variables with temporary variables in clauses before
268    // unifying?
269    // 1. That's because variables in different clauses are actually different
270    // from each other even they have the same identity. Variable's identity
271    // is valid only in the one clause where they belong.
272    // 2. Also, we apply this method whenever unification happens because one
273    // clause can be used mupltiple times in a single proof search path. Then
274    // it is considered as a different clause.
275    fn convert_var_into_temp(
276        mut clause_id: ClauseId,
277        stor: &mut TermStorage<Int>,
278        temp_var_buf: &mut Map<TermId, TermId>,
279        temp_var_int: &mut u32,
280    ) -> ClauseId {
281        debug_assert!(temp_var_buf.is_empty());
282
283        let mut f = |terms: &mut UniqueTermArray<Int>, term_id: TermId| {
284            let term = terms.get_mut(term_id);
285            if term.is_variable() {
286                let src = term.id();
287
288                temp_var_buf.entry(src).or_insert_with(|| {
289                    let temp_term = Term {
290                        functor: Int::temporary(*temp_var_int),
291                        args: [].into(),
292                    };
293                    *temp_var_int += 1;
294                    terms.insert(temp_term)
295                });
296            }
297        };
298
299        stor.get_term_mut(clause_id.head).with_terminal(&mut f);
300
301        if let Some(body) = clause_id.body {
302            stor.get_expr_mut(body).with_terminal(&mut f);
303        }
304
305        for (src, dst) in temp_var_buf.drain() {
306            let mut head = stor.get_term_mut(clause_id.head);
307            head.replace(src, dst);
308            clause_id.head = head.id();
309
310            if let Some(body) = clause_id.body {
311                let mut body = stor.get_expr_mut(body);
312                body.replace_term(src, dst);
313                clause_id.body = Some(body.id());
314            }
315        }
316
317        clause_id
318    }
319
320    fn unify_node_with_clause(
321        &mut self,
322        node_index: usize,
323        clause: ClauseId,
324        stor: &mut TermStorage<Int>,
325    ) -> Option<Node> {
326        debug_assert!(self.uni_op.ops.is_empty());
327
328        let NodeKind::Expr(node_expr) = self.nodes[node_index].kind else {
329            unreachable!()
330        };
331
332        if !stor
333            .get_expr(node_expr)
334            .leftmost_term()
335            .unify(stor.get_term(clause.head), &mut |op| {
336                self.uni_op.push_op(op)
337            })
338        {
339            return None;
340        }
341        let (node_expr, clause, uni_delta) = self.uni_op.consume_ops(stor, node_expr, clause);
342
343        if let Some(body) = clause.body {
344            let mut lhs = stor.get_expr_mut(node_expr);
345            lhs.replace_leftmost_term(body);
346            return Some(Node {
347                kind: NodeKind::Expr(lhs.id()),
348                uni_delta,
349                parent: node_index,
350            });
351        }
352
353        let mut lhs = stor.get_expr_mut(node_expr);
354        let kind = match lhs.apply_to_leftmost_term(true) {
355            ApplyResult::Expr => NodeKind::Expr(lhs.id()),
356            ApplyResult::Complete(eval) => NodeKind::Leaf(eval),
357        };
358        Some(Node {
359            kind,
360            uni_delta,
361            parent: node_index,
362        })
363    }
364
365    /// Finds all assignments from the given node to the root node.
366    ///
367    /// Then, the assignment information is stored at [`Self::assignments`].
368    fn find_assignment(&mut self, node_index: usize) {
369        // Collects unification records.
370        self.assignments.clear();
371
372        let mut cur_index = node_index;
373        loop {
374            let node = &self.nodes[cur_index];
375            let range = node.uni_delta.clone();
376
377            for (from, to) in self.uni_op.get_record(range).iter().cloned() {
378                let (from, to) = (from.0, to.0);
379
380                for i in self.assignments.len()..=from.max(to) {
381                    self.assignments.push(i);
382                }
383
384                let root_from = find(&mut self.assignments, from);
385                let root_to = find(&mut self.assignments, to);
386                self.assignments[root_from] = root_to;
387            }
388
389            if node.parent == cur_index {
390                break;
391            }
392            cur_index = node.parent;
393        }
394
395        return;
396
397        // === Internal helper functions ===
398
399        fn find(buf: &mut [usize], i: usize) -> usize {
400            if buf[i] == i {
401                i
402            } else {
403                let root = find(buf, buf[i]);
404                buf[i] = root;
405                root
406            }
407        }
408    }
409}
410
411#[derive(Debug)]
412struct UnificationOperator {
413    ops: Vec<UnifyOp>,
414
415    /// History of unification.
416    ///
417    /// A pair of term ids means that `pair.0` is assiend to `pair.1`. For
418    /// example, `(X, a)` means `X` is assigned to `a`.
419    record: Vec<(TermId, TermId)>,
420}
421
422impl UnificationOperator {
423    const fn new() -> Self {
424        Self {
425            ops: Vec::new(),
426            record: Vec::new(),
427        }
428    }
429
430    fn clear(&mut self) {
431        self.ops.clear();
432        self.record.clear();
433    }
434
435    fn push_op(&mut self, op: UnifyOp) {
436        self.ops.push(op);
437    }
438
439    #[must_use]
440    fn consume_ops(
441        &mut self,
442        stor: &mut TermStorage<Int>,
443        mut left: ExprId,
444        mut right: ClauseId,
445    ) -> (ExprId, ClauseId, Range<usize>) {
446        let record_start = self.record.len();
447
448        for op in self.ops.drain(..) {
449            match op {
450                UnifyOp::Left { from, to } => {
451                    let mut expr = stor.get_expr_mut(left);
452                    expr.replace_term(from, to);
453                    left = expr.id();
454
455                    self.record.push((from, to));
456                }
457                UnifyOp::Right { from, to } => {
458                    if let Some(right_body) = right.body {
459                        let mut expr = stor.get_expr_mut(right_body);
460                        expr.replace_term(from, to);
461                        right.body = Some(expr.id());
462
463                        self.record.push((from, to));
464                    }
465                }
466            }
467        }
468
469        (left, right, record_start..self.record.len())
470    }
471
472    fn get_record(&self, range: Range<usize>) -> &[(TermId, TermId)] {
473        &self.record[range]
474    }
475}
476
477#[derive(Debug, Clone)]
478struct Node {
479    kind: NodeKind,
480    uni_delta: Range<usize>,
481    parent: usize,
482}
483
484#[derive(Debug, Clone, Copy)]
485enum NodeKind {
486    /// A non-terminal node containig an expression id that needs to be
487    /// evaluated.
488    Expr(ExprId),
489
490    /// A terminal node containing whether a proof path ends with true or false.
491    Leaf(bool),
492}
493
494#[derive(Debug)]
495enum UnifyOp {
496    Left { from: TermId, to: TermId },
497    Right { from: TermId, to: TermId },
498}
499
500pub struct ProveCx<'a> {
501    prover: &'a mut Prover,
502    clause_map: &'a ClauseMap,
503    stor: &'a mut TermStorage<Int>,
504    nimap: &'a mut NameIntMap,
505    old_stor_len: TermStorageLen,
506    old_nimap_state: NameIntMapState,
507}
508
509impl<'a> ProveCx<'a> {
510    fn new(
511        prover: &'a mut Prover,
512        clause_map: &'a ClauseMap,
513        stor: &'a mut TermStorage<Int>,
514        nimap: &'a mut NameIntMap,
515        old_stor_len: TermStorageLen,
516        old_nimap_state: NameIntMapState,
517    ) -> Self {
518        Self {
519            prover,
520            clause_map,
521            stor,
522            nimap,
523            old_stor_len,
524            old_nimap_state,
525        }
526    }
527
528    pub fn prove_next(&mut self) -> Option<EvalView<'_>> {
529        while let Some(node_index) = self.prover.queue.pop_front() {
530            if let Some(proof_result) =
531                self.prover
532                    .evaluate_node(node_index, self.clause_map, self.stor)
533            {
534                // Returns Some(EvalView) only if the result is TRUE.
535                if proof_result {
536                    return Some(EvalView {
537                        query_vars: &self.prover.query_vars,
538                        terms: &self.stor.terms.buf,
539                        assignments: &self.prover.assignments,
540                        int2name: &self.nimap.int2name,
541                        start: 0,
542                        end: self.prover.query_vars.len(),
543                    });
544                }
545            }
546        }
547        None
548    }
549
550    pub fn is_true(mut self) -> bool {
551        self.prove_next().is_some()
552    }
553}
554
555impl Drop for ProveCx<'_> {
556    fn drop(&mut self) {
557        self.stor.truncate(self.old_stor_len.clone());
558        self.nimap.revert(self.old_nimap_state.clone());
559    }
560}
561
562pub struct EvalView<'a> {
563    query_vars: &'a [TermId],
564    terms: &'a [TermElem<Int>],
565    assignments: &'a [usize],
566    int2name: &'a IndexMap<Int, Name>,
567    /// Inclusive
568    start: usize,
569    /// Exclusive
570    end: usize,
571}
572
573impl EvalView<'_> {
574    const fn len(&self) -> usize {
575        self.end - self.start
576    }
577}
578
579impl<'a> Iterator for EvalView<'a> {
580    type Item = Assignment<'a>;
581
582    fn next(&mut self) -> Option<Self::Item> {
583        if self.start < self.end {
584            let from = self.query_vars[self.start];
585            self.start += 1;
586
587            Some(Assignment {
588                buf: self.terms,
589                from,
590                assignments: self.assignments,
591                int2name: self.int2name,
592            })
593        } else {
594            None
595        }
596    }
597
598    fn size_hint(&self) -> (usize, Option<usize>) {
599        let len = <Self>::len(self);
600        (len, Some(len))
601    }
602}
603
604impl ExactSizeIterator for EvalView<'_> {
605    fn len(&self) -> usize {
606        <Self>::len(self)
607    }
608}
609
610impl iter::FusedIterator for EvalView<'_> {}
611
612pub struct Assignment<'a> {
613    buf: &'a [TermElem<Int>],
614    from: TermId,
615    assignments: &'a [usize],
616    int2name: &'a IndexMap<Int, Name>,
617}
618
619impl<'a> Assignment<'a> {
620    /// Creates left hand side term of the assignment.
621    ///
622    /// To create a term, this method could allocate memory for the term.
623    pub fn lhs(&self) -> Term<Name> {
624        Self::term_view_to_term(self.lhs_view(), self.int2name)
625    }
626
627    /// Creates right hand side term of the assignment.
628    ///
629    /// To create a term, this method could allocate memory for the term.
630    pub fn rhs(&self) -> Term<Name> {
631        Self::term_deep_view_to_term(self.rhs_view(), self.int2name)
632    }
633
634    /// Returns left hand side variable name of the assignment.
635    ///
636    /// Note that assignment's left hand side is always variable.
637    pub fn get_lhs_variable(&self) -> &Name {
638        let int = self.lhs_view().get_contained_variable().unwrap();
639        self.int2name.get(&int).unwrap()
640    }
641
642    fn term_view_to_term(view: TermView<'_, Int>, int2name: &IndexMap<Int, Name>) -> Term<Name> {
643        let functor = view.functor();
644        let args = view.args();
645
646        let functor = if let Some(name) = int2name.get(functor) {
647            name.clone()
648        } else {
649            let mut debug_string = String::new();
650            write!(&mut debug_string, "{:?}", functor).unwrap();
651            debug_string.into()
652        };
653
654        let args = args
655            .into_iter()
656            .map(|arg| Self::term_view_to_term(arg, int2name))
657            .collect();
658
659        Term { functor, args }
660    }
661
662    fn term_deep_view_to_term(
663        view: TermDeepView<'_, Int>,
664        int2name: &IndexMap<Int, Name>,
665    ) -> Term<Name> {
666        let functor = view.functor();
667        let args = view.args();
668
669        let functor = if let Some(name) = int2name.get(functor) {
670            name.clone()
671        } else {
672            let mut debug_string = String::new();
673            write!(&mut debug_string, "{:?}", functor).unwrap();
674            debug_string.into()
675        };
676
677        let args = args
678            .into_iter()
679            .map(|arg| Self::term_deep_view_to_term(arg, int2name))
680            .collect();
681
682        Term { functor, args }
683    }
684
685    const fn lhs_view(&self) -> TermView<'_, Int> {
686        TermView {
687            buf: self.buf,
688            id: self.from,
689        }
690    }
691
692    const fn rhs_view(&self) -> TermDeepView<'_, Int> {
693        TermDeepView {
694            buf: self.buf,
695            links: self.assignments,
696            id: self.from,
697        }
698    }
699}
700
701impl fmt::Display for Assignment<'_> {
702    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
703        let view = format::NamedTermView::new(self.lhs_view(), self.int2name);
704        fmt::Display::fmt(&view, f)?;
705
706        f.write_str(" = ")?;
707
708        let view = format::NamedTermDeepView::new(self.rhs_view(), self.int2name);
709        fmt::Display::fmt(&view, f)
710    }
711}
712
713impl fmt::Debug for Assignment<'_> {
714    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
715        let lhs = format::NamedTermView::new(self.lhs_view(), self.int2name);
716        let rhs = format::NamedTermDeepView::new(self.rhs_view(), self.int2name);
717
718        f.debug_struct("Assignment")
719            .field("lhs", &lhs)
720            .field("rhs", &rhs)
721            .finish()
722    }
723}
724
725impl ExprView<'_, Int> {
726    fn is_unifiable(&self, other: TermView<'_, Int>) -> bool {
727        match self.as_kind() {
728            ExprKind::Term(term) => term.is_unifiable(other),
729            ExprKind::Not(inner) => inner.is_unifiable(other),
730            ExprKind::And(mut args) | ExprKind::Or(mut args) => {
731                args.next().unwrap().is_unifiable(other)
732            }
733        }
734    }
735
736    fn contains_variable(&self) -> bool {
737        match self.as_kind() {
738            ExprKind::Term(term) => term.contains_variable(),
739            ExprKind::Not(inner) => inner.contains_variable(),
740            ExprKind::And(mut args) | ExprKind::Or(mut args) => {
741                args.any(|arg| arg.contains_variable())
742            }
743        }
744    }
745}
746
747impl TermView<'_, Int> {
748    fn unify<F: FnMut(UnifyOp)>(self, other: Self, f: &mut F) -> bool {
749        if self.is_variable() {
750            f(UnifyOp::Left {
751                from: self.id,
752                to: other.id,
753            });
754            true
755        } else if other.is_variable() {
756            f(UnifyOp::Right {
757                from: other.id,
758                to: self.id,
759            });
760            true
761        } else if self.functor() == other.functor() {
762            let zip = self.args().zip(other.args());
763            // Unifies only if all arguments are unifiable.
764            if self.arity() == other.arity() && zip.clone().all(|(l, r)| l.is_unifiable(r)) {
765                for (l, r) in zip {
766                    l.unify(r, f);
767                }
768                true
769            } else {
770                false
771            }
772        } else {
773            false
774        }
775    }
776
777    fn is_unifiable(&self, other: Self) -> bool {
778        if self.is_variable() || other.is_variable() {
779            true
780        } else if self.functor() == other.functor() {
781            if self.arity() == other.arity() {
782                self.args()
783                    .zip(other.args())
784                    .all(|(l, r)| l.is_unifiable(r))
785            } else {
786                false
787            }
788        } else {
789            false
790        }
791    }
792
793    /// Returns true if this term is a variable.
794    ///
795    /// e.g. Terms like `X`, `Y` will return true.
796    fn is_variable(&self) -> bool {
797        self.arity() == 0 && self.functor().is_variable()
798    }
799
800    /// Returns true if this term is a variable or contains variable in it.
801    ///
802    /// e.g. Terms like `X` of `f(X)` will return true.
803    fn contains_variable(&self) -> bool {
804        self.is_variable() || self.args().any(|arg| arg.contains_variable())
805    }
806
807    fn get_contained_variable(&self) -> Option<Int> {
808        if self.is_variable() {
809            Some(*self.functor())
810        } else {
811            self.args().find_map(|arg| arg.get_contained_variable())
812        }
813    }
814
815    fn with_variable<F: FnMut(&Self)>(&self, f: &mut F) {
816        if self.is_variable() {
817            f(self);
818        } else {
819            for arg in self.args() {
820                arg.with_variable(f);
821            }
822        }
823    }
824}
825
826impl TermViewMut<'_, Int> {
827    fn is_variable(&self) -> bool {
828        self.arity() == 0 && self.functor().is_variable()
829    }
830}
831
832#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
833pub struct Int(u32);
834
835impl Int {
836    const VAR_FLAG: u32 = 0x1 << 31;
837    const TEMPORARY_FLAG: u32 = 0x1 << 30;
838
839    pub(crate) fn from_text(s: &Name, mut index: u32) -> Self {
840        if s.is_variable() {
841            index |= Self::VAR_FLAG;
842        }
843        Self(index)
844    }
845
846    pub(crate) const fn temporary(int: u32) -> Self {
847        Self(int | Self::VAR_FLAG | Self::TEMPORARY_FLAG)
848    }
849
850    pub(crate) const fn is_variable(self) -> bool {
851        (Self::VAR_FLAG & self.0) == Self::VAR_FLAG
852    }
853
854    pub(crate) const fn is_temporary_variable(self) -> bool {
855        let mask: u32 = Self::VAR_FLAG | Self::TEMPORARY_FLAG;
856        (mask & self.0) == mask
857    }
858}
859
860impl fmt::Debug for Int {
861    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
862        let mask: u32 = Self::VAR_FLAG | Self::TEMPORARY_FLAG;
863        let index = !mask & self.0;
864
865        if self.is_variable() {
866            f.write_char(VAR_PREFIX)?;
867        }
868        if self.is_temporary_variable() {
869            f.write_char('#')?;
870        }
871        index.fmt(f)
872    }
873}
874
875impl ops::AddAssign<u32> for Int {
876    fn add_assign(&mut self, rhs: u32) {
877        self.0 += rhs;
878    }
879}
880
881/// Only mapping of user-input clauses and queries are stored in this map.
882/// Auto-generated variables or something like that are not stored here.
883#[derive(Debug)]
884pub(crate) struct NameIntMap {
885    pub(crate) name2int: IndexMap<Name, Int>,
886    pub(crate) int2name: IndexMap<Int, Name>,
887    next_int: u32,
888}
889
890impl NameIntMap {
891    pub(crate) fn new() -> Self {
892        Self {
893            name2int: IndexMap::default(),
894            int2name: IndexMap::default(),
895            next_int: 0,
896        }
897    }
898
899    pub(crate) fn name_to_int(&mut self, name: Name) -> Int {
900        if let Some(int) = self.name2int.get(&name) {
901            *int
902        } else {
903            let int = Int::from_text(&name, self.next_int);
904
905            self.name2int.insert(name.clone(), int);
906            self.int2name.insert(int, name);
907
908            self.next_int += 1;
909            int
910        }
911    }
912
913    pub(crate) fn state(&self) -> NameIntMapState {
914        NameIntMapState {
915            name2int_len: self.name2int.len(),
916            int2name_len: self.int2name.len(),
917            next_int: self.next_int,
918        }
919    }
920
921    pub(crate) fn revert(
922        &mut self,
923        NameIntMapState {
924            name2int_len,
925            int2name_len,
926            next_int,
927        }: NameIntMapState,
928    ) {
929        self.name2int.truncate(name2int_len);
930        self.int2name.truncate(int2name_len);
931        self.next_int = next_int;
932    }
933}
934
935#[derive(Debug, Clone, PartialEq, Eq)]
936pub(crate) struct NameIntMapState {
937    name2int_len: usize,
938    int2name_len: usize,
939    next_int: u32,
940}
941
942pub(crate) mod format {
943    use super::*;
944
945    pub struct NamedTermView<'a> {
946        view: TermView<'a, Int>,
947        int2name: &'a IndexMap<Int, Name>,
948    }
949
950    impl<'a> NamedTermView<'a> {
951        pub(crate) const fn new(
952            view: TermView<'a, Int>,
953            int2name: &'a IndexMap<Int, Name>,
954        ) -> Self {
955            Self { view, int2name }
956        }
957    }
958
959    impl fmt::Display for NamedTermView<'_> {
960        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
961            let Self { view, int2name } = self;
962
963            let functor = view.functor();
964            let args = view.args();
965            let num_args = args.len();
966
967            write_int(functor, int2name, f)?;
968
969            if num_args > 0 {
970                f.write_char('(')?;
971                for (i, arg) in args.enumerate() {
972                    fmt::Display::fmt(&Self::new(arg, int2name), f)?;
973                    if i + 1 < num_args {
974                        f.write_str(", ")?;
975                    }
976                }
977                f.write_char(')')?;
978            }
979            Ok(())
980        }
981    }
982
983    impl fmt::Debug for NamedTermView<'_> {
984        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
985            let Self { view, int2name } = self;
986
987            let functor = view.functor();
988            let args = view.args();
989            let num_args = args.len();
990
991            if num_args == 0 {
992                write_int(functor, int2name, f)
993            } else {
994                let mut d = if let Some(name) = int2name.get(functor) {
995                    f.debug_tuple(name)
996                } else {
997                    let mut debug_string = String::new();
998                    write!(&mut debug_string, "{:?}", functor)?;
999                    f.debug_tuple(&debug_string)
1000                };
1001
1002                for arg in args {
1003                    d.field(&Self::new(arg, int2name));
1004                }
1005                d.finish()
1006            }
1007        }
1008    }
1009
1010    pub(crate) struct NamedTermDeepView<'a> {
1011        view: TermDeepView<'a, Int>,
1012        int2name: &'a IndexMap<Int, Name>,
1013    }
1014
1015    impl<'a> NamedTermDeepView<'a> {
1016        pub(crate) const fn new(
1017            view: TermDeepView<'a, Int>,
1018            int2name: &'a IndexMap<Int, Name>,
1019        ) -> Self {
1020            Self { view, int2name }
1021        }
1022    }
1023
1024    impl fmt::Display for NamedTermDeepView<'_> {
1025        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1026            let Self { view, int2name } = self;
1027
1028            let functor = view.functor();
1029            let args = view.args();
1030            let num_args = args.len();
1031
1032            write_int(functor, int2name, f)?;
1033
1034            if num_args > 0 {
1035                f.write_char('(')?;
1036                for (i, arg) in args.enumerate() {
1037                    fmt::Display::fmt(&Self::new(arg, int2name), f)?;
1038                    if i + 1 < num_args {
1039                        f.write_str(", ")?;
1040                    }
1041                }
1042                f.write_char(')')?;
1043            }
1044            Ok(())
1045        }
1046    }
1047
1048    impl fmt::Debug for NamedTermDeepView<'_> {
1049        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1050            let Self { view, int2name } = self;
1051
1052            let functor = view.functor();
1053            let args = view.args();
1054            let num_args = args.len();
1055
1056            if num_args == 0 {
1057                write_int(functor, int2name, f)
1058            } else {
1059                let mut d = if let Some(name) = int2name.get(functor) {
1060                    f.debug_tuple(name)
1061                } else {
1062                    let mut debug_string = String::new();
1063                    write!(&mut debug_string, "{:?}", functor)?;
1064                    f.debug_tuple(&debug_string)
1065                };
1066
1067                for arg in args {
1068                    d.field(&Self::new(arg, int2name));
1069                }
1070                d.finish()
1071            }
1072        }
1073    }
1074
1075    pub(crate) struct NamedExprView<'a> {
1076        view: ExprView<'a, Int>,
1077        int2name: &'a IndexMap<Int, Name>,
1078    }
1079
1080    impl<'a> NamedExprView<'a> {
1081        pub(crate) const fn new(
1082            view: ExprView<'a, Int>,
1083            int2name: &'a IndexMap<Int, Name>,
1084        ) -> Self {
1085            Self { view, int2name }
1086        }
1087    }
1088
1089    impl fmt::Display for NamedExprView<'_> {
1090        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1091            let Self { view, int2name } = self;
1092
1093            match view.as_kind() {
1094                ExprKind::Term(term) => fmt::Display::fmt(
1095                    &NamedTermView {
1096                        view: term,
1097                        int2name,
1098                    },
1099                    f,
1100                )?,
1101                ExprKind::Not(inner) => {
1102                    f.write_str("\\+ ")?;
1103                    if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) {
1104                        f.write_char('(')?;
1105                        fmt::Display::fmt(&Self::new(inner, int2name), f)?;
1106                        f.write_char(')')?;
1107                    } else {
1108                        fmt::Display::fmt(&Self::new(inner, int2name), f)?;
1109                    }
1110                }
1111                ExprKind::And(args) => {
1112                    let num_args = args.len();
1113                    for (i, arg) in args.enumerate() {
1114                        if matches!(arg.as_kind(), ExprKind::Or(_)) {
1115                            f.write_char('(')?;
1116                            fmt::Display::fmt(&Self::new(arg, int2name), f)?;
1117                            f.write_char(')')?;
1118                        } else {
1119                            fmt::Display::fmt(&Self::new(arg, int2name), f)?;
1120                        }
1121                        if i + 1 < num_args {
1122                            f.write_str(", ")?;
1123                        }
1124                    }
1125                }
1126                ExprKind::Or(args) => {
1127                    let num_args = args.len();
1128                    for (i, arg) in args.enumerate() {
1129                        fmt::Display::fmt(&Self::new(arg, int2name), f)?;
1130                        if i + 1 < num_args {
1131                            f.write_str("; ")?;
1132                        }
1133                    }
1134                }
1135            }
1136            Ok(())
1137        }
1138    }
1139
1140    impl fmt::Debug for NamedExprView<'_> {
1141        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1142            let Self { view, int2name } = self;
1143
1144            match view.as_kind() {
1145                ExprKind::Term(term) => fmt::Debug::fmt(&NamedTermView::new(term, int2name), f),
1146                ExprKind::Not(inner) => f
1147                    .debug_tuple("Not")
1148                    .field(&NamedExprView::new(inner, int2name))
1149                    .finish(),
1150                ExprKind::And(args) => {
1151                    let mut d = f.debug_tuple("And");
1152                    for arg in args {
1153                        d.field(&NamedExprView::new(arg, int2name));
1154                    }
1155                    d.finish()
1156                }
1157                ExprKind::Or(args) => {
1158                    let mut d = f.debug_tuple("Or");
1159                    for arg in args {
1160                        d.field(&NamedExprView::new(arg, int2name));
1161                    }
1162                    d.finish()
1163                }
1164            }
1165        }
1166    }
1167
1168    fn write_int(int: &Int, map: &IndexMap<Int, Name>, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1169        if let Some(name) = map.get(int) {
1170            f.write_str(name)
1171        } else {
1172            fmt::Debug::fmt(int, f)
1173        }
1174    }
1175}