Skip to main content

logic_eval/prove/
prover.rs

1use super::{
2    canonical,
3    repr::{
4        ApplyResult, ClauseId, ExprId, ExprKind, ExprView, TermDeepView, TermElem, TermId,
5        TermStorage, TermStorageLen, TermView, TermViewMut, UniqueTermArray,
6    },
7    table::Table,
8};
9use crate::{
10    parse::repr::{Expr, Predicate, Term},
11    prove::table::{TableEntry, TableIndex},
12    Atom, IndexMap, IndexSet, Map, VAR_PREFIX,
13};
14use core::{
15    fmt::{self, Debug, Display, Write},
16    hash::Hash,
17    iter,
18    ops::{self, Range},
19};
20use smallvec::SmallVec;
21use std::collections::VecDeque;
22
23#[derive(Debug)]
24pub(crate) struct Prover {
25    uni_op: UnificationOperator,
26
27    /// Nodes created during proof search.
28    nodes: Vec<Node>,
29
30    /// Variable assignments (e.g. X = a, Y = z)
31    term_assigns: TermAssignments,
32
33    /// A given query.
34    query: ExprId,
35
36    /// Variables in the root node(query).
37    ///
38    /// This could be used to find what terms these variables are assigned to.
39    query_vars: Vec<TermId>,
40
41    /// Previously returned query answers.
42    query_answers: Vec<Vec<TermId>>,
43
44    /// Task queue containing node index.
45    queue: NodeQueue,
46
47    /// A buffer containing mapping between variables and temporary variables.
48    ///
49    /// This buffer is used when we convert variables into temporary variables for a clause. After
50    /// the conversion, this buffer get empty.
51    temp_var_buf: Map<TermId, TermId>,
52
53    /// A monotonically increasing integer that is used for generating temporary variables.
54    temp_var_int: u32,
55
56    /// SLG resolution.
57    table: Table,
58}
59
60impl Prover {
61    pub(crate) fn new() -> Self {
62        Self {
63            uni_op: UnificationOperator::new(),
64            nodes: Vec::new(),
65            term_assigns: TermAssignments::default(),
66            query: ExprId(0),
67            query_vars: Vec::new(),
68            query_answers: Vec::new(),
69            queue: NodeQueue::default(),
70            temp_var_buf: Map::default(),
71            temp_var_int: 0,
72            table: Table::default(),
73        }
74    }
75
76    fn clear(&mut self) {
77        self.uni_op.clear();
78        self.nodes.clear();
79        self.term_assigns.clear();
80        self.query_vars.clear();
81        self.query_answers.clear();
82        self.queue.clear();
83        self.table.clear();
84    }
85
86    pub(crate) fn prove<'a, T: Atom>(
87        &'a mut self,
88        query: Expr<T>,
89        clauses: &'a IndexMap<Predicate<Integer>, Vec<ClauseId>>,
90        table_clauses: &'a IndexSet<Predicate<Integer>>,
91        stor: &'a mut TermStorage<Integer>,
92        nimap: &'a mut NameIntMap<T>,
93    ) -> ProveCx<'a, T> {
94        self.clear();
95
96        let old_nimap_state = nimap.state();
97        let query = query.map(&mut |name| nimap.name_to_int(name));
98
99        let old_stor_len = stor.len();
100        self.query = stor.insert_expr(query);
101
102        stor.get_expr(self.query)
103            .with_term(&mut |term: TermView<'_, Integer>| {
104                term.with_variable(|term| self.query_vars.push(term.id));
105            });
106
107        let node_kind = NodeKind::Expr(self.query);
108        let node_parent = self.nodes.len();
109        self.nodes.push(Node::new(node_kind, node_parent));
110        self.queue.push(0);
111
112        ProveCx {
113            prover: self,
114            clauses,
115            table_clauses,
116            stor,
117            nimap,
118            old_stor_len,
119            old_nimap_state,
120        }
121    }
122
123    /// Evaluates the given node with all possible clauses in the clause dataset, then returns
124    /// whether a proof search path is complete or not.
125    ///
126    /// If it reached an end of paths, it returns proof search result within `Some`. The proof
127    /// search result is either true or false, which means the expression in the given node is
128    /// evaluted as true or false.
129    fn evaluate_node(
130        &mut self,
131        node_index: usize,
132        clauses: &IndexMap<Predicate<Integer>, Vec<ClauseId>>,
133        table_clauses: &IndexSet<Predicate<Integer>>,
134        stor: &mut TermStorage<Integer>,
135    ) -> Option<bool> {
136        let node_expr = match self.nodes[node_index].kind {
137            NodeKind::Expr(expr_id) => expr_id,
138            NodeKind::Leaf(eval) => {
139                self.find_assignments(node_index);
140                // On a successful proof, records the answer in the nearest ancestor-owned SLG
141                // table entry, then notifies all waiting consumers.
142                if eval {
143                    self.update_answer_and_notify(node_index, stor);
144                }
145                return Some(eval);
146            }
147        };
148
149        let node_leftmost = stor.get_expr(node_expr).leftmost_term().id;
150        let node_leftmost_pred = stor.get_term(node_leftmost).predicate();
151        let mut similar_clauses = &[][..];
152        let mut clause_buf: SmallVec<[ClauseId; 1]> = SmallVec::new();
153
154        // === SLG path ===
155        // * Table entry - Created from non-canonical leftmost term of the node. In tabling,
156        //   we use canonical variables for table keys only.
157
158        if table_clauses.contains(&node_leftmost_pred) {
159            let key = canonical::canonicalize_term_id(stor, node_leftmost);
160            if let Some((_, entry)) = self.table.get_mut(&key) {
161                entry.register_consumer(node_index);
162
163                // No answers yet? the node may be woken up by notification.
164                let answer_offset = self.nodes[node_index].table_answer_offset;
165                let answers = entry.answers(answer_offset);
166                if answers.is_empty() {
167                    return None;
168                }
169                let next_offset = answer_offset + 1;
170                self.nodes[node_index].table_answer_offset = next_offset;
171
172                // Synthesizes an answer clause then let this to be unified with the current node.
173                let mut term = stor.get_term_mut(node_leftmost);
174                let vars = term.as_view().collect_variables();
175                for (var, answer) in vars.into_iter().zip(answers) {
176                    term.replace(var, *answer);
177                }
178                clause_buf.push(ClauseId {
179                    head: term.id(),
180                    body: None,
181                });
182                similar_clauses = &clause_buf[..];
183
184                // More answers? We'll handle them next time.
185                if !entry.answers(next_offset).is_empty() {
186                    self.queue.push(node_index);
187                }
188            } else {
189                // First encounter: Just creates a table entry then proceeds with SLD.
190                if let Some(entry) = TableEntry::from_term_view(&stor.get_term(node_leftmost)) {
191                    let index = self.table.register(key, entry);
192                    self.nodes[node_index].table_owner = Some(index);
193                }
194            }
195        }
196
197        // === BFS based SLD path ===
198
199        if similar_clauses.is_empty() {
200            if let Some(v) = clauses.get(&node_leftmost_pred) {
201                similar_clauses = v.as_slice()
202            }
203        }
204
205        let old_len = self.nodes.len();
206
207        for clause in similar_clauses {
208            let head = stor.get_term(clause.head);
209
210            if !stor.get_expr(node_expr).is_unifiable(head) {
211                continue;
212            }
213
214            let clause = Self::convert_var_into_temp(
215                *clause,
216                stor,
217                &mut self.temp_var_buf,
218                &mut self.temp_var_int,
219            );
220            if let Some(new_node) = self.unify_node_with_clause(node_index, clause, stor) {
221                self.nodes.push(new_node);
222                self.queue.push(self.nodes.len() - 1);
223            }
224        }
225
226        // We may need to apply true or false to the leftmost term of the node expression due to
227        // unification failure or exhaustive search.
228        // - Unification failure means the leftmost term should be false.
229        // - But we need to consider exhaustive search at the same time.
230
231        let expr = stor.get_expr(node_expr);
232        let eval = self.nodes.len() > old_len;
233        let mut need_apply = None;
234
235        let lost_possibility = match assume_leftmost_term(expr, eval) {
236            AssumeResult::Incomplete { lost } => lost,
237            AssumeResult::Complete { lost, .. } => lost,
238        };
239        if lost_possibility {
240            need_apply = Some(!eval);
241        } else if !eval {
242            need_apply = Some(false);
243        }
244
245        if let Some(to) = need_apply {
246            let mut expr = stor.get_expr_mut(node_expr);
247            let node_kind = match expr.apply_to_leftmost_term(to) {
248                ApplyResult::Expr => NodeKind::Expr(expr.id()),
249                ApplyResult::Complete(eval) => NodeKind::Leaf(eval),
250            };
251            let node_parent = node_index;
252            self.nodes.push(Node::new(node_kind, node_parent));
253            self.queue.push(self.nodes.len() - 1);
254        }
255
256        return None;
257
258        // === Internal helper functions ===
259
260        enum AssumeResult {
261            /// The whole expression could not completely evaluated even though the assumption is
262            /// realized.
263            Incomplete {
264                /// Whether or not the assumption will make us lose some search possibilities.
265                lost: bool,
266            },
267
268            /// The whole expression will be completely evaluated if the assumption is realized.
269            Complete {
270                /// Evalauted as true or false.
271                eval: bool,
272                lost: bool,
273            },
274        }
275
276        fn assume_leftmost_term(expr: ExprView<'_, Integer>, to: bool) -> AssumeResult {
277            match expr.as_kind() {
278                ExprKind::Term(_) => AssumeResult::Complete {
279                    eval: to,
280                    lost: false,
281                },
282                ExprKind::Not(inner) => match assume_leftmost_term(inner, to) {
283                    res @ AssumeResult::Incomplete { .. } => res,
284                    AssumeResult::Complete { eval, lost } => {
285                        AssumeResult::Complete { eval: !eval, lost }
286                    }
287                },
288                ExprKind::And(mut args) => {
289                    // Unlike 'Or', even if 'And' contains variables and the whole expression will
290                    // be evaluated false, those variables must be ignored. They don't belong to
291                    // 'lost'.
292                    match assume_leftmost_term(args.next().unwrap(), to) {
293                        res @ AssumeResult::Incomplete { .. } => res,
294                        AssumeResult::Complete { eval, lost } => {
295                            if !eval {
296                                AssumeResult::Complete { eval: false, lost }
297                            } else {
298                                AssumeResult::Incomplete { lost }
299                            }
300                        }
301                    }
302                }
303                ExprKind::Or(mut args) => {
304                    // The whole 'Or' is true if any argument is true. But we will lose possible
305                    // search paths if we ignore right variables.
306                    match assume_leftmost_term(args.next().unwrap(), to) {
307                        res @ AssumeResult::Incomplete { .. } => res,
308                        AssumeResult::Complete { eval, lost } => {
309                            if eval {
310                                let right_var = args.any(|arg| arg.contains_variable());
311                                AssumeResult::Complete {
312                                    eval: true,
313                                    lost: lost | right_var,
314                                }
315                            } else {
316                                AssumeResult::Incomplete { lost }
317                            }
318                        }
319                    }
320                }
321            }
322        }
323    }
324
325    /// Finds the nearest ancestor node that owns SLG table entry, then updates the entry and
326    /// notifies all waiting consumers.
327    fn update_answer_and_notify(&mut self, node_index: usize, stor: &TermStorage<Integer>) {
328        let tabled_ancestor = {
329            let mut cur = node_index;
330            loop {
331                if self.nodes[cur].table_owner.is_some() {
332                    break Some(cur);
333                }
334                let parent = self.nodes[cur].parent;
335                if parent == cur {
336                    break None;
337                }
338                cur = parent;
339            }
340        };
341
342        if let Some(ancestor) = tabled_ancestor {
343            let table_index = self.nodes[ancestor].table_owner.unwrap();
344            let entry = &mut self.table[table_index];
345            let all_answers_concrete = entry.variables().iter().all(|&var| {
346                if let Some(answer) = self.term_assigns.find(var) {
347                    !stor.get_term(answer).contains_variable()
348                } else {
349                    false
350                }
351            });
352
353            if all_answers_concrete && !entry.has_answer(&self.term_assigns) {
354                entry.update_answer(&self.term_assigns);
355                for i in entry.consumer_nodes() {
356                    if i != node_index {
357                        self.queue.push(i);
358                    }
359                }
360            }
361        }
362    }
363
364    /// Replaces variables in a clause with other temporary variables.
365    //
366    // Why we replace variables with temporary variables in clauses before unifying?
367    // 1. That's because variables in different clauses are actually different from each other even
368    //    they have the same identity. Variable's identity is valid only in the one clause where
369    //    they belong.
370    // 2. Also, we apply this method whenever unification happens because one clause can be used
371    //    mupltiple times in a single proof search path. Then it is considered as a different
372    //    clause.
373    fn convert_var_into_temp(
374        mut clause_id: ClauseId,
375        stor: &mut TermStorage<Integer>,
376        temp_var_buf: &mut Map<TermId, TermId>,
377        temp_var_int: &mut u32,
378    ) -> ClauseId {
379        debug_assert!(temp_var_buf.is_empty());
380
381        let mut f = |terms: &mut UniqueTermArray<Integer>, term_id: TermId| {
382            let term = terms.get_mut(term_id);
383            if term.is_variable() {
384                let src = term.id();
385
386                temp_var_buf.entry(src).or_insert_with(|| {
387                    let temp_term = Term {
388                        functor: Integer::temporary(*temp_var_int),
389                        args: [].into(),
390                    };
391                    *temp_var_int += 1;
392                    terms.insert(temp_term)
393                });
394            }
395        };
396
397        stor.get_term_mut(clause_id.head).with_terminal(&mut f);
398
399        if let Some(body) = clause_id.body {
400            stor.get_expr_mut(body).with_terminal(&mut f);
401        }
402
403        for (src, dst) in temp_var_buf.drain() {
404            let mut head = stor.get_term_mut(clause_id.head);
405            head.replace(src, dst);
406            clause_id.head = head.id();
407
408            if let Some(body) = clause_id.body {
409                let mut body = stor.get_expr_mut(body);
410                body.replace_term(src, dst);
411                clause_id.body = Some(body.id());
412            }
413        }
414
415        clause_id
416    }
417
418    fn unify_node_with_clause(
419        &mut self,
420        node_index: usize,
421        clause: ClauseId,
422        stor: &mut TermStorage<Integer>,
423    ) -> Option<Node> {
424        debug_assert!(self.uni_op.ops.is_empty());
425
426        let NodeKind::Expr(node_expr) = self.nodes[node_index].kind else {
427            unreachable!()
428        };
429
430        if !stor
431            .get_expr(node_expr)
432            .leftmost_term()
433            .unify(stor.get_term(clause.head), &mut |op| {
434                self.uni_op.push_op(op)
435            })
436        {
437            return None;
438        }
439        let (node_expr, clause, uni_history) = self.uni_op.consume_ops(stor, node_expr, clause);
440
441        if let Some(body) = clause.body {
442            let mut lhs = stor.get_expr_mut(node_expr);
443            lhs.replace_leftmost_term(body);
444            let node_kind = NodeKind::Expr(lhs.id());
445            let node_parent = node_index;
446            let node = Node::new(node_kind, node_parent).with_unification_history(uni_history);
447            return Some(node);
448        }
449
450        let mut lhs = stor.get_expr_mut(node_expr);
451        let node_kind = match lhs.apply_to_leftmost_term(true) {
452            ApplyResult::Expr => NodeKind::Expr(lhs.id()),
453            ApplyResult::Complete(eval) => NodeKind::Leaf(eval),
454        };
455        let node_parent = node_index;
456        let node = Node::new(node_kind, node_parent).with_unification_history(uni_history);
457        Some(node)
458    }
459
460    /// Finds all from/to relations while traversing from the given node to the root node then add
461    /// the relations to [`TermAssignments`].
462    fn find_assignments(&mut self, node_index: usize) {
463        self.term_assigns.clear();
464
465        let mut cur_index = node_index;
466        loop {
467            let node = &self.nodes[cur_index];
468            let range = node.uni_history.clone();
469
470            for (from, to) in self.uni_op.get_record(range).iter().cloned() {
471                self.term_assigns.add(from, to);
472            }
473
474            if node.parent == cur_index {
475                break;
476            }
477            cur_index = node.parent;
478        }
479    }
480
481    /// Records the current proof result as a query answer if it is ground and not duplicated,
482    /// then returns whether a new answer was recorded.
483    fn record_query_answer(&mut self, stor: &mut TermStorage<Integer>) -> bool {
484        let mut answer = Vec::with_capacity(self.query_vars.len());
485        for &var in &self.query_vars {
486            let Some(resolved) = self.materialize_assigned_term(var, stor) else {
487                return false;
488            };
489            answer.push(resolved);
490        }
491
492        // no query vars -> empty iter -> all() returns true
493        if self.query_answers.iter().all(|seen| seen != &answer) {
494            self.query_answers.push(answer);
495            true
496        } else {
497            false
498        }
499    }
500
501    /// Builds a fully substituted term for a query-side term from `term_assigns`.
502    ///
503    /// Examples:
504    ///
505    /// | assignments         |  input   |  output  |
506    /// | ------------------- | :------: | :------: |
507    /// | `T = Vec(a)`        |   `T`    | `Vec(a)` |
508    /// | `T = a`             | `Vec(T)` | `Vec(a)` |
509    /// | `T = Vec(U), U = a` |   `T`    | `Vec(a)` |
510    /// | `T = Vec(U)`        |   `T`    |  `None`  |
511    ///
512    /// This must materialize the whole term tree, not just rewrite functors in place. The returned
513    /// `TermId` always points to a ground term inserted into `stor`.
514    fn materialize_assigned_term(
515        &self,
516        term_id: TermId,
517        stor: &mut TermStorage<Integer>,
518    ) -> Option<TermId> {
519        let term = stor.get_term(term_id);
520        if term.is_variable() {
521            let resolved = self.term_assigns.find(term_id)?;
522            if resolved == term_id {
523                return None;
524            }
525            return self.materialize_assigned_term(resolved, stor);
526        }
527
528        let functor = *term.functor();
529        let arg_ids = term.args().map(|arg| arg.id).collect::<Vec<_>>();
530        let args = arg_ids
531            .into_iter()
532            .map(|arg_id| {
533                self.materialize_assigned_term(arg_id, stor)
534                    .map(|id| stor.get_term(id).deserialize())
535            })
536            .collect::<Option<Vec<_>>>()?;
537
538        let materialized = Term { functor, args };
539        Some(stor.insert_term(materialized))
540    }
541}
542
543/// Manages unification operations between a goal(node) and a clause during SLG resolution.
544///
545/// You can make unification operations, [`UnifyOp`]s, by unifying the leftmost term of the goal and
546/// the head of the clause. Append the operations in order to apply them to the whole goal and
547/// clause. You can apply them at once via [`consume_ops`].
548///
549/// [`consume_ops`]: Self::consume_ops
550#[derive(Debug)]
551struct UnificationOperator {
552    /// Buffered unification operations.
553    ops: Vec<UnifyOp>,
554
555    /// Unification history.
556    ///
557    /// This is a record of `(from, to)` pairs. It means there has been unification that substitute
558    /// the `from` with `to`. For example, `(X, a)` means the variable `X` was substituted with `a`.
559    record: Vec<(TermId, TermId)>,
560}
561
562impl UnificationOperator {
563    const fn new() -> Self {
564        Self {
565            ops: Vec::new(),
566            record: Vec::new(),
567        }
568    }
569
570    fn clear(&mut self) {
571        self.ops.clear();
572        self.record.clear();
573    }
574
575    fn push_op(&mut self, op: UnifyOp) {
576        self.ops.push(op);
577    }
578
579    /// Returns
580    /// * `ExprId` - Operation applied `left`
581    /// * `ClauseId` - Operation applied `right`
582    /// * `Range<usize>` - A range of unification history(from/to pairs). You can retrieve the
583    ///   from/to pairs via [`get_record`]
584    ///
585    /// [`get_record`]: Self::get_record
586    #[must_use]
587    fn consume_ops(
588        &mut self,
589        stor: &mut TermStorage<Integer>,
590        mut left: ExprId,
591        mut right: ClauseId,
592    ) -> (ExprId, ClauseId, Range<usize>) {
593        let record_start = self.record.len();
594
595        for op in self.ops.drain(..) {
596            match op {
597                UnifyOp::Left { from, to } => {
598                    let mut expr = stor.get_expr_mut(left);
599                    expr.replace_term(from, to);
600                    left = expr.id();
601
602                    self.record.push((from, to));
603                }
604                UnifyOp::Right { from, to } => {
605                    if let Some(right_body) = right.body {
606                        let mut expr = stor.get_expr_mut(right_body);
607                        expr.replace_term(from, to);
608                        right.body = Some(expr.id());
609
610                        self.record.push((from, to));
611                    }
612                }
613            }
614        }
615
616        (left, right, record_start..self.record.len())
617    }
618
619    fn get_record(&self, range: Range<usize>) -> &[(TermId, TermId)] {
620        &self.record[range]
621    }
622}
623
624#[derive(Debug, Default)]
625struct NodeQueue {
626    inner: VecDeque<usize>,
627}
628
629impl NodeQueue {
630    fn clear(&mut self) {
631        self.inner.clear();
632    }
633
634    fn contains(&self, node_index: &usize) -> bool {
635        self.inner.contains(node_index)
636    }
637
638    fn push(&mut self, node_index: usize) {
639        if !self.contains(&node_index) {
640            self.inner.push_back(node_index);
641        }
642    }
643
644    fn pop(&mut self) -> Option<usize> {
645        self.inner.pop_front()
646    }
647}
648
649#[derive(Debug, Clone)]
650struct Node {
651    kind: NodeKind,
652    parent: usize,
653
654    /// A range of unification history that applied to prove this node:
655    /// Pairs of from([`TermId`]) -> to([`TermId`]).
656    ///
657    /// You can retreive the from/to pairs via [`UnificationOperator::get_record`].
658    uni_history: Range<usize>,
659
660    /// Table entry owned by this node, if this node is the producer of a tabled subgoal.
661    table_owner: Option<TableIndex>,
662
663    /// Number of answers already consumed from a table entry.
664    table_answer_offset: usize,
665}
666
667impl Node {
668    fn new(kind: NodeKind, parent: usize) -> Self {
669        Self {
670            kind,
671            parent,
672            uni_history: 0..0,
673            table_owner: None,
674            table_answer_offset: 0,
675        }
676    }
677
678    fn with_unification_history(mut self, uni_history: Range<usize>) -> Self {
679        self.uni_history = uni_history;
680        self
681    }
682}
683
684#[derive(Debug, Clone, Copy)]
685enum NodeKind {
686    /// A non-terminal node containig an expression id that needs to be evaluated.
687    Expr(ExprId),
688
689    /// A terminal node containing whether a proof path ends with true or false.
690    Leaf(bool),
691}
692
693#[derive(Debug, Default)]
694pub(crate) struct TermAssignments {
695    /// Union-find from-to relations.
696    ///
697    /// # Examples
698    /// `roots[a]: a` means TermId(a) is not unified with anything.
699    /// `roots[v]: w` means TermId(v) is a variable and it is unified with TermId(w).
700    relations: Vec<TermId>,
701}
702
703impl TermAssignments {
704    pub(crate) fn find(&self, from: TermId) -> Option<TermId> {
705        let to = *self.relations.get(from.0)?;
706        if from == to {
707            Some(to)
708        } else {
709            self.find(to)
710        }
711    }
712
713    pub(crate) fn find_optimize(&mut self, from: TermId) -> TermId {
714        let new_len = from.0 + 1;
715        for i in self.len()..new_len {
716            self.relations.push(TermId(i));
717        }
718
719        let to = self.relations[from.0];
720        if from == to {
721            to
722        } else {
723            let root = self.find_optimize(to);
724            self.relations[from.0] = root;
725            root
726        }
727    }
728
729    fn len(&self) -> usize {
730        self.relations.len()
731    }
732
733    fn clear(&mut self) {
734        self.relations.clear();
735    }
736
737    fn add(&mut self, from: TermId, to: TermId) {
738        let root_from = self.find_optimize(from);
739        let root_to = self.find_optimize(to);
740        self.relations[root_from.0] = root_to;
741    }
742}
743
744/// Unification operation between `node expr - clause's body(expr)`.
745#[derive(Debug)]
746enum UnifyOp {
747    /// Unification operation that rewrites the goal expression on the query side.
748    ///
749    /// Substitues all `from`s in the goal expression with `to`.
750    Left { from: TermId, to: TermId },
751
752    /// Unification operation that rewrites the clause body on the clause side.
753    ///
754    /// Substitues all `from`s in the clause's body with `to`.
755    Right { from: TermId, to: TermId },
756}
757
758pub struct ProveCx<'a, T: Atom> {
759    prover: &'a mut Prover,
760    clauses: &'a IndexMap<Predicate<Integer>, Vec<ClauseId>>,
761    table_clauses: &'a IndexSet<Predicate<Integer>>,
762    stor: &'a mut TermStorage<Integer>,
763    nimap: &'a mut NameIntMap<T>,
764    old_stor_len: TermStorageLen,
765    old_nimap_state: NameIntMapState,
766}
767
768impl<'a, T: Atom> ProveCx<'a, T> {
769    pub fn prove_next(&mut self) -> Option<EvalView<'_, T>> {
770        while let Some(node_index) = self.prover.queue.pop() {
771            if let Some(proof_result) =
772                self.prover
773                    .evaluate_node(node_index, self.clauses, self.table_clauses, self.stor)
774            {
775                // Returns Some(EvalView) only if the result is TRUE.
776                if proof_result && self.prover.record_query_answer(self.stor) {
777                    return Some(EvalView {
778                        query_vars: &self.prover.query_vars,
779                        terms: &self.stor.terms.buf,
780                        term_assigns: &self.prover.term_assigns,
781                        nimap: self.nimap,
782                        start: 0,
783                        end: self.prover.query_vars.len(),
784                    });
785                }
786            }
787        }
788        None
789    }
790
791    pub fn is_true(mut self) -> bool {
792        self.prove_next().is_some()
793    }
794}
795
796impl<T: Atom> Drop for ProveCx<'_, T> {
797    fn drop(&mut self) {
798        self.stor.truncate(self.old_stor_len.clone());
799        self.nimap.revert(self.old_nimap_state.clone());
800    }
801}
802
803pub struct EvalView<'a, T> {
804    query_vars: &'a [TermId],
805    terms: &'a [TermElem<Integer>],
806    term_assigns: &'a TermAssignments,
807    nimap: &'a NameIntMap<T>,
808    /// Inclusive
809    start: usize,
810    /// Exclusive
811    end: usize,
812}
813
814impl<T> EvalView<'_, T> {
815    const fn len(&self) -> usize {
816        self.end - self.start
817    }
818}
819
820impl<'a, T> Iterator for EvalView<'a, T> {
821    type Item = Assignment<'a, T>;
822
823    fn next(&mut self) -> Option<Self::Item> {
824        if self.start < self.end {
825            let from = self.query_vars[self.start];
826            self.start += 1;
827
828            Some(Assignment {
829                buf: self.terms,
830                from,
831                term_assigns: self.term_assigns,
832                nimap: self.nimap,
833            })
834        } else {
835            None
836        }
837    }
838
839    fn size_hint(&self) -> (usize, Option<usize>) {
840        let len = <Self>::len(self);
841        (len, Some(len))
842    }
843}
844
845impl<T> ExactSizeIterator for EvalView<'_, T> {
846    fn len(&self) -> usize {
847        <Self>::len(self)
848    }
849}
850
851impl<T> iter::FusedIterator for EvalView<'_, T> {}
852
853pub struct Assignment<'a, T> {
854    buf: &'a [TermElem<Integer>],
855    from: TermId,
856    term_assigns: &'a TermAssignments,
857    nimap: &'a NameIntMap<T>,
858}
859
860impl<'a, T: 'a> Assignment<'a, T> {
861    /// Returns left hand side variable name of the assignment.
862    ///
863    /// Note that assignment's left hand side is always variable.
864    pub fn get_lhs_variable(&self) -> &T {
865        let int = self.lhs_view().find_variable().unwrap();
866        self.nimap.get_name(&int).unwrap()
867    }
868
869    const fn lhs_view(&self) -> TermView<'_, Integer> {
870        TermView {
871            buf: self.buf,
872            id: self.from,
873        }
874    }
875
876    const fn rhs_view(&self) -> TermDeepView<'_, Integer> {
877        TermDeepView {
878            buf: self.buf,
879            term_assigns: self.term_assigns,
880            id: self.from,
881        }
882    }
883}
884
885impl<'a, T: Atom + 'a> Assignment<'a, T> {
886    /// Creates left hand side term of the assignment.
887    ///
888    /// To create a term, this method could allocate memory for the term.
889    pub fn lhs(&self) -> Term<T> {
890        Self::term_view_to_term(self.lhs_view(), self.nimap)
891    }
892
893    /// Creates right hand side term of the assignment.
894    ///
895    /// To create a term, this method could allocate memory for the term.
896    pub fn rhs(&self) -> Term<T> {
897        Self::term_deep_view_to_term(self.rhs_view(), self.nimap)
898    }
899
900    fn term_view_to_term(view: TermView<'_, Integer>, nimap: &NameIntMap<T>) -> Term<T> {
901        let functor = view.functor();
902        let args = view.args();
903
904        let functor = if let Some(name) = nimap.get_name(functor) {
905            name.clone()
906        } else {
907            unreachable!("integer {:?} has no name mapping", functor)
908        };
909
910        let args = args
911            .into_iter()
912            .map(|arg| Self::term_view_to_term(arg, nimap))
913            .collect();
914
915        Term { functor, args }
916    }
917
918    fn term_deep_view_to_term(view: TermDeepView<'_, Integer>, nimap: &NameIntMap<T>) -> Term<T> {
919        let functor = view.functor();
920        let args = view.args();
921
922        let functor = if let Some(name) = nimap.get_name(functor) {
923            name.clone()
924        } else {
925            unreachable!("integer {:?} has no name mapping", functor)
926        };
927
928        let args = args
929            .into_iter()
930            .map(|arg| Self::term_deep_view_to_term(arg, nimap))
931            .collect();
932
933        Term { functor, args }
934    }
935}
936
937impl<T: Atom + Display> Display for Assignment<'_, T> {
938    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
939        let view = format::NamedTermView::new(self.lhs_view(), self.nimap);
940        Display::fmt(&view, f)?;
941
942        f.write_str(" = ")?;
943
944        let view = format::NamedTermDeepView::new(self.rhs_view(), self.nimap);
945        Display::fmt(&view, f)
946    }
947}
948
949impl<T: Atom + Debug> Debug for Assignment<'_, T> {
950    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
951        let lhs = format::NamedTermView::new(self.lhs_view(), self.nimap);
952        let rhs = format::NamedTermDeepView::new(self.rhs_view(), self.nimap);
953
954        f.debug_struct("Assignment")
955            .field("lhs", &lhs)
956            .field("rhs", &rhs)
957            .finish()
958    }
959}
960
961impl ExprView<'_, Integer> {
962    fn is_unifiable(&self, other: TermView<'_, Integer>) -> bool {
963        match self.as_kind() {
964            ExprKind::Term(term) => term.is_unifiable(other),
965            ExprKind::Not(inner) => inner.is_unifiable(other),
966            ExprKind::And(mut args) | ExprKind::Or(mut args) => {
967                args.next().unwrap().is_unifiable(other)
968            }
969        }
970    }
971
972    fn contains_variable(&self) -> bool {
973        match self.as_kind() {
974            ExprKind::Term(term) => term.contains_variable(),
975            ExprKind::Not(inner) => inner.contains_variable(),
976            ExprKind::And(mut args) | ExprKind::Or(mut args) => {
977                args.any(|arg| arg.contains_variable())
978            }
979        }
980    }
981}
982
983impl TermView<'_, Integer> {
984    fn unify<F: FnMut(UnifyOp)>(self, other: Self, f: &mut F) -> bool {
985        if self.is_variable() {
986            f(UnifyOp::Left {
987                from: self.id,
988                to: other.id,
989            });
990            true
991        } else if other.is_variable() {
992            f(UnifyOp::Right {
993                from: other.id,
994                to: self.id,
995            });
996            true
997        } else if self.functor() == other.functor() {
998            let zip = self.args().zip(other.args());
999            // Unifies only if all arguments are unifiable.
1000            if self.arity() == other.arity() && zip.clone().all(|(l, r)| l.is_unifiable(r)) {
1001                for (l, r) in zip {
1002                    l.unify(r, f);
1003                }
1004                true
1005            } else {
1006                false
1007            }
1008        } else {
1009            false
1010        }
1011    }
1012
1013    fn is_unifiable(&self, other: Self) -> bool {
1014        if self.is_variable() || other.is_variable() {
1015            true
1016        } else if self.functor() == other.functor() {
1017            if self.arity() == other.arity() {
1018                self.args()
1019                    .zip(other.args())
1020                    .all(|(l, r)| l.is_unifiable(r))
1021            } else {
1022                false
1023            }
1024        } else {
1025            false
1026        }
1027    }
1028}
1029
1030impl TermViewMut<'_, Integer> {
1031    fn is_variable(&self) -> bool {
1032        self.arity() == 0 && self.functor().is_variable()
1033    }
1034}
1035
1036#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
1037pub struct Integer(u32);
1038
1039impl Integer {
1040    const VAR_FLAG: u32 = 0x1 << 31;
1041    const TEMPORARY_FLAG: u32 = 0x1 << 30;
1042
1043    pub(crate) fn from_value<T: Atom>(s: &T, mut index: u32) -> Self {
1044        if s.is_variable() {
1045            index |= Self::VAR_FLAG;
1046        }
1047        Self(index)
1048    }
1049
1050    pub(crate) fn variable(int: u32) -> Self {
1051        let mask = Self::VAR_FLAG;
1052        debug_assert_eq!(int & mask, 0);
1053        Self(int | mask)
1054    }
1055
1056    pub(crate) fn temporary(int: u32) -> Self {
1057        let mask = Self::VAR_FLAG | Self::TEMPORARY_FLAG;
1058        debug_assert_eq!(int & mask, 0);
1059        Self(int | mask)
1060    }
1061
1062    pub(crate) const fn is_temporary_variable(self) -> bool {
1063        let mask = Self::VAR_FLAG | Self::TEMPORARY_FLAG;
1064        (mask & self.0) == mask
1065    }
1066}
1067
1068impl Atom for Integer {
1069    fn is_variable(&self) -> bool {
1070        (Self::VAR_FLAG & self.0) == Self::VAR_FLAG
1071    }
1072}
1073
1074impl Debug for Integer {
1075    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1076        let mask: u32 = Self::VAR_FLAG | Self::TEMPORARY_FLAG;
1077        let index = !mask & self.0;
1078
1079        if self.is_variable() {
1080            f.write_char(VAR_PREFIX)?;
1081        }
1082        if self.is_temporary_variable() {
1083            f.write_char('#')?;
1084        }
1085        Debug::fmt(&index, f)
1086    }
1087}
1088
1089impl ops::AddAssign<u32> for Integer {
1090    fn add_assign(&mut self, rhs: u32) {
1091        self.0 += rhs;
1092    }
1093}
1094
1095/// Only mapping of user-input clauses and queries are stored in this map. Auto-generated variables
1096/// or something like that are not stored here.
1097#[derive(Debug)]
1098pub(crate) struct NameIntMap<T> {
1099    name2int: IndexMap<T, Integer>,
1100    int2name: IndexMap<Integer, T>,
1101    next_int: u32,
1102}
1103
1104impl<T> NameIntMap<T> {
1105    pub(crate) fn new() -> Self {
1106        Self {
1107            name2int: IndexMap::default(),
1108            int2name: IndexMap::default(),
1109            next_int: 0,
1110        }
1111    }
1112
1113    pub(crate) fn get_name(&self, int: &Integer) -> Option<&T> {
1114        self.int2name.get(int)
1115    }
1116
1117    pub(crate) fn state(&self) -> NameIntMapState {
1118        NameIntMapState {
1119            name2int_len: self.name2int.len(),
1120            int2name_len: self.int2name.len(),
1121            next_int: self.next_int,
1122        }
1123    }
1124
1125    pub(crate) fn revert(
1126        &mut self,
1127        NameIntMapState {
1128            name2int_len,
1129            int2name_len,
1130            next_int,
1131        }: NameIntMapState,
1132    ) {
1133        self.name2int.truncate(name2int_len);
1134        self.int2name.truncate(int2name_len);
1135        self.next_int = next_int;
1136    }
1137}
1138
1139impl<T: Atom> NameIntMap<T> {
1140    pub(crate) fn name_to_int(&mut self, name: T) -> Integer {
1141        if let Some(int) = self.name2int.get(&name) {
1142            *int
1143        } else {
1144            let int = Integer::from_value(&name, self.next_int);
1145
1146            self.name2int.insert(name.clone(), int);
1147            self.int2name.insert(int, name);
1148
1149            self.next_int += 1;
1150            int
1151        }
1152    }
1153}
1154
1155#[derive(Debug, Clone, PartialEq, Eq)]
1156pub(crate) struct NameIntMapState {
1157    name2int_len: usize,
1158    int2name_len: usize,
1159    next_int: u32,
1160}
1161
1162pub(crate) mod format {
1163    use super::*;
1164
1165    pub struct NamedTermView<'a, T> {
1166        view: TermView<'a, Integer>,
1167        nimap: &'a NameIntMap<T>,
1168    }
1169
1170    impl<'a, T> NamedTermView<'a, T> {
1171        pub(crate) const fn new(view: TermView<'a, Integer>, nimap: &'a NameIntMap<T>) -> Self {
1172            Self { view, nimap }
1173        }
1174
1175        fn args<'s>(&'s self) -> impl Iterator<Item = NamedTermView<'a, T>> + 's {
1176            self.view.args().map(|arg| Self {
1177                view: arg,
1178                nimap: self.nimap,
1179            })
1180        }
1181    }
1182
1183    impl<'a, T: Atom> NamedTermView<'a, T> {
1184        pub fn is(&self, term: &Term<T>) -> bool {
1185            let functor = self.view.functor();
1186            let Some(functor) = self.nimap.get_name(functor) else {
1187                return false;
1188            };
1189
1190            if functor != &term.functor {
1191                return false;
1192            }
1193
1194            self.args().zip(&term.args).all(|(l, r)| l.is(r))
1195        }
1196
1197        pub fn contains(&self, term: &Term<T>) -> bool {
1198            if self.is(term) {
1199                return true;
1200            }
1201
1202            self.args().any(|arg| arg.contains(term))
1203        }
1204    }
1205
1206    impl<'a, T: Display> Display for NamedTermView<'a, T> {
1207        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1208            let Self { view, nimap } = self;
1209
1210            let functor = view.functor();
1211            let args = view.args();
1212            let num_args = args.len();
1213
1214            write_int(functor, nimap, f)?;
1215
1216            if num_args > 0 {
1217                f.write_char('(')?;
1218                for (i, arg) in args.enumerate() {
1219                    fmt::Display::fmt(&Self::new(arg, nimap), f)?;
1220                    if i + 1 < num_args {
1221                        f.write_str(", ")?;
1222                    }
1223                }
1224                f.write_char(')')?;
1225            }
1226            Ok(())
1227        }
1228    }
1229
1230    impl<'a, T: Debug> Debug for NamedTermView<'a, T> {
1231        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1232            let Self { view, nimap } = self;
1233
1234            let functor = view.functor();
1235            let args = view.args();
1236            let num_args = args.len();
1237
1238            if num_args == 0 {
1239                if let Some(name) = nimap.get_name(functor) {
1240                    fmt::Debug::fmt(name, f)
1241                } else {
1242                    fmt::Debug::fmt(functor, f)
1243                }
1244            } else {
1245                let name_str = if let Some(name) = nimap.get_name(functor) {
1246                    format!("{:?}", name)
1247                } else {
1248                    format!("{:?}", functor)
1249                };
1250                let mut d = f.debug_tuple(&name_str);
1251
1252                for arg in args {
1253                    d.field(&Self::new(arg, nimap));
1254                }
1255                d.finish()
1256            }
1257        }
1258    }
1259
1260    pub(crate) struct NamedTermDeepView<'a, T> {
1261        view: TermDeepView<'a, Integer>,
1262        nimap: &'a NameIntMap<T>,
1263    }
1264
1265    impl<'a, T> NamedTermDeepView<'a, T> {
1266        pub(crate) const fn new(view: TermDeepView<'a, Integer>, nimap: &'a NameIntMap<T>) -> Self {
1267            Self { view, nimap }
1268        }
1269    }
1270
1271    impl<'a, T: Display> Display for NamedTermDeepView<'a, T> {
1272        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1273            let Self { view, nimap } = self;
1274
1275            let functor = view.functor();
1276            let args = view.args();
1277            let num_args = args.len();
1278
1279            write_int(functor, nimap, f)?;
1280
1281            if num_args > 0 {
1282                f.write_char('(')?;
1283                for (i, arg) in args.enumerate() {
1284                    fmt::Display::fmt(&Self::new(arg, nimap), f)?;
1285                    if i + 1 < num_args {
1286                        f.write_str(", ")?;
1287                    }
1288                }
1289                f.write_char(')')?;
1290            }
1291            Ok(())
1292        }
1293    }
1294
1295    impl<'a, T: Debug> Debug for NamedTermDeepView<'a, T> {
1296        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1297            let Self { view, nimap } = self;
1298
1299            let functor = view.functor();
1300            let args = view.args();
1301            let num_args = args.len();
1302
1303            if num_args == 0 {
1304                if let Some(name) = nimap.get_name(functor) {
1305                    fmt::Debug::fmt(name, f)
1306                } else {
1307                    fmt::Debug::fmt(functor, f)
1308                }
1309            } else {
1310                let name_str = if let Some(name) = nimap.get_name(functor) {
1311                    format!("{:?}", name)
1312                } else {
1313                    format!("{:?}", functor)
1314                };
1315                let mut d = f.debug_tuple(&name_str);
1316
1317                for arg in args {
1318                    d.field(&Self::new(arg, nimap));
1319                }
1320                d.finish()
1321            }
1322        }
1323    }
1324
1325    pub struct NamedExprView<'a, T> {
1326        view: ExprView<'a, Integer>,
1327        nimap: &'a NameIntMap<T>,
1328    }
1329
1330    impl<'a, T> NamedExprView<'a, T> {
1331        pub(crate) const fn new(view: ExprView<'a, Integer>, nimap: &'a NameIntMap<T>) -> Self {
1332            Self { view, nimap }
1333        }
1334    }
1335
1336    impl<'a, T: Atom> NamedExprView<'a, T> {
1337        pub fn contains_term(&self, term: &Term<T>) -> bool {
1338            match self.view.as_kind() {
1339                ExprKind::Term(view) => NamedTermView {
1340                    view,
1341                    nimap: self.nimap,
1342                }
1343                .contains(term),
1344                ExprKind::Not(view) => NamedExprView {
1345                    view,
1346                    nimap: self.nimap,
1347                }
1348                .contains_term(term),
1349                ExprKind::And(args) | ExprKind::Or(args) => args.into_iter().any(|view| {
1350                    NamedExprView {
1351                        view,
1352                        nimap: self.nimap,
1353                    }
1354                    .contains_term(term)
1355                }),
1356            }
1357        }
1358    }
1359
1360    impl<'a, T: Display> Display for NamedExprView<'a, T> {
1361        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1362            let Self { view, nimap } = self;
1363
1364            match view.as_kind() {
1365                ExprKind::Term(term) => fmt::Display::fmt(&NamedTermView { view: term, nimap }, f)?,
1366                ExprKind::Not(inner) => {
1367                    f.write_str("\\+ ")?;
1368                    if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) {
1369                        f.write_char('(')?;
1370                        fmt::Display::fmt(&Self::new(inner, nimap), f)?;
1371                        f.write_char(')')?;
1372                    } else {
1373                        fmt::Display::fmt(&Self::new(inner, nimap), f)?;
1374                    }
1375                }
1376                ExprKind::And(args) => {
1377                    let num_args = args.len();
1378                    for (i, arg) in args.enumerate() {
1379                        if matches!(arg.as_kind(), ExprKind::Or(_)) {
1380                            f.write_char('(')?;
1381                            fmt::Display::fmt(&Self::new(arg, nimap), f)?;
1382                            f.write_char(')')?;
1383                        } else {
1384                            fmt::Display::fmt(&Self::new(arg, nimap), f)?;
1385                        }
1386                        if i + 1 < num_args {
1387                            f.write_str(", ")?;
1388                        }
1389                    }
1390                }
1391                ExprKind::Or(args) => {
1392                    let num_args = args.len();
1393                    for (i, arg) in args.enumerate() {
1394                        fmt::Display::fmt(&Self::new(arg, nimap), f)?;
1395                        if i + 1 < num_args {
1396                            f.write_str("; ")?;
1397                        }
1398                    }
1399                }
1400            }
1401            Ok(())
1402        }
1403    }
1404
1405    impl<'a, T: Debug> Debug for NamedExprView<'a, T> {
1406        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1407            let Self { view, nimap } = self;
1408
1409            match view.as_kind() {
1410                ExprKind::Term(term) => fmt::Debug::fmt(&NamedTermView::new(term, nimap), f),
1411                ExprKind::Not(inner) => f
1412                    .debug_tuple("Not")
1413                    .field(&NamedExprView::new(inner, nimap))
1414                    .finish(),
1415                ExprKind::And(args) => {
1416                    let mut d = f.debug_tuple("And");
1417                    for arg in args {
1418                        d.field(&NamedExprView::new(arg, nimap));
1419                    }
1420                    d.finish()
1421                }
1422                ExprKind::Or(args) => {
1423                    let mut d = f.debug_tuple("Or");
1424                    for arg in args {
1425                        d.field(&NamedExprView::new(arg, nimap));
1426                    }
1427                    d.finish()
1428                }
1429            }
1430        }
1431    }
1432
1433    fn write_int<T: fmt::Display>(
1434        int: &Integer,
1435        nimap: &NameIntMap<T>,
1436        f: &mut fmt::Formatter<'_>,
1437    ) -> fmt::Result {
1438        if let Some(name) = nimap.get_name(int) {
1439            fmt::Display::fmt(name, f)
1440        } else {
1441            fmt::Debug::fmt(int, f)
1442        }
1443    }
1444}