egg/
explain.rs

1#![allow(clippy::only_used_in_recursion)]
2use crate::Symbol;
3use crate::{
4    util::pretty_print, Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language,
5    PatternAst, RecExpr, Rewrite, UnionFind, Var,
6};
7
8use std::cmp::Ordering;
9use std::collections::{BinaryHeap, VecDeque};
10use std::fmt::{self, Debug, Display, Formatter};
11use std::ops::{Deref, DerefMut};
12use std::rc::Rc;
13
14use num_bigint::BigUint;
15use num_traits::identities::{One, Zero};
16use symbolic_expressions::Sexp;
17
18type ProofCost = BigUint;
19
20const CONGRUENCE_LIMIT: usize = 2;
21const GREEDY_NUM_ITERS: usize = 2;
22
23/// A justification for a union, either via a rule or congruence.
24/// A direct union with a justification is also stored as a rule.
25#[derive(Debug, Clone, Hash, PartialEq, Eq)]
26#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
27pub enum Justification {
28    /// Justification by a rule with this name.
29    Rule(Symbol),
30    /// Justification by congruence.
31    Congruence,
32}
33
34#[derive(Debug, Clone, Hash, PartialEq, Eq)]
35#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
36struct Connection {
37    next: Id,
38    current: Id,
39    justification: Justification,
40    is_rewrite_forward: bool,
41}
42
43#[derive(Debug, Clone)]
44#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
45struct ExplainNode {
46    // neighbors includes parent connections
47    neighbors: Vec<Connection>,
48    parent_connection: Connection,
49}
50
51#[derive(Debug, Clone)]
52#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
53pub struct Explain<L: Language> {
54    explainfind: Vec<ExplainNode>,
55    #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
56    #[cfg_attr(
57        feature = "serde-1",
58        serde(bound(
59            serialize = "L: serde::Serialize",
60            deserialize = "L: serde::Deserialize<'de>",
61        ))
62    )]
63    pub uncanon_memo: HashMap<L, Id>,
64    /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted.
65    pub optimize_explanation_lengths: bool,
66    // For a given pair of enodes in the same eclass,
67    // stores the length of the shortest found explanation
68    // and the Id of the neighbor for retrieving
69    // the explanation.
70    // Invariant: The distance is always <= the unoptimized distance
71    // That is, less than or equal to the result of `distance_between`
72    #[cfg_attr(feature = "serde-1", serde(skip))]
73    shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>,
74}
75
76pub(crate) struct ExplainNodes<'a, L: Language> {
77    explain: &'a mut Explain<L>,
78    nodes: &'a [L],
79}
80
81#[derive(Default)]
82struct DistanceMemo {
83    parent_distance: Vec<(Id, ProofCost)>,
84    common_ancestor: HashMap<(Id, Id), Id>,
85    tree_depth: HashMap<Id, ProofCost>,
86}
87
88/// Explanation trees are the compact representation showing
89/// how one term can be rewritten to another.
90///
91/// Each [`TreeTerm`] has child [`TreeExplanation`]
92/// justifying a transformation from the initial child to the final child term.
93/// Children [`TreeTerm`] can be shared, thus re-using explanations.
94/// This sharing can be checked via Rc pointer equality.
95///
96/// See [`TreeTerm`] for more details on how to
97/// interpret each term.
98pub type TreeExplanation<L> = Vec<Rc<TreeTerm<L>>>;
99
100/// FlatExplanation are the simpler, expanded representation
101/// showing one term being rewritten to another.
102/// Each step contains a full `FlatTerm`. Each flat term
103/// is connected to the previous by exactly one rewrite.
104///
105/// See [`FlatTerm`] for more details on how to find this rewrite.
106pub type FlatExplanation<L> = Vec<FlatTerm<L>>;
107
108/// A vector of equalities based on enode ids. Each entry represents
109/// two enode ids that are equal and why.
110pub type UnionEqualities = Vec<(Id, Id, Symbol)>;
111
112// given two adjacent nodes and the direction of the proof
113type ExplainCache<L> = HashMap<(Id, Id), Rc<TreeTerm<L>>>;
114type NodeExplanationCache<L> = HashMap<Id, Rc<TreeTerm<L>>>;
115
116/** A data structure representing an explanation that two terms are equivalent.
117
118There are two representations of explanations, each of which can be
119represented as s-expressions in strings.
120See [`Explanation`] for more details.
121**/
122pub struct Explanation<L: Language> {
123    /// The tree representation of the explanation.
124    pub explanation_trees: TreeExplanation<L>,
125    flat_explanation: Option<FlatExplanation<L>>,
126}
127
128impl<L: Language + Display + FromOp> Display for Explanation<L> {
129    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
130        let s = self.get_sexp().to_string();
131        f.write_str(&s)
132    }
133}
134
135impl<L: Language + Display + FromOp> Explanation<L> {
136    /// Get each flattened term in the explanation as an s-expression string.
137    ///
138    /// The s-expression format mirrors the format of each [`FlatTerm`].
139    /// Each expression after the first will be annotated in one location with a rewrite.
140    /// When a term is being re-written it is wrapped with "(Rewrite=> rule-name expression)"
141    /// or "(Rewrite<= rule-name expression)".
142    /// "Rewrite=>" indicates that the previous term is rewritten to the current term
143    /// and "Rewrite<=" indicates that the current term is rewritten to the previous term.
144    /// The name of the rule or the reason provided to [`union_instantiations`](super::EGraph::union_instantiations).
145    ///
146    /// Example explanation:
147    /// ```text
148    /// (+ 1 (- a (* (- 2 1) a)))
149    /// (+ 1 (- a (* (Rewrite=> constant_fold 1) a)))
150    /// (+ 1 (- a (Rewrite=> comm-mul (* a 1))))
151    /// (+ 1 (- a (Rewrite<= mul-one a)))
152    /// (+ 1 (Rewrite=> cancel-sub 0))
153    /// (Rewrite=> constant_fold 1)
154    /// ```
155    pub fn get_flat_string(&mut self) -> String {
156        self.get_flat_strings().join("\n")
157    }
158
159    /// Get each the tree-style explanation as an s-expression string.
160    ///
161    /// The s-expression format mirrors the format of each [`TreeTerm`].
162    /// When a child contains an explanation, the explanation is wrapped with
163    /// "(Explanation ...)".
164    /// When a term is being re-written it is wrapped with "(Rewrite=> rule-name expression)"
165    /// or "(Rewrite<= rule-name expression)".
166    /// "Rewrite=>" indicates that the previous term is rewritten to the current term
167    /// and "Rewrite<=" indicates that the current term is rewritten to the previous term.
168    /// The name of the rule or the reason provided to [`union_instantiations`](super::EGraph::union_instantiations).
169    ///
170    /// The following example shows that `(+ 1 (- a (* (- 2 1) a))) = 1`
171    /// Example explanation:
172    /// ```text
173    /// (+ 1 (- a (* (- 2 1) a)))
174    /// (+
175    ///    1
176    ///    (Explanation
177    ///      (- a (* (- 2 1) a))
178    ///      (-
179    ///        a
180    ///        (Explanation
181    ///          (* (- 2 1) a)
182    ///          (* (Explanation (- 2 1) (Rewrite=> constant_fold 1)) a)
183    ///          (Rewrite=> comm-mul (* a 1))
184    ///          (Rewrite<= mul-one a)))
185    ///      (Rewrite=> cancel-sub 0)))
186    /// (Rewrite=> constant_fold 1)
187    /// ```
188    pub fn get_string(&self) -> String {
189        self.to_string()
190    }
191
192    /// Get the tree-style explanation as an s-expression string
193    /// with let binding to enable sharing of subproofs.
194    ///
195    /// The following explanation shows that `(+ x (+ x (+ x x))) = (* 4 x)`.
196    /// Steps such as factoring are shared via the let bindings.
197    /// Example explanation:
198    ///
199    /// ```text
200    /// (let
201    ///     (v_0 (Rewrite=> mul-one (* x 1)))
202    ///     (let
203    ///       (v_1 (+ (Explanation x v_0) (Explanation x v_0)))
204    ///       (let
205    ///         (v_2 (+ 1 1))
206    ///         (let
207    ///           (v_3 (Rewrite=> factor (* x v_2)))
208    ///           (Explanation
209    ///             (+ x (+ x (+ x x)))
210    ///             (Rewrite=> assoc-add (+ (+ x x) (+ x x)))
211    ///             (+ (Explanation (+ x x) v_1 v_3) (Explanation (+ x x) v_1 v_3))
212    ///             (Rewrite=> factor (* x (+ (+ 1 1) (+ 1 1))))
213    ///             (Rewrite=> comm-mul (* (+ (+ 1 1) (+ 1 1)) x))
214    ///             (*
215    ///               (Explanation
216    ///                 (+ (+ 1 1) (+ 1 1))
217    ///                 (+
218    ///                   (Explanation (+ 1 1) (Rewrite=> constant_fold 2))
219    ///                   (Explanation (+ 1 1) (Rewrite=> constant_fold 2)))
220    ///                 (Rewrite=> constant_fold 4))
221    ///               x))))))
222    /// ```
223    pub fn get_string_with_let(&self) -> String {
224        let mut s = "".to_string();
225        pretty_print(&mut s, &self.get_sexp_with_let(), 100, 0).unwrap();
226        s
227    }
228
229    /// Get each term in the explanation as a string.
230    /// See [`get_string`](Explanation::get_string) for the format of these strings.
231    pub fn get_flat_strings(&mut self) -> Vec<String> {
232        self.make_flat_explanation()
233            .iter()
234            .map(|e| e.to_string())
235            .collect()
236    }
237
238    fn get_sexp(&self) -> Sexp {
239        let mut items = vec![Sexp::String("Explanation".to_string())];
240        for e in self.explanation_trees.iter() {
241            items.push(e.get_sexp());
242        }
243
244        Sexp::List(items)
245    }
246
247    /// Get the size of this explanation tree in terms of the number of rewrites
248    /// in the let-bound version of the tree.
249    pub fn get_tree_size(&self) -> ProofCost {
250        let mut seen = Default::default();
251        let mut seen_adjacent = Default::default();
252        let mut sum: ProofCost = BigUint::zero();
253        for e in self.explanation_trees.iter() {
254            sum += self.tree_size(&mut seen, &mut seen_adjacent, e);
255        }
256        sum
257    }
258
259    fn tree_size(
260        &self,
261        seen: &mut HashSet<*const TreeTerm<L>>,
262        seen_adjacent: &mut HashSet<(Id, Id)>,
263        current: &Rc<TreeTerm<L>>,
264    ) -> ProofCost {
265        if !seen.insert(&**current as *const TreeTerm<L>) {
266            return BigUint::zero();
267        }
268        let mut my_size: ProofCost = BigUint::zero();
269        if current.forward_rule.is_some() {
270            my_size += 1_u32;
271        }
272        if current.backward_rule.is_some() {
273            my_size += 1_u32;
274        }
275        assert!(my_size.is_zero() || my_size.is_one());
276        if my_size.is_one() {
277            if !seen_adjacent.insert((current.current, current.last)) {
278                return BigUint::zero();
279            } else {
280                seen_adjacent.insert((current.last, current.current));
281            }
282        }
283
284        for child_proof in &current.child_proofs {
285            for child in child_proof {
286                my_size += self.tree_size(seen, seen_adjacent, child);
287            }
288        }
289        my_size
290    }
291
292    fn get_sexp_with_let(&self) -> Sexp {
293        let mut shared: HashSet<*const TreeTerm<L>> = Default::default();
294        let mut to_let_bind = vec![];
295        for term in &self.explanation_trees {
296            self.find_to_let_bind(term.clone(), &mut shared, &mut to_let_bind);
297        }
298
299        let mut bindings: HashMap<*const TreeTerm<L>, Sexp> = Default::default();
300        let mut generated_bindings: Vec<(Sexp, Sexp)> = Default::default();
301        for to_bind in to_let_bind {
302            if bindings.get(&(&*to_bind as *const TreeTerm<L>)).is_none() {
303                let name = Sexp::String("v_".to_string() + &generated_bindings.len().to_string());
304                let ast = to_bind.get_sexp_with_bindings(&bindings);
305                generated_bindings.push((name.clone(), ast));
306                bindings.insert(&*to_bind as *const TreeTerm<L>, name);
307            }
308        }
309
310        let mut items = vec![Sexp::String("Explanation".to_string())];
311        for e in self.explanation_trees.iter() {
312            if let Some(existing) = bindings.get(&(&**e as *const TreeTerm<L>)) {
313                items.push(existing.clone());
314            } else {
315                items.push(e.get_sexp_with_bindings(&bindings));
316            }
317        }
318
319        let mut result = Sexp::List(items);
320
321        for (name, expr) in generated_bindings.into_iter().rev() {
322            let let_expr = Sexp::List(vec![name, expr]);
323            result = Sexp::List(vec![Sexp::String("let".to_string()), let_expr, result]);
324        }
325
326        result
327    }
328
329    // for every subterm which is shared in
330    // multiple places, add it to to_let_bind
331    fn find_to_let_bind(
332        &self,
333        term: Rc<TreeTerm<L>>,
334        shared: &mut HashSet<*const TreeTerm<L>>,
335        to_let_bind: &mut Vec<Rc<TreeTerm<L>>>,
336    ) {
337        if !term.child_proofs.is_empty() {
338            if shared.insert(&*term as *const TreeTerm<L>) {
339                for proof in &term.child_proofs {
340                    for child in proof {
341                        self.find_to_let_bind(child.clone(), shared, to_let_bind);
342                    }
343                }
344            } else {
345                to_let_bind.push(term);
346            }
347        }
348    }
349}
350
351impl<L: Language> Explanation<L> {
352    /// Construct a new explanation given its tree representation.
353    pub fn new(explanation_trees: TreeExplanation<L>) -> Explanation<L> {
354        Explanation {
355            explanation_trees,
356            flat_explanation: None,
357        }
358    }
359
360    /// Construct the flat representation of the explanation and return it.
361    pub fn make_flat_explanation(&mut self) -> &FlatExplanation<L> {
362        if self.flat_explanation.is_some() {
363            return self.flat_explanation.as_ref().unwrap();
364        } else {
365            self.flat_explanation = Some(TreeTerm::flatten_proof(&self.explanation_trees));
366            self.flat_explanation.as_ref().unwrap()
367        }
368    }
369
370    /// Check the validity of the explanation with respect to the given rules.
371    /// This only is able to check rule applications when the rules are implement `get_pattern_ast`.
372    pub fn check_proof<'a, R, N>(&mut self, rules: R)
373    where
374        R: IntoIterator<Item = &'a Rewrite<L, N>>,
375        L: 'a,
376        N: Analysis<L> + 'a,
377    {
378        let rules: Vec<&Rewrite<L, N>> = rules.into_iter().collect();
379        let rule_table = Explain::make_rule_table(rules.as_slice());
380        self.make_flat_explanation();
381        let flat_explanation = self.flat_explanation.as_ref().unwrap();
382        assert!(!flat_explanation[0].has_rewrite_forward());
383        assert!(!flat_explanation[0].has_rewrite_backward());
384        for i in 0..flat_explanation.len() - 1 {
385            let current = &flat_explanation[i];
386            let next = &flat_explanation[i + 1];
387
388            let has_forward = next.has_rewrite_forward();
389            let has_backward = next.has_rewrite_backward();
390            assert!(has_forward ^ has_backward);
391
392            if has_forward {
393                assert!(self.check_rewrite_at(current, next, &rule_table, true));
394            } else {
395                assert!(self.check_rewrite_at(current, next, &rule_table, false));
396            }
397        }
398    }
399
400    fn check_rewrite_at<N: Analysis<L>>(
401        &self,
402        current: &FlatTerm<L>,
403        next: &FlatTerm<L>,
404        table: &HashMap<Symbol, &Rewrite<L, N>>,
405        is_forward: bool,
406    ) -> bool {
407        if is_forward && next.forward_rule.is_some() {
408            let rule_name = next.forward_rule.as_ref().unwrap();
409            if let Some(rule) = table.get(rule_name) {
410                Explanation::check_rewrite(current, next, rule)
411            } else {
412                // give up when the rule is not provided
413                true
414            }
415        } else if !is_forward && next.backward_rule.is_some() {
416            let rule_name = next.backward_rule.as_ref().unwrap();
417            if let Some(rule) = table.get(rule_name) {
418                Explanation::check_rewrite(next, current, rule)
419            } else {
420                true
421            }
422        } else {
423            for (left, right) in current.children.iter().zip(next.children.iter()) {
424                if !self.check_rewrite_at(left, right, table, is_forward) {
425                    return false;
426                }
427            }
428            true
429        }
430    }
431
432    // if the rewrite is just patterns, then it can check it
433    fn check_rewrite<'a, N: Analysis<L>>(
434        current: &'a FlatTerm<L>,
435        next: &'a FlatTerm<L>,
436        rewrite: &Rewrite<L, N>,
437    ) -> bool {
438        if let Some(lhs) = rewrite.searcher.get_pattern_ast() {
439            if let Some(rhs) = rewrite.applier.get_pattern_ast() {
440                let rewritten = current.rewrite(lhs, rhs);
441                if &rewritten != next {
442                    return false;
443                }
444            }
445        }
446        true
447    }
448}
449
450/// An explanation for a term and its equivalent children.
451/// Each child is a proof transforming the initial child into the final child term.
452/// The initial term is given by taking each first sub-term
453/// in each [`child_proofs`](TreeTerm::child_proofs) recursively.
454/// The final term is given by all of the final terms in each [`child_proofs`](TreeTerm::child_proofs).
455///
456/// If [`forward_rule`](TreeTerm::forward_rule) is provided, then this TreeTerm's initial term
457/// can be derived from the previous
458/// TreeTerm by applying the rule.
459/// Similarly, if [`backward_rule`](TreeTerm::backward_rule) is provided,
460/// then the previous TreeTerm's final term is given by applying the rule to this TreeTerm's initial term.
461///
462/// TreeTerms are flattened by first flattening [`child_proofs`](TreeTerm::child_proofs), then wrapping
463/// the flattened proof with this TreeTerm's node.
464#[derive(Debug, Clone)]
465pub struct TreeTerm<L: Language> {
466    /// A node representing this TreeTerm's operator. The children of the node should be ignored.
467    pub node: L,
468    /// A rule rewriting this TreeTerm's initial term back to the last TreeTerm's final term.
469    pub backward_rule: Option<Symbol>,
470    /// A rule rewriting the last TreeTerm's final term to this TreeTerm's initial term.
471    pub forward_rule: Option<Symbol>,
472    /// A list of child proofs, each transforming the initial term to the final term for that child.
473    pub child_proofs: Vec<TreeExplanation<L>>,
474
475    last: Id,
476    current: Id,
477}
478
479impl<L: Language> TreeTerm<L> {
480    /// Construct a new TreeTerm given its node and child_proofs.
481    pub fn new(node: L, child_proofs: Vec<TreeExplanation<L>>) -> TreeTerm<L> {
482        TreeTerm {
483            node,
484            backward_rule: None,
485            forward_rule: None,
486            child_proofs,
487            current: Id::from(0),
488            last: Id::from(0),
489        }
490    }
491
492    fn flatten_proof(proof: &[Rc<TreeTerm<L>>]) -> FlatExplanation<L> {
493        let mut flat_proof: FlatExplanation<L> = vec![];
494        for tree in proof {
495            let mut explanation = tree.flatten_explanation();
496
497            if !flat_proof.is_empty()
498                && !explanation[0].has_rewrite_forward()
499                && !explanation[0].has_rewrite_backward()
500            {
501                let last = flat_proof.pop().unwrap();
502                explanation[0].combine_rewrites(&last);
503            }
504
505            flat_proof.extend(explanation);
506        }
507
508        flat_proof
509    }
510
511    /// Get a FlatTerm representing the first term in this proof.
512    pub fn get_initial_flat_term(&self) -> FlatTerm<L> {
513        FlatTerm {
514            node: self.node.clone(),
515            backward_rule: self.backward_rule,
516            forward_rule: self.forward_rule,
517            children: self
518                .child_proofs
519                .iter()
520                .map(|child_proof| child_proof[0].get_initial_flat_term())
521                .collect(),
522        }
523    }
524
525    /// Get a FlatTerm representing the final term in this proof.
526    pub fn get_last_flat_term(&self) -> FlatTerm<L> {
527        FlatTerm {
528            node: self.node.clone(),
529            backward_rule: self.backward_rule,
530            forward_rule: self.forward_rule,
531            children: self
532                .child_proofs
533                .iter()
534                .map(|child_proof| child_proof[child_proof.len() - 1].get_last_flat_term())
535                .collect(),
536        }
537    }
538
539    /// Construct the [`FlatExplanation`] for this TreeTerm.
540    pub fn flatten_explanation(&self) -> FlatExplanation<L> {
541        let mut proof = vec![];
542        let mut child_proofs = vec![];
543        let mut representative_terms = vec![];
544        for child_explanation in &self.child_proofs {
545            let flat_proof = TreeTerm::flatten_proof(child_explanation);
546            representative_terms.push(flat_proof[0].remove_rewrites());
547            child_proofs.push(flat_proof);
548        }
549
550        proof.push(FlatTerm::new(
551            self.node.clone(),
552            representative_terms.clone(),
553        ));
554
555        for (i, child_proof) in child_proofs.iter().enumerate() {
556            // replace first one to preserve the rule annotation
557            proof.last_mut().unwrap().children[i] = child_proof[0].clone();
558
559            for child in child_proof.iter().skip(1) {
560                let mut children = vec![];
561                for (j, rep_term) in representative_terms.iter().enumerate() {
562                    if j == i {
563                        children.push(child.clone());
564                    } else {
565                        children.push(rep_term.clone());
566                    }
567                }
568
569                proof.push(FlatTerm::new(self.node.clone(), children));
570            }
571            representative_terms[i] = child_proof.last().unwrap().remove_rewrites();
572        }
573
574        proof[0].backward_rule = self.backward_rule;
575        proof[0].forward_rule = self.forward_rule;
576
577        proof
578    }
579}
580
581/// A single term in an flattened explanation.
582/// After the first term in a [`FlatExplanation`], each term
583/// will be annotated with exactly one [`backward_rule`](FlatTerm::backward_rule) or one
584/// [`forward_rule`](FlatTerm::forward_rule). This can appear in children [`FlatTerm`]s,
585/// indicating that the child is being rewritten.
586///
587/// When [`forward_rule`](FlatTerm::forward_rule) is provided, the previous FlatTerm can be rewritten
588/// to this FlatTerm by applying the rule.
589/// When [`backward_rule`](FlatTerm::backward_rule) is provided, the previous FlatTerm is given by applying
590/// the rule to this FlatTerm.
591/// Rules are either the string of the name of the rule or the reason provided to
592/// [`union_instantiations`](super::EGraph::union_instantiations).
593///
594#[derive(Debug, Clone, Eq)]
595pub struct FlatTerm<L: Language> {
596    /// The node representing this FlatTerm's operator.
597    /// The children of the node should be ignored.
598    pub node: L,
599    /// A rule rewriting this FlatTerm back to the last FlatTerm.
600    pub backward_rule: Option<Symbol>,
601    /// A rule rewriting the last FlatTerm to this FlatTerm.
602    pub forward_rule: Option<Symbol>,
603    /// The children of this FlatTerm.
604    pub children: FlatExplanation<L>,
605}
606
607impl<L: Language + Display + FromOp> Display for FlatTerm<L> {
608    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
609        let s = self.get_sexp().to_string();
610        write!(f, "{}", s)
611    }
612}
613
614impl<L: Language> PartialEq for FlatTerm<L> {
615    fn eq(&self, other: &FlatTerm<L>) -> bool {
616        if !self.node.matches(&other.node) {
617            return false;
618        }
619
620        for (child1, child2) in self.children.iter().zip(other.children.iter()) {
621            if !child1.eq(child2) {
622                return false;
623            }
624        }
625        true
626    }
627}
628
629impl<L: Language> FlatTerm<L> {
630    /// Remove the rewrite annotation from this flatterm, if any.
631    pub fn remove_rewrites(&self) -> FlatTerm<L> {
632        FlatTerm::new(
633            self.node.clone(),
634            self.children
635                .iter()
636                .map(|child| child.remove_rewrites())
637                .collect(),
638        )
639    }
640
641    fn combine_rewrites(&mut self, other: &FlatTerm<L>) {
642        if other.forward_rule.is_some() {
643            assert!(self.forward_rule.is_none());
644            self.forward_rule = other.forward_rule;
645        }
646
647        if other.backward_rule.is_some() {
648            assert!(self.backward_rule.is_none());
649            self.backward_rule = other.backward_rule;
650        }
651
652        for (left, right) in self.children.iter_mut().zip(other.children.iter()) {
653            left.combine_rewrites(right);
654        }
655    }
656}
657
658impl<L: Language> Default for Explain<L> {
659    fn default() -> Self {
660        Self::new()
661    }
662}
663
664impl<L: Language + Display + FromOp> FlatTerm<L> {
665    /// Convert this FlatTerm to an S-expression.
666    /// See [`get_flat_string`](Explanation::get_flat_string) for the format of these expressions.
667    pub fn get_string(&self) -> String {
668        self.get_sexp().to_string()
669    }
670
671    fn get_sexp(&self) -> Sexp {
672        let op = Sexp::String(self.node.to_string());
673        let mut expr = if self.node.is_leaf() {
674            op
675        } else {
676            let mut vec = vec![op];
677            for child in &self.children {
678                vec.push(child.get_sexp());
679            }
680            Sexp::List(vec)
681        };
682
683        if let Some(rule_name) = &self.backward_rule {
684            expr = Sexp::List(vec![
685                Sexp::String("Rewrite<=".to_string()),
686                Sexp::String((*rule_name).to_string()),
687                expr,
688            ]);
689        }
690
691        if let Some(rule_name) = &self.forward_rule {
692            expr = Sexp::List(vec![
693                Sexp::String("Rewrite=>".to_string()),
694                Sexp::String((*rule_name).to_string()),
695                expr,
696            ]);
697        }
698
699        expr
700    }
701
702    /// Convert this FlatTerm to a RecExpr.
703    pub fn get_recexpr(&self) -> RecExpr<L> {
704        self.remove_rewrites().to_string().parse().unwrap()
705    }
706}
707
708impl<L: Language + Display + FromOp> Display for TreeTerm<L> {
709    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
710        let mut buf = String::new();
711        let width = 80;
712        pretty_print(&mut buf, &self.get_sexp(), width, 1).unwrap();
713        write!(f, "{}", buf)
714    }
715}
716
717impl<L: Language + Display + FromOp> TreeTerm<L> {
718    /// Convert this TreeTerm to an S-expression.
719    fn get_sexp(&self) -> Sexp {
720        self.get_sexp_with_bindings(&Default::default())
721    }
722
723    fn get_sexp_with_bindings(&self, bindings: &HashMap<*const TreeTerm<L>, Sexp>) -> Sexp {
724        let op = Sexp::String(self.node.to_string());
725        let mut expr = if self.node.is_leaf() {
726            op
727        } else {
728            let mut vec = vec![op];
729            for child in &self.child_proofs {
730                assert!(!child.is_empty());
731                if child.len() == 1 {
732                    if let Some(existing) = bindings.get(&(&*child[0] as *const TreeTerm<L>)) {
733                        vec.push(existing.clone());
734                    } else {
735                        vec.push(child[0].get_sexp_with_bindings(bindings));
736                    }
737                } else {
738                    let mut child_expressions = vec![Sexp::String("Explanation".to_string())];
739                    for child_explanation in child.iter() {
740                        if let Some(existing) =
741                            bindings.get(&(&**child_explanation as *const TreeTerm<L>))
742                        {
743                            child_expressions.push(existing.clone());
744                        } else {
745                            child_expressions
746                                .push(child_explanation.get_sexp_with_bindings(bindings));
747                        }
748                    }
749                    vec.push(Sexp::List(child_expressions));
750                }
751            }
752            Sexp::List(vec)
753        };
754
755        if let Some(rule_name) = &self.backward_rule {
756            expr = Sexp::List(vec![
757                Sexp::String("Rewrite<=".to_string()),
758                Sexp::String((*rule_name).to_string()),
759                expr,
760            ]);
761        }
762
763        if let Some(rule_name) = &self.forward_rule {
764            expr = Sexp::List(vec![
765                Sexp::String("Rewrite=>".to_string()),
766                Sexp::String((*rule_name).to_string()),
767                expr,
768            ]);
769        }
770
771        expr
772    }
773}
774
775impl<L: Language> FlatTerm<L> {
776    /// Construct a new FlatTerm given a node and its children.
777    pub fn new(node: L, children: FlatExplanation<L>) -> FlatTerm<L> {
778        FlatTerm {
779            node,
780            backward_rule: None,
781            forward_rule: None,
782            children,
783        }
784    }
785
786    /// Rewrite the FlatTerm by matching the lhs and substituting the rhs.
787    /// The lhs must be guaranteed to match.
788    pub fn rewrite(&self, lhs: &PatternAst<L>, rhs: &PatternAst<L>) -> FlatTerm<L> {
789        let mut bindings = Default::default();
790        self.make_bindings(lhs, lhs.len() - 1, &mut bindings);
791        FlatTerm::from_pattern(rhs, rhs.len() - 1, &bindings)
792    }
793
794    /// Checks if this term or any child has a [`forward_rule`](FlatTerm::forward_rule).
795    pub fn has_rewrite_forward(&self) -> bool {
796        self.forward_rule.is_some()
797            || self
798                .children
799                .iter()
800                .any(|child| child.has_rewrite_forward())
801    }
802
803    /// Checks if this term or any child has a [`backward_rule`](FlatTerm::backward_rule).
804    pub fn has_rewrite_backward(&self) -> bool {
805        self.backward_rule.is_some()
806            || self
807                .children
808                .iter()
809                .any(|child| child.has_rewrite_backward())
810    }
811
812    fn from_pattern(
813        pattern: &[ENodeOrVar<L>],
814        location: usize,
815        bindings: &HashMap<Var, &FlatTerm<L>>,
816    ) -> FlatTerm<L> {
817        match &pattern[location] {
818            ENodeOrVar::Var(var) => (*bindings.get(var).unwrap()).clone(),
819            ENodeOrVar::ENode(node) => {
820                let children = node.fold(vec![], |mut acc, child| {
821                    acc.push(FlatTerm::from_pattern(
822                        pattern,
823                        usize::from(child),
824                        bindings,
825                    ));
826                    acc
827                });
828                FlatTerm::new(node.clone(), children)
829            }
830        }
831    }
832
833    fn make_bindings<'a>(
834        &'a self,
835        pattern: &[ENodeOrVar<L>],
836        location: usize,
837        bindings: &mut HashMap<Var, &'a FlatTerm<L>>,
838    ) {
839        match &pattern[location] {
840            ENodeOrVar::Var(var) => {
841                if let Some(existing) = bindings.get(var) {
842                    if existing != &self {
843                        panic!(
844                            "Invalid proof: binding for variable {:?} does not match between {:?} \n and \n {:?}",
845                            var, existing, self);
846                    }
847                } else {
848                    bindings.insert(*var, self);
849                }
850            }
851            ENodeOrVar::ENode(node) => {
852                // The node must match the rewrite or the proof is invalid.
853                assert!(node.matches(&self.node));
854                let mut counter = 0;
855                node.for_each(|child| {
856                    self.children[counter].make_bindings(pattern, usize::from(child), bindings);
857                    counter += 1;
858                });
859            }
860        }
861    }
862}
863
864// Make sure to use push_increase instead of push when using priority queue
865#[derive(Clone, Eq, PartialEq)]
866struct HeapState<I> {
867    cost: ProofCost,
868    item: I,
869}
870// The priority queue depends on `Ord`.
871// Explicitly implement the trait so the queue becomes a min-heap
872// instead of a max-heap.
873impl<I: Eq + PartialEq> Ord for HeapState<I> {
874    fn cmp(&self, other: &Self) -> Ordering {
875        // Notice that the we flip the ordering on costs.
876        // In case of a tie we compare positions - this step is necessary
877        // to make implementations of `PartialEq` and `Ord` consistent.
878        other
879            .cost
880            .cmp(&self.cost)
881            .then_with(|| self.cost.cmp(&other.cost))
882    }
883}
884
885// `PartialOrd` needs to be implemented as well.
886impl<I: Eq + PartialEq> PartialOrd for HeapState<I> {
887    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
888        Some(self.cmp(other))
889    }
890}
891
892impl<L: Language> Explain<L> {
893    fn make_rule_table<'a, N: Analysis<L>>(
894        rules: &[&'a Rewrite<L, N>],
895    ) -> HashMap<Symbol, &'a Rewrite<L, N>> {
896        let mut table: HashMap<Symbol, &'a Rewrite<L, N>> = Default::default();
897        for r in rules {
898            table.insert(r.name, r);
899        }
900        table
901    }
902    pub fn new() -> Self {
903        Explain {
904            explainfind: vec![],
905            uncanon_memo: Default::default(),
906            shortest_explanation_memo: Default::default(),
907            optimize_explanation_lengths: true,
908        }
909    }
910
911    pub(crate) fn add(&mut self, node: L, set: Id) -> Id {
912        assert_eq!(self.explainfind.len(), usize::from(set));
913        self.uncanon_memo.insert(node, set);
914        self.explainfind.push(ExplainNode {
915            neighbors: vec![],
916            parent_connection: Connection {
917                justification: Justification::Congruence,
918                is_rewrite_forward: false,
919                next: set,
920                current: set,
921            },
922        });
923        set
924    }
925
926    // reverse edges recursively to make this node the leader
927    fn make_leader(&mut self, node: Id) {
928        let next = self.explainfind[usize::from(node)].parent_connection.next;
929        if next != node {
930            self.make_leader(next);
931            let node_connection = &self.explainfind[usize::from(node)].parent_connection;
932            let pconnection = Connection {
933                justification: node_connection.justification.clone(),
934                is_rewrite_forward: !node_connection.is_rewrite_forward,
935                next: node,
936                current: next,
937            };
938            self.explainfind[usize::from(next)].parent_connection = pconnection;
939        }
940    }
941
942    pub(crate) fn alternate_rewrite(&mut self, node1: Id, node2: Id, justification: Justification) {
943        if node1 == node2 {
944            return;
945        }
946        if let Some((cost, _)) = self.shortest_explanation_memo.get(&(node1, node2)) {
947            if cost.is_zero() || cost.is_one() {
948                return;
949            }
950        }
951
952        let lconnection = Connection {
953            justification: justification.clone(),
954            is_rewrite_forward: true,
955            next: node2,
956            current: node1,
957        };
958
959        let rconnection = Connection {
960            justification,
961            is_rewrite_forward: false,
962            next: node1,
963            current: node2,
964        };
965
966        self.explainfind[usize::from(node1)]
967            .neighbors
968            .push(lconnection);
969        self.explainfind[usize::from(node2)]
970            .neighbors
971            .push(rconnection);
972        self.shortest_explanation_memo
973            .insert((node1, node2), (BigUint::one(), node2));
974        self.shortest_explanation_memo
975            .insert((node2, node1), (BigUint::one(), node1));
976    }
977
978    pub(crate) fn union(&mut self, node1: Id, node2: Id, justification: Justification) {
979        if let Justification::Congruence = justification {
980            // assert!(self.node(node1).matches(self.node(node2)));
981        }
982
983        self.make_leader(node1);
984        self.explainfind[usize::from(node1)].parent_connection.next = node2;
985
986        if let Justification::Rule(_) = justification {
987            self.shortest_explanation_memo
988                .insert((node1, node2), (BigUint::one(), node2));
989            self.shortest_explanation_memo
990                .insert((node2, node1), (BigUint::one(), node1));
991        }
992
993        let pconnection = Connection {
994            justification: justification.clone(),
995            is_rewrite_forward: true,
996            next: node2,
997            current: node1,
998        };
999        let other_pconnection = Connection {
1000            justification,
1001            is_rewrite_forward: false,
1002            next: node1,
1003            current: node2,
1004        };
1005        self.explainfind[usize::from(node1)]
1006            .neighbors
1007            .push(pconnection.clone());
1008        self.explainfind[usize::from(node2)]
1009            .neighbors
1010            .push(other_pconnection);
1011        self.explainfind[usize::from(node1)].parent_connection = pconnection;
1012    }
1013    pub(crate) fn get_union_equalities(&self) -> UnionEqualities {
1014        let mut equalities = vec![];
1015        for node in &self.explainfind {
1016            for neighbor in &node.neighbors {
1017                if neighbor.is_rewrite_forward {
1018                    if let Justification::Rule(r) = neighbor.justification {
1019                        equalities.push((neighbor.current, neighbor.next, r));
1020                    }
1021                }
1022            }
1023        }
1024        equalities
1025    }
1026
1027    pub(crate) fn with_nodes<'a>(&'a mut self, nodes: &'a [L]) -> ExplainNodes<'a, L> {
1028        ExplainNodes {
1029            explain: self,
1030            nodes,
1031        }
1032    }
1033}
1034
1035impl<'a, L: Language> Deref for ExplainNodes<'a, L> {
1036    type Target = Explain<L>;
1037
1038    fn deref(&self) -> &Self::Target {
1039        self.explain
1040    }
1041}
1042
1043impl<'a, L: Language> DerefMut for ExplainNodes<'a, L> {
1044    fn deref_mut(&mut self) -> &mut Self::Target {
1045        &mut *self.explain
1046    }
1047}
1048
1049impl<'x, L: Language> ExplainNodes<'x, L> {
1050    pub(crate) fn node(&self, node_id: Id) -> &L {
1051        &self.nodes[usize::from(node_id)]
1052    }
1053    fn node_to_explanation(
1054        &self,
1055        node_id: Id,
1056        cache: &mut NodeExplanationCache<L>,
1057    ) -> Rc<TreeTerm<L>> {
1058        if let Some(existing) = cache.get(&node_id) {
1059            existing.clone()
1060        } else {
1061            let node = self.node(node_id).clone();
1062            let children = node.fold(vec![], |mut sofar, child| {
1063                sofar.push(vec![self.node_to_explanation(child, cache)]);
1064                sofar
1065            });
1066            let res = Rc::new(TreeTerm::new(node, children));
1067            cache.insert(node_id, res.clone());
1068            res
1069        }
1070    }
1071
1072    fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm<L> {
1073        let node = self.node(node_id).clone();
1074        let children = node.fold(vec![], |mut sofar, child| {
1075            sofar.push(self.node_to_flat_explanation(child));
1076            sofar
1077        });
1078        FlatTerm::new(node, children)
1079    }
1080
1081    pub fn check_each_explain<N: Analysis<L>>(&self, rules: &[&Rewrite<L, N>]) -> bool {
1082        let rule_table = Explain::make_rule_table(rules);
1083        for i in 0..self.explainfind.len() {
1084            let explain_node = &self.explainfind[i];
1085
1086            if explain_node.parent_connection.next != Id::from(i) {
1087                let mut current_explanation = self.node_to_flat_explanation(Id::from(i));
1088                let mut next_explanation =
1089                    self.node_to_flat_explanation(explain_node.parent_connection.next);
1090                if let Justification::Rule(rule_name) =
1091                    &explain_node.parent_connection.justification
1092                {
1093                    if let Some(rule) = rule_table.get(rule_name) {
1094                        if !explain_node.parent_connection.is_rewrite_forward {
1095                            std::mem::swap(&mut current_explanation, &mut next_explanation);
1096                        }
1097                        if !Explanation::check_rewrite(
1098                            &current_explanation,
1099                            &next_explanation,
1100                            rule,
1101                        ) {
1102                            return false;
1103                        }
1104                    }
1105                }
1106            }
1107        }
1108        true
1109    }
1110
1111    pub(crate) fn explain_equivalence<N: Analysis<L>>(
1112        &mut self,
1113        left: Id,
1114        right: Id,
1115        unionfind: &mut UnionFind,
1116        classes: &HashMap<Id, EClass<L, N::Data>>,
1117    ) -> Explanation<L> {
1118        if self.optimize_explanation_lengths {
1119            self.calculate_shortest_explanations::<N>(left, right, classes, unionfind);
1120        }
1121
1122        let mut cache = Default::default();
1123        let mut enode_cache = Default::default();
1124        Explanation::new(self.explain_enodes(left, right, &mut cache, &mut enode_cache, false))
1125    }
1126
1127    fn common_ancestor(&self, mut left: Id, mut right: Id) -> Id {
1128        let mut seen_left: HashSet<Id> = Default::default();
1129        let mut seen_right: HashSet<Id> = Default::default();
1130        loop {
1131            seen_left.insert(left);
1132            if seen_right.contains(&left) {
1133                return left;
1134            }
1135
1136            seen_right.insert(right);
1137            if seen_left.contains(&right) {
1138                return right;
1139            }
1140
1141            let next_left = self.explainfind[usize::from(left)].parent_connection.next;
1142            let next_right = self.explainfind[usize::from(right)].parent_connection.next;
1143            assert!(next_left != left || next_right != right);
1144            left = next_left;
1145            right = next_right;
1146        }
1147    }
1148
1149    fn get_connections(&self, mut node: Id, ancestor: Id) -> Vec<Connection> {
1150        if node == ancestor {
1151            return vec![];
1152        }
1153
1154        let mut nodes = vec![];
1155        loop {
1156            let next = self.explainfind[usize::from(node)].parent_connection.next;
1157            nodes.push(
1158                self.explainfind[usize::from(node)]
1159                    .parent_connection
1160                    .clone(),
1161            );
1162            if next == ancestor {
1163                return nodes;
1164            }
1165            assert!(next != node);
1166            node = next;
1167        }
1168    }
1169
1170    fn get_path_unoptimized(&self, left: Id, right: Id) -> (Vec<Connection>, Vec<Connection>) {
1171        let ancestor = self.common_ancestor(left, right);
1172        let left_connections = self.get_connections(left, ancestor);
1173        let right_connections = self.get_connections(right, ancestor);
1174        (left_connections, right_connections)
1175    }
1176
1177    fn get_neighbor(&self, current: Id, next: Id) -> Connection {
1178        for neighbor in &self.explainfind[usize::from(current)].neighbors {
1179            if neighbor.next == next {
1180                if let Justification::Rule(_) = neighbor.justification {
1181                    return neighbor.clone();
1182                }
1183            }
1184        }
1185        Connection {
1186            justification: Justification::Congruence,
1187            current,
1188            next,
1189            is_rewrite_forward: true,
1190        }
1191    }
1192
1193    fn get_path(&self, mut left: Id, right: Id) -> (Vec<Connection>, Vec<Connection>) {
1194        let mut left_connections = vec![];
1195        loop {
1196            if left == right {
1197                return (left_connections, vec![]);
1198            }
1199            if let Some((_, next)) = self.shortest_explanation_memo.get(&(left, right)) {
1200                left_connections.push(self.get_neighbor(left, *next));
1201                left = *next;
1202            } else {
1203                break;
1204            }
1205        }
1206
1207        let (restleft, right_connections) = self.get_path_unoptimized(left, right);
1208        left_connections.extend(restleft);
1209        (left_connections, right_connections)
1210    }
1211
1212    fn explain_enodes(
1213        &self,
1214        left: Id,
1215        right: Id,
1216        cache: &mut ExplainCache<L>,
1217        node_explanation_cache: &mut NodeExplanationCache<L>,
1218        use_unoptimized: bool,
1219    ) -> TreeExplanation<L> {
1220        let mut proof = vec![self.node_to_explanation(left, node_explanation_cache)];
1221        let (left_connections, right_connections) = if use_unoptimized {
1222            self.get_path_unoptimized(left, right)
1223        } else {
1224            self.get_path(left, right)
1225        };
1226
1227        for (i, connection) in left_connections
1228            .iter()
1229            .chain(right_connections.iter().rev())
1230            .enumerate()
1231        {
1232            let mut connection = connection.clone();
1233            if i >= left_connections.len() {
1234                connection.is_rewrite_forward = !connection.is_rewrite_forward;
1235                std::mem::swap(&mut connection.next, &mut connection.current);
1236            }
1237
1238            proof.push(self.explain_adjacent(
1239                connection,
1240                cache,
1241                node_explanation_cache,
1242                use_unoptimized,
1243            ));
1244        }
1245        proof
1246    }
1247
1248    fn explain_adjacent(
1249        &self,
1250        connection: Connection,
1251        cache: &mut ExplainCache<L>,
1252        node_explanation_cache: &mut NodeExplanationCache<L>,
1253        use_unoptimized: bool,
1254    ) -> Rc<TreeTerm<L>> {
1255        let fingerprint = (connection.current, connection.next);
1256
1257        if let Some(answer) = cache.get(&fingerprint) {
1258            return answer.clone();
1259        }
1260
1261        let term = match connection.justification {
1262            Justification::Rule(name) => {
1263                let mut rewritten =
1264                    (*self.node_to_explanation(connection.next, node_explanation_cache)).clone();
1265                if connection.is_rewrite_forward {
1266                    rewritten.forward_rule = Some(name);
1267                } else {
1268                    rewritten.backward_rule = Some(name);
1269                }
1270
1271                rewritten.current = connection.next;
1272                rewritten.last = connection.current;
1273
1274                Rc::new(rewritten)
1275            }
1276            Justification::Congruence => {
1277                // add the children proofs to the last explanation
1278                let current_node = self.node(connection.current);
1279                let next_node = self.node(connection.next);
1280                assert!(current_node.matches(next_node));
1281                let mut subproofs = vec![];
1282
1283                for (left_child, right_child) in current_node
1284                    .children()
1285                    .iter()
1286                    .zip(next_node.children().iter())
1287                {
1288                    subproofs.push(self.explain_enodes(
1289                        *left_child,
1290                        *right_child,
1291                        cache,
1292                        node_explanation_cache,
1293                        use_unoptimized,
1294                    ));
1295                }
1296                Rc::new(TreeTerm::new(current_node.clone(), subproofs))
1297            }
1298        };
1299
1300        cache.insert(fingerprint, term.clone());
1301
1302        term
1303    }
1304
1305    fn find_all_enodes(&self, eclass: Id) -> HashSet<Id> {
1306        let mut enodes = HashSet::default();
1307        let mut todo = vec![eclass];
1308
1309        while let Some(current) = todo.pop() {
1310            if enodes.insert(current) {
1311                for neighbor in &self.explainfind[usize::from(current)].neighbors {
1312                    todo.push(neighbor.next);
1313                }
1314            }
1315        }
1316        enodes
1317    }
1318
1319    fn add_tree_depths(&self, node: Id, depths: &mut HashMap<Id, ProofCost>) -> ProofCost {
1320        if depths.get(&node).is_none() {
1321            let parent = self.parent(node);
1322            let depth = if parent == node {
1323                BigUint::zero()
1324            } else {
1325                self.add_tree_depths(parent, depths) + 1_u32
1326            };
1327
1328            depths.insert(node, depth);
1329        }
1330
1331        depths.get(&node).unwrap().clone()
1332    }
1333
1334    fn calculate_tree_depths(&self) -> HashMap<Id, ProofCost> {
1335        let mut depths = HashMap::default();
1336        for i in 0..self.explainfind.len() {
1337            self.add_tree_depths(Id::from(i), &mut depths);
1338        }
1339        depths
1340    }
1341
1342    fn replace_distance(&mut self, current: Id, next: Id, right: Id, distance: ProofCost) {
1343        self.shortest_explanation_memo
1344            .insert((current, right), (distance, next));
1345    }
1346
1347    fn populate_path_length(
1348        &mut self,
1349        right: Id,
1350        left_connections: &[Connection],
1351        distance_memo: &mut DistanceMemo,
1352    ) {
1353        self.shortest_explanation_memo
1354            .insert((right, right), (BigUint::zero(), right));
1355        for connection in left_connections.iter().rev() {
1356            let next = connection.next;
1357            let current = connection.current;
1358            let next_cost = self
1359                .shortest_explanation_memo
1360                .get(&(next, right))
1361                .unwrap()
1362                .0
1363                .clone();
1364            let dist = self.connection_distance(connection, distance_memo);
1365            self.replace_distance(current, next, right, next_cost + dist);
1366        }
1367    }
1368
1369    fn distance_between(
1370        &mut self,
1371        left: Id,
1372        right: Id,
1373        distance_memo: &mut DistanceMemo,
1374    ) -> ProofCost {
1375        if left == right {
1376            return BigUint::zero();
1377        }
1378        let ancestor = if let Some(a) = distance_memo.common_ancestor.get(&(left, right)) {
1379            *a
1380        } else {
1381            // fall back on calculating ancestor for top-level query (not from congruence)
1382            self.common_ancestor(left, right)
1383        };
1384        // calculate edges until you are past the ancestor
1385        self.calculate_parent_distance(left, ancestor, distance_memo);
1386        self.calculate_parent_distance(right, ancestor, distance_memo);
1387
1388        // now all three share an ancestor
1389        let a = self.calculate_parent_distance(ancestor, Id::from(usize::MAX), distance_memo);
1390        let b = self.calculate_parent_distance(left, Id::from(usize::MAX), distance_memo);
1391        let c = self.calculate_parent_distance(right, Id::from(usize::MAX), distance_memo);
1392
1393        assert!(
1394            distance_memo.parent_distance[usize::from(ancestor)].0
1395                == distance_memo.parent_distance[usize::from(left)].0
1396        );
1397        assert!(
1398            distance_memo.parent_distance[usize::from(ancestor)].0
1399                == distance_memo.parent_distance[usize::from(right)].0
1400        );
1401
1402        // calculate distance to find upper bound
1403        b + c - (a << 1)
1404
1405        //assert_eq!(dist+1, Explanation::new(self.explain_enodes(left, right, &mut Default::default())).make_flat_explanation().len());
1406    }
1407
1408    fn congruence_distance(
1409        &mut self,
1410        current: Id,
1411        next: Id,
1412        distance_memo: &mut DistanceMemo,
1413    ) -> ProofCost {
1414        let current_node = self.node(current).clone();
1415        let next_node = self.node(next).clone();
1416        let mut cost: ProofCost = BigUint::zero();
1417        for (left_child, right_child) in current_node
1418            .children()
1419            .iter()
1420            .zip(next_node.children().iter())
1421        {
1422            cost += self.distance_between(*left_child, *right_child, distance_memo);
1423        }
1424        cost
1425    }
1426
1427    fn connection_distance(
1428        &mut self,
1429        connection: &Connection,
1430        distance_memo: &mut DistanceMemo,
1431    ) -> ProofCost {
1432        match connection.justification {
1433            Justification::Congruence => {
1434                self.congruence_distance(connection.current, connection.next, distance_memo)
1435            }
1436            Justification::Rule(_) => BigUint::one(),
1437        }
1438    }
1439
1440    fn calculate_parent_distance(
1441        &mut self,
1442        enode: Id,
1443        ancestor: Id,
1444        distance_memo: &mut DistanceMemo,
1445    ) -> ProofCost {
1446        loop {
1447            let parent = distance_memo.parent_distance[usize::from(enode)].0;
1448            let dist = distance_memo.parent_distance[usize::from(enode)].1.clone();
1449            if self.parent(parent) == parent {
1450                break;
1451            }
1452
1453            let parent_parent = distance_memo.parent_distance[usize::from(parent)].0;
1454            if parent_parent != parent {
1455                let new_dist = dist + distance_memo.parent_distance[usize::from(parent)].1.clone();
1456                distance_memo.parent_distance[usize::from(enode)] = (parent_parent, new_dist);
1457            } else {
1458                if ancestor == Id::from(usize::MAX) {
1459                    break;
1460                }
1461                if distance_memo.tree_depth.get(&parent).unwrap()
1462                    <= distance_memo.tree_depth.get(&ancestor).unwrap()
1463                {
1464                    break;
1465                }
1466
1467                // find the length of one parent connection
1468                let connection = &self.explainfind[usize::from(parent)].parent_connection;
1469                let current = connection.current;
1470                let next = connection.next;
1471                let cost = match connection.justification {
1472                    Justification::Congruence => {
1473                        self.congruence_distance(current, next, distance_memo)
1474                    }
1475                    Justification::Rule(_) => BigUint::one(),
1476                };
1477                distance_memo.parent_distance[usize::from(parent)] = (self.parent(parent), cost);
1478            }
1479        }
1480
1481        //assert_eq!(distance_memo.parent_distance[usize::from(enode)].1+1,
1482        //Explanation::new(self.explain_enodes(enode, distance_memo.parent_distance[usize::from(enode)].0, &mut Default::default())).make_flat_explanation().len());
1483
1484        distance_memo.parent_distance[usize::from(enode)].1.clone()
1485    }
1486
1487    fn find_congruence_neighbors<N: Analysis<L>>(
1488        &self,
1489        classes: &HashMap<Id, EClass<L, N::Data>>,
1490        congruence_neighbors: &mut [Vec<Id>],
1491        unionfind: &UnionFind,
1492    ) {
1493        let mut counter = 0;
1494        // add the normal congruence edges first
1495        for node in &self.explainfind {
1496            if let Justification::Congruence = node.parent_connection.justification {
1497                let current = node.parent_connection.current;
1498                let next = node.parent_connection.next;
1499                congruence_neighbors[usize::from(current)].push(next);
1500                congruence_neighbors[usize::from(next)].push(current);
1501                counter += 1;
1502            }
1503        }
1504
1505        'outer: for eclass in classes.keys() {
1506            let enodes = self.find_all_enodes(*eclass);
1507            // find all congruence nodes
1508            let mut cannon_enodes: HashMap<L, Vec<Id>> = Default::default();
1509            for enode in &enodes {
1510                let cannon = self
1511                    .node(*enode)
1512                    .clone()
1513                    .map_children(|child| unionfind.find(child));
1514                if let Some(others) = cannon_enodes.get_mut(&cannon) {
1515                    for other in others.iter() {
1516                        congruence_neighbors[usize::from(*enode)].push(*other);
1517                        congruence_neighbors[usize::from(*other)].push(*enode);
1518                        counter += 1;
1519                    }
1520                    others.push(*enode);
1521                } else {
1522                    counter += 1;
1523                    cannon_enodes.insert(cannon, vec![*enode]);
1524                }
1525                // Don't find every congruence edge because that could be n^2 edges
1526                if counter > CONGRUENCE_LIMIT * self.explainfind.len() {
1527                    break 'outer;
1528                }
1529            }
1530        }
1531    }
1532
1533    pub fn get_num_congr<N: Analysis<L>>(
1534        &self,
1535        classes: &HashMap<Id, EClass<L, N::Data>>,
1536        unionfind: &UnionFind,
1537    ) -> usize {
1538        let mut congruence_neighbors = vec![vec![]; self.explainfind.len()];
1539        self.find_congruence_neighbors::<N>(classes, &mut congruence_neighbors, unionfind);
1540        let mut count = 0;
1541        for v in congruence_neighbors {
1542            count += v.len();
1543        }
1544
1545        count / 2
1546    }
1547
1548    pub fn get_num_nodes(&self) -> usize {
1549        self.explainfind.len()
1550    }
1551
1552    fn shortest_path_modulo_congruence(
1553        &mut self,
1554        start: Id,
1555        end: Id,
1556        congruence_neighbors: &[Vec<Id>],
1557        distance_memo: &mut DistanceMemo,
1558    ) -> Option<(Vec<Connection>, Vec<Connection>)> {
1559        let mut todo = BinaryHeap::new();
1560        todo.push(HeapState {
1561            cost: BigUint::zero(),
1562            item: Connection {
1563                current: start,
1564                next: start,
1565                justification: Justification::Congruence,
1566                is_rewrite_forward: true,
1567            },
1568        });
1569
1570        let mut last = HashMap::default();
1571        let mut path_cost = HashMap::default();
1572
1573        'outer: loop {
1574            if todo.is_empty() {
1575                break 'outer;
1576            }
1577            let state = todo.pop().unwrap();
1578            let connection = state.item;
1579            let cost_so_far = state.cost.clone();
1580            let current = connection.next;
1581
1582            if last.get(&current).is_some() {
1583                continue 'outer;
1584            } else {
1585                last.insert(current, connection);
1586                path_cost.insert(current, cost_so_far.clone());
1587            }
1588
1589            if current == end {
1590                break;
1591            }
1592
1593            for neighbor in &self.explainfind[usize::from(current)].neighbors {
1594                if let Justification::Rule(_) = neighbor.justification {
1595                    let neighbor_cost = cost_so_far.clone() + 1_u32;
1596                    todo.push(HeapState {
1597                        item: neighbor.clone(),
1598                        cost: neighbor_cost,
1599                    });
1600                }
1601            }
1602
1603            for other in congruence_neighbors[usize::from(current)].iter() {
1604                let next = other;
1605                let distance = self.congruence_distance(current, *next, distance_memo);
1606                let next_cost = cost_so_far.clone() + distance;
1607                todo.push(HeapState {
1608                    item: Connection {
1609                        current,
1610                        next: *next,
1611                        justification: Justification::Congruence,
1612                        is_rewrite_forward: true,
1613                    },
1614                    cost: next_cost,
1615                });
1616            }
1617        }
1618
1619        let total_cost = path_cost.get(&end);
1620
1621        let left_connections;
1622        let mut right_connections = vec![];
1623
1624        // we would like to assert that we found a path better than the normal one
1625        // but since proof sizes are saturated this is not true
1626        /*let dist = self.distance_between(start, end, distance_memo);
1627        if *total_cost.unwrap() > dist {
1628            panic!(
1629                "Found cost greater than baseline {} vs {}",
1630                total_cost.unwrap(),
1631                dist
1632            );
1633        }*/
1634        if *total_cost.unwrap() >= self.distance_between(start, end, distance_memo) {
1635            let (a_left_connections, a_right_connections) = self.get_path_unoptimized(start, end);
1636            left_connections = a_left_connections;
1637            right_connections = a_right_connections;
1638        } else {
1639            let mut current = end;
1640            let mut connections = vec![];
1641            while current != start {
1642                let prev = last.get(&current);
1643                if let Some(prev_connection) = prev {
1644                    connections.push(prev_connection.clone());
1645                    current = prev_connection.current;
1646                } else {
1647                    break;
1648                }
1649            }
1650            connections.reverse();
1651            self.populate_path_length(end, &connections, distance_memo);
1652            left_connections = connections;
1653        }
1654
1655        Some((left_connections, right_connections))
1656    }
1657
1658    fn greedy_short_explanations(
1659        &mut self,
1660        start: Id,
1661        end: Id,
1662        congruence_neighbors: &[Vec<Id>],
1663        distance_memo: &mut DistanceMemo,
1664        mut fuel: usize,
1665    ) {
1666        let mut todo_congruence = VecDeque::new();
1667        todo_congruence.push_back((start, end));
1668
1669        while !todo_congruence.is_empty() {
1670            let (start, end) = todo_congruence.pop_front().unwrap();
1671            let eclass_size = self.find_all_enodes(start).len();
1672            if fuel < eclass_size {
1673                continue;
1674            }
1675            fuel = fuel.saturating_sub(eclass_size);
1676
1677            let (left_connections, right_connections) = self
1678                .shortest_path_modulo_congruence(start, end, congruence_neighbors, distance_memo)
1679                .unwrap();
1680
1681            //assert!(Explanation::new(self.explain_enodes(start, end, &mut Default::default())).make_flat_explanation().len()-1 <= total_cost);
1682
1683            for (i, connection) in left_connections
1684                .iter()
1685                .chain(right_connections.iter().rev())
1686                .enumerate()
1687            {
1688                let mut next = connection.next;
1689                let mut current = connection.current;
1690                if i >= left_connections.len() {
1691                    std::mem::swap(&mut next, &mut current);
1692                }
1693                if let Justification::Congruence = connection.justification {
1694                    let current_node = self.node(current).clone();
1695                    let next_node = self.node(next).clone();
1696                    for (left_child, right_child) in current_node
1697                        .children()
1698                        .iter()
1699                        .zip(next_node.children().iter())
1700                    {
1701                        todo_congruence.push_back((*left_child, *right_child));
1702                    }
1703                }
1704            }
1705        }
1706    }
1707
1708    #[allow(clippy::too_many_arguments)]
1709    fn tarjan_ocla(
1710        &self,
1711        enode: Id,
1712        children: &HashMap<Id, Vec<Id>>,
1713        common_ancestor_queries: &HashMap<Id, Vec<Id>>,
1714        black_set: &mut HashSet<Id>,
1715        unionfind: &mut UnionFind,
1716        ancestor: &mut Vec<Id>,
1717        common_ancestor: &mut HashMap<(Id, Id), Id>,
1718    ) {
1719        ancestor[usize::from(enode)] = enode;
1720        for child in children[&enode].iter() {
1721            self.tarjan_ocla(
1722                *child,
1723                children,
1724                common_ancestor_queries,
1725                black_set,
1726                unionfind,
1727                ancestor,
1728                common_ancestor,
1729            );
1730            unionfind.union(enode, *child);
1731            ancestor[usize::from(unionfind.find(enode))] = enode;
1732        }
1733
1734        if common_ancestor_queries.get(&enode).is_some() {
1735            black_set.insert(enode);
1736            for other in common_ancestor_queries.get(&enode).unwrap() {
1737                if black_set.contains(other) {
1738                    let ancestor = ancestor[usize::from(unionfind.find(*other))];
1739                    common_ancestor.insert((enode, *other), ancestor);
1740                    common_ancestor.insert((*other, enode), ancestor);
1741                }
1742            }
1743        }
1744    }
1745
1746    fn parent(&self, enode: Id) -> Id {
1747        self.explainfind[usize::from(enode)].parent_connection.next
1748    }
1749
1750    fn calculate_common_ancestor<N: Analysis<L>>(
1751        &self,
1752        classes: &HashMap<Id, EClass<L, N::Data>>,
1753        congruence_neighbors: &[Vec<Id>],
1754    ) -> HashMap<(Id, Id), Id> {
1755        let mut common_ancestor_queries = HashMap::default();
1756        for (s_int, others) in congruence_neighbors.iter().enumerate() {
1757            let start = &Id::from(s_int);
1758            for other in others {
1759                for (left, right) in self
1760                    .node(*start)
1761                    .children()
1762                    .iter()
1763                    .zip(self.node(*other).children().iter())
1764                {
1765                    if left != right {
1766                        if common_ancestor_queries.get(start).is_none() {
1767                            common_ancestor_queries.insert(*start, vec![]);
1768                        }
1769                        if common_ancestor_queries.get(other).is_none() {
1770                            common_ancestor_queries.insert(*other, vec![]);
1771                        }
1772                        common_ancestor_queries.get_mut(start).unwrap().push(*other);
1773                        common_ancestor_queries.get_mut(other).unwrap().push(*start);
1774                    }
1775                }
1776            }
1777        }
1778
1779        let mut common_ancestor = HashMap::default();
1780        let mut unionfind = UnionFind::default();
1781        let mut ancestor = vec![];
1782        for i in 0..self.explainfind.len() {
1783            unionfind.make_set();
1784            ancestor.push(Id::from(i));
1785        }
1786        for (eclass, _) in classes.iter() {
1787            let enodes = self.find_all_enodes(*eclass);
1788            let mut children: HashMap<Id, Vec<Id>> = HashMap::default();
1789            for enode in &enodes {
1790                children.insert(*enode, vec![]);
1791            }
1792            for enode in &enodes {
1793                if self.parent(*enode) != *enode {
1794                    children.get_mut(&self.parent(*enode)).unwrap().push(*enode);
1795                }
1796            }
1797
1798            let mut black_set = HashSet::default();
1799
1800            let mut parent = *enodes.iter().next().unwrap();
1801            while parent != self.parent(parent) {
1802                parent = self.parent(parent);
1803            }
1804            self.tarjan_ocla(
1805                parent,
1806                &children,
1807                &common_ancestor_queries,
1808                &mut black_set,
1809                &mut unionfind,
1810                &mut ancestor,
1811                &mut common_ancestor,
1812            );
1813        }
1814
1815        common_ancestor
1816    }
1817
1818    fn calculate_shortest_explanations<N: Analysis<L>>(
1819        &mut self,
1820        start: Id,
1821        end: Id,
1822        classes: &HashMap<Id, EClass<L, N::Data>>,
1823        unionfind: &UnionFind,
1824    ) {
1825        let mut congruence_neighbors = vec![vec![]; self.explainfind.len()];
1826        self.find_congruence_neighbors::<N>(classes, &mut congruence_neighbors, unionfind);
1827        let mut parent_distance = vec![(Id::from(0), BigUint::zero()); self.explainfind.len()];
1828        for (i, entry) in parent_distance.iter_mut().enumerate() {
1829            entry.0 = Id::from(i);
1830        }
1831
1832        let mut distance_memo = DistanceMemo {
1833            parent_distance,
1834            common_ancestor: self.calculate_common_ancestor::<N>(classes, &congruence_neighbors),
1835            tree_depth: self.calculate_tree_depths(),
1836        };
1837
1838        let fuel = GREEDY_NUM_ITERS * self.explainfind.len();
1839        self.greedy_short_explanations(start, end, &congruence_neighbors, &mut distance_memo, fuel);
1840    }
1841}
1842
1843#[cfg(test)]
1844mod tests {
1845    use super::super::*;
1846
1847    #[test]
1848    fn simple_explain() {
1849        use SymbolLang as S;
1850
1851        crate::init_logger();
1852        let mut egraph = EGraph::<S, ()>::default().with_explanations_enabled();
1853
1854        let fa = "(f a)".parse().unwrap();
1855        let fb = "(f b)".parse().unwrap();
1856        egraph.add_expr(&fa);
1857        egraph.add_expr(&fb);
1858        egraph.add_expr(&"c".parse().unwrap());
1859        egraph.add_expr(&"d".parse().unwrap());
1860
1861        egraph.union_instantiations(
1862            &"a".parse().unwrap(),
1863            &"c".parse().unwrap(),
1864            &Default::default(),
1865            "ac".to_string(),
1866        );
1867
1868        egraph.union_instantiations(
1869            &"c".parse().unwrap(),
1870            &"d".parse().unwrap(),
1871            &Default::default(),
1872            "cd".to_string(),
1873        );
1874
1875        egraph.union_instantiations(
1876            &"d".parse().unwrap(),
1877            &"b".parse().unwrap(),
1878            &Default::default(),
1879            "db".to_string(),
1880        );
1881
1882        egraph.rebuild();
1883
1884        assert_eq!(egraph.add_expr(&fa), egraph.add_expr(&fb));
1885        assert_eq!(
1886            egraph
1887                .explain_equivalence(&fa, &fb)
1888                .get_flat_strings()
1889                .len(),
1890            4
1891        );
1892        assert_eq!(
1893            egraph
1894                .explain_equivalence(&fa, &fb)
1895                .get_flat_strings()
1896                .len(),
1897            4
1898        );
1899        assert_eq!(
1900            egraph
1901                .explain_equivalence(&fa, &fb)
1902                .get_flat_strings()
1903                .len(),
1904            4
1905        );
1906
1907        egraph.union_instantiations(
1908            &"(f a)".parse().unwrap(),
1909            &"g".parse().unwrap(),
1910            &Default::default(),
1911            "fag".to_string(),
1912        );
1913        egraph.union_instantiations(
1914            &"g".parse().unwrap(),
1915            &"(f b)".parse().unwrap(),
1916            &Default::default(),
1917            "gfb".to_string(),
1918        );
1919
1920        egraph.rebuild();
1921
1922        egraph = egraph.without_explanation_length_optimization();
1923        assert_eq!(
1924            egraph
1925                .explain_equivalence(&fa, &fb)
1926                .get_flat_strings()
1927                .len(),
1928            4
1929        );
1930        egraph = egraph.with_explanation_length_optimization();
1931        assert_eq!(
1932            egraph
1933                .explain_equivalence(&fa, &fb)
1934                .get_flat_strings()
1935                .len(),
1936            3
1937        );
1938
1939        assert_eq!(
1940            egraph
1941                .explain_equivalence(&fa, &fb)
1942                .get_flat_strings()
1943                .len(),
1944            3
1945        );
1946
1947        egraph.dot().to_dot("target/foo.dot").unwrap();
1948    }
1949}
1950
1951#[test]
1952fn simple_explain_union_trusted() {
1953    use crate::{EGraph, SymbolLang};
1954    crate::init_logger();
1955    let mut egraph = EGraph::new(()).with_explanations_enabled();
1956
1957    let a = egraph.add_uncanonical(SymbolLang::leaf("a"));
1958    let b = egraph.add_uncanonical(SymbolLang::leaf("b"));
1959    let c = egraph.add_uncanonical(SymbolLang::leaf("c"));
1960    let d = egraph.add_uncanonical(SymbolLang::leaf("d"));
1961    egraph.union_trusted(a, b, "a=b");
1962    egraph.rebuild();
1963    let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a]));
1964    let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b]));
1965    egraph.union_trusted(c, fa, "c=fa");
1966    egraph.union_trusted(d, fb, "d=fb");
1967    egraph.rebuild();
1968    let mut exp = egraph.explain_equivalence(&"c".parse().unwrap(), &"d".parse().unwrap());
1969    assert_eq!(exp.make_flat_explanation().len(), 4)
1970}