egg/
egraph.rs

1use crate::*;
2use std::{
3    borrow::BorrowMut,
4    fmt::{self, Debug, Display},
5    marker::PhantomData,
6};
7
8#[cfg(feature = "serde-1")]
9use serde::{Deserialize, Serialize};
10
11use log::*;
12
13/** A data structure to keep track of equalities between expressions.
14
15Check out the [background tutorial](crate::tutorials::_01_background)
16for more information on e-graphs in general.
17
18# E-graphs in `egg`
19
20In `egg`, the main types associated with e-graphs are
21[`EGraph`], [`EClass`], [`Language`], and [`Id`].
22
23[`EGraph`] and [`EClass`] are all generic over a
24[`Language`], meaning that types actually floating around in the
25egraph are all user-defined.
26In particular, the e-nodes are elements of your [`Language`].
27[`EGraph`]s and [`EClass`]es are additionally parameterized by some
28[`Analysis`], abritrary data associated with each e-class.
29
30Many methods of [`EGraph`] deal with [`Id`]s, which represent e-classes.
31Because eclasses are frequently merged, many [`Id`]s will refer to the
32same e-class.
33
34You can use the `egraph[id]` syntax to get an [`EClass`] from an [`Id`], because
35[`EGraph`] implements
36`Index` and `IndexMut`.
37
38Enabling the `serde-1` feature on this crate will allow you to
39de/serialize [`EGraph`]s using [`serde`](https://serde.rs/).
40You must call [`EGraph::rebuild`] after deserializing an e-graph!
41
42[`add`]: EGraph::add()
43[`union`]: EGraph::union()
44[`rebuild`]: EGraph::rebuild()
45[equivalence relation]: https://en.wikipedia.org/wiki/Equivalence_relation
46[congruence relation]: https://en.wikipedia.org/wiki/Congruence_relation
47[dot]: Dot
48[extract]: Extractor
49[sound]: https://itinerarium.github.io/phoneme-synthesis/?w=/'igraf/
50**/
51#[derive(Clone)]
52#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
53pub struct EGraph<L: Language, N: Analysis<L>> {
54    /// The `Analysis` given when creating this `EGraph`.
55    pub analysis: N,
56    /// The `Explain` used to explain equivalences in this `EGraph`.
57    pub(crate) explain: Option<Explain<L>>,
58    unionfind: UnionFind,
59    /// Stores the original node represented by each non-canonical id
60    nodes: Vec<L>,
61    /// Stores each enode's `Id`, not the `Id` of the eclass.
62    /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new
63    /// unions can cause them to become out of date.
64    #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
65    memo: HashMap<L, Id>,
66    /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode,
67    /// not the canonical id of the eclass.
68    pending: Vec<Id>,
69    analysis_pending: UniqueQueue<Id>,
70    #[cfg_attr(
71        feature = "serde-1",
72        serde(bound(
73            serialize = "N::Data: Serialize",
74            deserialize = "N::Data: for<'a> Deserialize<'a>",
75        ))
76    )]
77    pub(crate) classes: HashMap<Id, EClass<L, N::Data>>,
78    #[cfg_attr(feature = "serde-1", serde(skip))]
79    #[cfg_attr(feature = "serde-1", serde(default = "default_classes_by_op"))]
80    classes_by_op: HashMap<L::Discriminant, HashSet<Id>>,
81    /// Whether or not reading operation are allowed on this e-graph.
82    /// Mutating operations will set this to `false`, and
83    /// [`EGraph::rebuild`] will set it to true.
84    /// Reading operations require this to be `true`.
85    /// Only manually set it if you know what you're doing.
86    #[cfg_attr(feature = "serde-1", serde(skip))]
87    pub clean: bool,
88}
89
90#[cfg(feature = "serde-1")]
91fn default_classes_by_op<K>() -> HashMap<K, HashSet<Id>> {
92    HashMap::default()
93}
94
95impl<L: Language, N: Analysis<L> + Default> Default for EGraph<L, N> {
96    fn default() -> Self {
97        Self::new(N::default())
98    }
99}
100
101// manual debug impl to avoid L: Language bound on EGraph defn
102impl<L: Language, N: Analysis<L>> Debug for EGraph<L, N> {
103    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
104        f.debug_struct("EGraph")
105            .field("memo", &self.memo)
106            .field("classes", &self.classes)
107            .finish()
108    }
109}
110
111impl<L: Language, N: Analysis<L>> EGraph<L, N> {
112    /// Creates a new, empty `EGraph` with the given `Analysis`
113    pub fn new(analysis: N) -> Self {
114        Self {
115            analysis,
116            classes: Default::default(),
117            unionfind: Default::default(),
118            nodes: Default::default(),
119            clean: false,
120            explain: None,
121            pending: Default::default(),
122            memo: Default::default(),
123            analysis_pending: Default::default(),
124            classes_by_op: Default::default(),
125        }
126    }
127
128    /// Returns an iterator over the eclasses in the egraph.
129    pub fn classes(&self) -> impl ExactSizeIterator<Item = &EClass<L, N::Data>> {
130        self.classes.values()
131    }
132
133    /// Returns an mutating iterator over the eclasses in the egraph.
134    pub fn classes_mut(&mut self) -> impl ExactSizeIterator<Item = &mut EClass<L, N::Data>> {
135        self.classes.values_mut()
136    }
137
138    /// Returns an iterator over the eclasses that contain a given op.
139    pub fn classes_for_op(
140        &self,
141        op: &L::Discriminant,
142    ) -> Option<impl ExactSizeIterator<Item = Id> + '_> {
143        self.classes_by_op.get(op).map(|s| s.iter().copied())
144    }
145
146    /// Exposes the actual nodes in the egraph.
147    ///
148    /// Un-canonical id's can be used to index into this.
149    /// In normal circumstances, you should not need to use this.
150    pub fn nodes(&self) -> &[L] {
151        &self.nodes
152    }
153
154    /// Returns `true` if the egraph is empty
155    /// # Example
156    /// ```
157    /// use egg::{*, SymbolLang as S};
158    /// let mut egraph = EGraph::<S, ()>::default();
159    /// assert!(egraph.is_empty());
160    /// egraph.add(S::leaf("foo"));
161    /// assert!(!egraph.is_empty());
162    /// ```
163    pub fn is_empty(&self) -> bool {
164        self.memo.is_empty()
165    }
166
167    /// Returns the number of enodes in the `EGraph`.
168    ///
169    /// Actually returns the size of the hashcons index.
170    /// # Example
171    /// ```
172    /// use egg::{*, SymbolLang as S};
173    /// let mut egraph = EGraph::<S, ()>::default();
174    /// let x = egraph.add(S::leaf("x"));
175    /// let y = egraph.add(S::leaf("y"));
176    /// // only one eclass
177    /// egraph.union(x, y);
178    /// egraph.rebuild();
179    ///
180    /// assert_eq!(egraph.total_size(), 2);
181    /// assert_eq!(egraph.number_of_classes(), 1);
182    /// ```
183    pub fn total_size(&self) -> usize {
184        self.memo.len()
185    }
186
187    /// Iterates over the classes, returning the total number of nodes.
188    pub fn total_number_of_nodes(&self) -> usize {
189        self.classes().map(|c| c.len()).sum()
190    }
191
192    /// Returns the number of eclasses in the egraph.
193    pub fn number_of_classes(&self) -> usize {
194        self.classes.len()
195    }
196
197    /// Enable explanations for this `EGraph`.
198    /// This allows the egraph to explain why two expressions are
199    /// equivalent with the [`explain_equivalence`](EGraph::explain_equivalence) function.
200    pub fn with_explanations_enabled(mut self) -> Self {
201        if self.explain.is_some() {
202            return self;
203        }
204        if self.total_size() > 0 {
205            panic!("Need to set explanations enabled before adding any expressions to the egraph.");
206        }
207        self.explain = Some(Explain::new());
208        self
209    }
210
211    /// By default, egg runs a greedy algorithm to reduce the size of resulting explanations (without complexity overhead).
212    /// Use this function to turn this algorithm off.
213    pub fn without_explanation_length_optimization(mut self) -> Self {
214        if let Some(explain) = &mut self.explain {
215            explain.optimize_explanation_lengths = false;
216            self
217        } else {
218            panic!("Need to set explanations enabled before setting length optimization.");
219        }
220    }
221
222    /// By default, egg runs a greedy algorithm to reduce the size of resulting explanations (without complexity overhead).
223    /// Use this function to turn this algorithm on again if you have turned it off.
224    pub fn with_explanation_length_optimization(mut self) -> Self {
225        if let Some(explain) = &mut self.explain {
226            explain.optimize_explanation_lengths = true;
227            self
228        } else {
229            panic!("Need to set explanations enabled before setting length optimization.");
230        }
231    }
232
233    /// Make a copy of the egraph with the same nodes, but no unions between them.
234    pub fn copy_without_unions(&self, analysis: N) -> Self {
235        if self.explain.is_none() {
236            panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions");
237        }
238        let mut egraph = Self::new(analysis);
239        for node in &self.nodes {
240            egraph.add(node.clone());
241        }
242        egraph
243    }
244
245    /// Performs the union between two egraphs.
246    pub fn egraph_union(&mut self, other: &EGraph<L, N>) {
247        let right_unions = other.get_union_equalities();
248        for (left, right, why) in right_unions {
249            self.union_instantiations(
250                &other.id_to_pattern(left, &Default::default()).0.ast,
251                &other.id_to_pattern(right, &Default::default()).0.ast,
252                &Default::default(),
253                why,
254            );
255        }
256        self.rebuild();
257    }
258
259    fn from_enodes(enodes: Vec<(L, Id)>, analysis: N) -> Self {
260        let mut egraph = Self::new(analysis);
261        let mut ids: HashMap<Id, Id> = Default::default();
262
263        loop {
264            let mut did_something = false;
265
266            for (enode, id) in &enodes {
267                let valid = enode.children().iter().all(|c| ids.contains_key(c));
268                if !valid {
269                    continue;
270                }
271
272                let mut enode = enode.clone().map_children(|c| ids[&c]);
273
274                if egraph.lookup(&mut enode).is_some() {
275                    continue;
276                }
277
278                let added = egraph.add(enode);
279                if let Some(existing) = ids.get(id) {
280                    egraph.union(*existing, added);
281                } else {
282                    ids.insert(*id, added);
283                }
284
285                did_something = true;
286            }
287
288            if !did_something {
289                break;
290            }
291        }
292
293        egraph
294    }
295
296    /// A intersection algorithm between two egraphs.
297    /// The intersection is correct for all terms that are equal in both egraphs.
298    /// Be wary, though, because terms which are not represented in both egraphs
299    /// are not captured in the intersection.
300    /// The runtime of this algorithm is O(|E1| * |E2|), where |E1| and |E2| are the number of enodes in each egraph.
301    pub fn egraph_intersect(&self, other: &EGraph<L, N>, analysis: N) -> EGraph<L, N> {
302        let mut product_map: HashMap<(Id, Id), Id> = Default::default();
303        let mut enodes = vec![];
304
305        for class1 in self.classes() {
306            for class2 in other.classes() {
307                self.intersect_classes(other, &mut enodes, class1.id, class2.id, &mut product_map);
308            }
309        }
310
311        Self::from_enodes(enodes, analysis)
312    }
313
314    fn get_product_id(class1: Id, class2: Id, product_map: &mut HashMap<(Id, Id), Id>) -> Id {
315        if let Some(id) = product_map.get(&(class1, class2)) {
316            *id
317        } else {
318            let id = Id::from(product_map.len());
319            product_map.insert((class1, class2), id);
320            id
321        }
322    }
323
324    fn intersect_classes(
325        &self,
326        other: &EGraph<L, N>,
327        res: &mut Vec<(L, Id)>,
328        class1: Id,
329        class2: Id,
330        product_map: &mut HashMap<(Id, Id), Id>,
331    ) {
332        let res_id = Self::get_product_id(class1, class2, product_map);
333        for node1 in &self.classes[&class1].nodes {
334            for node2 in &other.classes[&class2].nodes {
335                if node1.matches(node2) {
336                    let children1 = node1.children();
337                    let children2 = node2.children();
338                    let mut new_node = node1.clone();
339                    let children = new_node.children_mut();
340                    for (i, (child1, child2)) in children1.iter().zip(children2.iter()).enumerate()
341                    {
342                        let prod = Self::get_product_id(
343                            self.find(*child1),
344                            other.find(*child2),
345                            product_map,
346                        );
347                        children[i] = prod;
348                    }
349
350                    res.push((new_node, res_id));
351                }
352            }
353        }
354    }
355
356    /// Pick a representative term for a given Id.
357    ///
358    /// Calling this function on an uncanonical `Id` returns a representative based on the how it
359    /// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical),
360    /// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical))
361    pub fn id_to_expr(&self, id: Id) -> RecExpr<L> {
362        let mut res = Default::default();
363        let mut cache = Default::default();
364        self.id_to_expr_internal(&mut res, id, &mut cache);
365        res
366    }
367
368    fn id_to_expr_internal(
369        &self,
370        res: &mut RecExpr<L>,
371        node_id: Id,
372        cache: &mut HashMap<Id, Id>,
373    ) -> Id {
374        if let Some(existing) = cache.get(&node_id) {
375            return *existing;
376        }
377        let new_node = self
378            .id_to_node(node_id)
379            .clone()
380            .map_children(|child| self.id_to_expr_internal(res, child, cache));
381        let res_id = res.add(new_node);
382        cache.insert(node_id, res_id);
383        res_id
384    }
385
386    /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep
387    pub fn id_to_node(&self, id: Id) -> &L {
388        &self.nodes[usize::from(id)]
389    }
390
391    /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term.
392    /// When an eclass listed in the given substitutions is found, it creates a variable.
393    /// It also adds this variable and the corresponding Id value to the resulting [`Subst`]
394    /// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr).
395    pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap<Id, Id>) -> (Pattern<L>, Subst) {
396        let mut res = Default::default();
397        let mut subst = Default::default();
398        let mut cache = Default::default();
399        self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache);
400        (Pattern::new(res), subst)
401    }
402
403    fn id_to_pattern_internal(
404        &self,
405        res: &mut PatternAst<L>,
406        node_id: Id,
407        var_substitutions: &HashMap<Id, Id>,
408        subst: &mut Subst,
409        cache: &mut HashMap<Id, Id>,
410    ) -> Id {
411        if let Some(existing) = cache.get(&node_id) {
412            return *existing;
413        }
414        let res_id = if let Some(existing) = var_substitutions.get(&node_id) {
415            let var = format!("?{}", node_id).parse().unwrap();
416            subst.insert(var, *existing);
417            res.add(ENodeOrVar::Var(var))
418        } else {
419            let new_node = self.id_to_node(node_id).clone().map_children(|child| {
420                self.id_to_pattern_internal(res, child, var_substitutions, subst, cache)
421            });
422            res.add(ENodeOrVar::ENode(new_node))
423        };
424        cache.insert(node_id, res_id);
425        res_id
426    }
427
428    /// Get all the unions ever found in the egraph in terms of enode ids.
429    pub fn get_union_equalities(&self) -> UnionEqualities {
430        if let Some(explain) = &self.explain {
431            explain.get_union_equalities()
432        } else {
433            panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get union equalities");
434        }
435    }
436
437    /// Disable explanations for this `EGraph`.
438    pub fn with_explanations_disabled(mut self) -> Self {
439        self.explain = None;
440        self
441    }
442
443    /// Check if explanations are enabled.
444    pub fn are_explanations_enabled(&self) -> bool {
445        self.explain.is_some()
446    }
447
448    /// Get the number of congruences between nodes in the egraph.
449    /// Only available when explanations are enabled.
450    pub fn get_num_congr(&mut self) -> usize {
451        if let Some(explain) = &mut self.explain {
452            explain
453                .with_nodes(&self.nodes)
454                .get_num_congr::<N>(&self.classes, &self.unionfind)
455        } else {
456            panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
457        }
458    }
459
460    /// Get the number of nodes in the egraph used for explanations.
461    pub fn get_explanation_num_nodes(&mut self) -> usize {
462        if let Some(explain) = &mut self.explain {
463            explain.with_nodes(&self.nodes).get_num_nodes()
464        } else {
465            panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
466        }
467    }
468
469    /// When explanations are enabled, this function
470    /// produces an [`Explanation`] describing why two expressions are equivalent.
471    ///
472    /// The [`Explanation`] can be used in it's default tree form or in a less compact
473    /// flattened form. Each of these also has a s-expression string representation,
474    /// given by [`get_flat_string`](Explanation::get_flat_string) and [`get_string`](Explanation::get_string).
475    pub fn explain_equivalence(
476        &mut self,
477        left_expr: &RecExpr<L>,
478        right_expr: &RecExpr<L>,
479    ) -> Explanation<L> {
480        let left = self.add_expr_uncanonical(left_expr);
481        let right = self.add_expr_uncanonical(right_expr);
482
483        self.explain_id_equivalence(left, right)
484    }
485
486    /// Equivalent to calling [`explain_equivalence`](EGraph::explain_equivalence)`(`[`id_to_expr`](EGraph::id_to_expr)`(left),`
487    /// [`id_to_expr`](EGraph::id_to_expr)`(right))` but more efficient
488    ///
489    /// This function picks representatives using [`id_to_expr`](EGraph::id_to_expr) so choosing
490    /// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important
491    /// to control explanations
492    pub fn explain_id_equivalence(&mut self, left: Id, right: Id) -> Explanation<L> {
493        if self.find(left) != self.find(right) {
494            panic!(
495                "Tried to explain equivalence between non-equal terms {:?} and {:?}",
496                self.id_to_expr(left),
497                self.id_to_expr(left)
498            );
499        }
500        if let Some(explain) = &mut self.explain {
501            explain.with_nodes(&self.nodes).explain_equivalence::<N>(
502                left,
503                right,
504                &mut self.unionfind,
505                &self.classes,
506            )
507        } else {
508            panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
509        }
510    }
511
512    /// Get an explanation for why an expression matches a pattern.
513    pub fn explain_matches(
514        &mut self,
515        left_expr: &RecExpr<L>,
516        right_pattern: &PatternAst<L>,
517        subst: &Subst,
518    ) -> Explanation<L> {
519        let left = self.add_expr_uncanonical(left_expr);
520        let right = self.add_instantiation_noncanonical(right_pattern, subst);
521
522        if self.find(left) != self.find(right) {
523            panic!(
524                "Tried to explain equivalence between non-equal terms {:?} and {:?}",
525                left_expr, right_pattern
526            );
527        }
528        if let Some(explain) = &mut self.explain {
529            explain.with_nodes(&self.nodes).explain_equivalence::<N>(
530                left,
531                right,
532                &mut self.unionfind,
533                &self.classes,
534            )
535        } else {
536            panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.");
537        }
538    }
539
540    /// Canonicalizes an eclass id.
541    ///
542    /// This corresponds to the `find` operation on the egraph's
543    /// underlying unionfind data structure.
544    ///
545    /// # Example
546    /// ```
547    /// use egg::{*, SymbolLang as S};
548    /// let mut egraph = EGraph::<S, ()>::default();
549    /// let x = egraph.add(S::leaf("x"));
550    /// let y = egraph.add(S::leaf("y"));
551    /// assert_ne!(egraph.find(x), egraph.find(y));
552    ///
553    /// egraph.union(x, y);
554    /// egraph.rebuild();
555    /// assert_eq!(egraph.find(x), egraph.find(y));
556    /// ```
557    pub fn find(&self, id: Id) -> Id {
558        self.unionfind.find(id)
559    }
560
561    /// This is private, but internals should use this whenever
562    /// possible because it does path compression.
563    fn find_mut(&mut self, id: Id) -> Id {
564        self.unionfind.find_mut(id)
565    }
566
567    /// Creates a [`Dot`] to visualize this egraph. See [`Dot`].
568    pub fn dot(&self) -> Dot<L, N> {
569        Dot {
570            egraph: self,
571            config: vec![],
572            use_anchors: true,
573        }
574    }
575}
576
577/// Translates `EGraph<L, A>` into `EGraph<L2, A2>`. For common cases, you don't
578/// need to implement this manually. See the provided [`SimpleLanguageMapper`].
579pub trait LanguageMapper<L, A>
580where
581    L: Language,
582    A: Analysis<L>,
583{
584    /// The target language to translate into.
585    type L2: Language;
586
587    /// The target analysis to transate into.
588    type A2: Analysis<Self::L2>;
589
590    /// Translate a node of `L` into a node of `L2`.
591    fn map_node(&self, node: L) -> Self::L2;
592
593    /// Translate `L::Discriminant` into `L2::Discriminant`
594    fn map_discriminant(
595        &self,
596        discriminant: L::Discriminant,
597    ) -> <Self::L2 as Language>::Discriminant;
598
599    /// Translate an analysis of type `A` into an analysis of `A2`.
600    fn map_analysis(&self, analysis: A) -> Self::A2;
601
602    /// Translate `A::Data` into `A2::Data`.
603    fn map_data(&self, data: A::Data) -> <Self::A2 as Analysis<Self::L2>>::Data;
604
605    /// Translate an [`EClass`] over `L` into an [`EClass`] over `L2`.
606    fn map_eclass(
607        &self,
608        src_eclass: EClass<L, A::Data>,
609    ) -> EClass<Self::L2, <Self::A2 as Analysis<Self::L2>>::Data> {
610        EClass {
611            id: src_eclass.id,
612            nodes: src_eclass
613                .nodes
614                .into_iter()
615                .map(|l| self.map_node(l))
616                .collect(),
617            data: self.map_data(src_eclass.data),
618            parents: src_eclass.parents,
619        }
620    }
621
622    /// Map an `EGraph` over `L` into an `EGraph` over `L2`.
623    fn map_egraph(&self, src_egraph: EGraph<L, A>) -> EGraph<Self::L2, Self::A2> {
624        let kv_map = |(k, v): (L, Id)| (self.map_node(k), v);
625        EGraph {
626            analysis: self.map_analysis(src_egraph.analysis),
627            explain: None,
628            unionfind: src_egraph.unionfind,
629            memo: src_egraph.memo.into_iter().map(kv_map).collect(),
630            pending: src_egraph.pending,
631            nodes: src_egraph
632                .nodes
633                .into_iter()
634                .map(|x| self.map_node(x))
635                .collect(),
636            analysis_pending: src_egraph.analysis_pending,
637            classes: src_egraph
638                .classes
639                .into_iter()
640                .map(|(id, eclass)| (id, self.map_eclass(eclass)))
641                .collect(),
642            classes_by_op: src_egraph
643                .classes_by_op
644                .into_iter()
645                .map(|(k, v)| (self.map_discriminant(k), v))
646                .collect(),
647            clean: src_egraph.clean,
648        }
649    }
650}
651
652/// An implementation of [`LanguageMapper`] that can convert an [`EGraph`] over one
653/// language into an [`EGraph`] over a different language in common cases.
654///
655/// Specifically, you can use this if have
656/// [`conversion`](https://doc.rust-lang.org/1.76.0/core/convert/index.html)
657/// implemented between your source and target language, as well as your source and
658/// target analysis.
659///
660/// Here is an example of how to use this. Consider a case where you have a newtype
661/// wrapper over an existing language type:
662///
663/// ```rust
664/// use egg::*;
665///
666/// #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
667/// struct MyLang(SymbolLang);
668/// # impl Language for MyLang {
669/// #     type Discriminant = <SymbolLang as Language>::Discriminant;
670/// #
671/// #     fn matches(&self, other: &Self) -> bool {
672/// #         self.0.matches(&other.0)
673/// #     }
674/// #
675/// #     fn children(&self) -> &[Id] {
676/// #         self.0.children()
677/// #     }
678/// #
679/// #     fn children_mut(&mut self) -> &mut [Id] {
680/// #         self.0.children_mut()
681/// #     }
682/// #
683/// #     fn discriminant(&self) -> Self::Discriminant {
684/// #         self.0.discriminant()
685/// #     }
686/// # }
687///
688/// // some external library function
689/// pub fn external(egraph: EGraph<SymbolLang, ()>) { }
690///
691/// fn do_thing(egraph: EGraph<MyLang, ()>) {
692///   // how do I call external?
693///   external(todo!())
694/// }
695/// ```
696///
697/// By providing an implementation of `From<MyLang> for SymbolLang`, we can
698/// construct `SimpleLanguageMapper` and use it to translate our [`EGraph`] into the
699/// right type.
700///
701/// ```rust
702/// # use egg::*;
703/// # #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
704/// # struct MyLang(SymbolLang);
705/// # impl Language for MyLang {
706/// #     type Discriminant = <SymbolLang as Language>::Discriminant;
707/// #
708/// #     fn matches(&self, other: &Self) -> bool {
709/// #         self.0.matches(&other.0)
710/// #     }
711/// #
712/// #     fn children(&self) -> &[Id] {
713/// #         self.0.children()
714/// #     }
715/// #
716/// #     fn children_mut(&mut self) -> &mut [Id] {
717/// #         self.0.children_mut()
718/// #     }
719/// #
720/// #     fn discriminant(&self) -> Self::Discriminant {
721/// #         self.0.discriminant()
722/// #     }
723/// # }
724/// # pub fn external(egraph: EGraph<SymbolLang, ()>) { }
725/// impl From<MyLang> for SymbolLang {
726///     fn from(value: MyLang) -> Self {
727///         value.0
728///     }
729/// }
730///
731/// fn do_thing(egraph: EGraph<MyLang, ()>) {
732///     external(SimpleLanguageMapper::default().map_egraph(egraph))
733/// }
734/// ```
735///
736/// Note that we do not need to provide any conversion for the analysis, because it
737/// is the same in both source and target e-graphs.
738pub struct SimpleLanguageMapper<L2, A2> {
739    _phantom: PhantomData<(L2, A2)>,
740}
741
742impl<L, A> Default for SimpleLanguageMapper<L, A> {
743    fn default() -> Self {
744        SimpleLanguageMapper {
745            _phantom: PhantomData,
746        }
747    }
748}
749
750impl<L, A, L2, A2> LanguageMapper<L, A> for SimpleLanguageMapper<L2, A2>
751where
752    L: Language,
753    A: Analysis<L>,
754    L2: Language + From<L>,
755    A2: Analysis<L2> + From<A>,
756    <L2 as Language>::Discriminant: From<<L as Language>::Discriminant>,
757    <A2 as Analysis<L2>>::Data: From<<A as Analysis<L>>::Data>,
758{
759    type L2 = L2;
760    type A2 = A2;
761
762    fn map_node(&self, node: L) -> Self::L2 {
763        node.into()
764    }
765
766    fn map_discriminant(
767        &self,
768        discriminant: <L as Language>::Discriminant,
769    ) -> <Self::L2 as Language>::Discriminant {
770        discriminant.into()
771    }
772
773    fn map_analysis(&self, analysis: A) -> Self::A2 {
774        analysis.into()
775    }
776
777    fn map_data(&self, data: <A as Analysis<L>>::Data) -> <Self::A2 as Analysis<Self::L2>>::Data {
778        data.into()
779    }
780}
781
782/// Given an `Id` using the `egraph[id]` syntax, retrieve the e-class.
783impl<L: Language, N: Analysis<L>> std::ops::Index<Id> for EGraph<L, N> {
784    type Output = EClass<L, N::Data>;
785    fn index(&self, id: Id) -> &Self::Output {
786        let id = self.find(id);
787        self.classes
788            .get(&id)
789            .unwrap_or_else(|| panic!("Invalid id {}", id))
790    }
791}
792
793/// Given an `Id` using the `&mut egraph[id]` syntax, retrieve a mutable
794/// reference to the e-class.
795impl<L: Language, N: Analysis<L>> std::ops::IndexMut<Id> for EGraph<L, N> {
796    fn index_mut(&mut self, id: Id) -> &mut Self::Output {
797        let id = self.find_mut(id);
798        self.classes
799            .get_mut(&id)
800            .unwrap_or_else(|| panic!("Invalid id {}", id))
801    }
802}
803
804impl<L: Language, N: Analysis<L>> EGraph<L, N> {
805    /// Adds a [`RecExpr`] to the [`EGraph`], returning the id of the RecExpr's eclass.
806    ///
807    /// # Example
808    /// ```
809    /// use egg::{*, SymbolLang as S};
810    /// let mut egraph = EGraph::<S, ()>::default();
811    /// let x = egraph.add(S::leaf("x"));
812    /// let y = egraph.add(S::leaf("y"));
813    /// let plus = egraph.add(S::new("+", vec![x, y]));
814    /// let plus_recexpr = "(+ x y)".parse().unwrap();
815    /// assert_eq!(plus, egraph.add_expr(&plus_recexpr));
816    /// ```
817    ///
818    /// [`add_expr`]: EGraph::add_expr()
819    pub fn add_expr(&mut self, expr: &RecExpr<L>) -> Id {
820        let id = self.add_expr_uncanonical(expr);
821        self.find(id)
822    }
823
824    /// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical
825    ///
826    /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled
827    pub fn add_expr_uncanonical(&mut self, expr: &RecExpr<L>) -> Id {
828        let mut new_ids = Vec::with_capacity(expr.len());
829        let mut new_node_q = Vec::with_capacity(expr.len());
830        for node in expr {
831            let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
832            let size_before = self.unionfind.size();
833            let next_id = self.add_uncanonical(new_node);
834            if self.unionfind.size() > size_before {
835                new_node_q.push(true);
836            } else {
837                new_node_q.push(false);
838            }
839            new_ids.push(next_id);
840        }
841        *new_ids.last().unwrap()
842    }
843
844    /// Adds a [`Pattern`] and a substitution to the [`EGraph`], returning
845    /// the eclass of the instantiated pattern.
846    pub fn add_instantiation(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
847        let id = self.add_instantiation_noncanonical(pat, subst);
848        self.find(id)
849    }
850
851    /// Similar to [`add_instantiation`](EGraph::add_instantiation) but the `Id` returned may not be
852    /// canonical
853    ///
854    /// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling
855    /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an correspond to the
856    /// instantiation of the pattern
857    fn add_instantiation_noncanonical(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
858        let mut new_ids = Vec::with_capacity(pat.len());
859        let mut new_node_q = Vec::with_capacity(pat.len());
860        for node in pat {
861            match node {
862                ENodeOrVar::Var(var) => {
863                    let id = self.find(subst[*var]);
864                    new_ids.push(id);
865                    new_node_q.push(false);
866                }
867                ENodeOrVar::ENode(node) => {
868                    let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
869                    let size_before = self.unionfind.size();
870                    let next_id = self.add_uncanonical(new_node);
871                    if self.unionfind.size() > size_before {
872                        new_node_q.push(true);
873                    } else {
874                        new_node_q.push(false);
875                    }
876
877                    new_ids.push(next_id);
878                }
879            }
880        }
881        *new_ids.last().unwrap()
882    }
883
884    /// Lookup the eclass of the given enode.
885    ///
886    /// You can pass in either an owned enode or a `&mut` enode,
887    /// in which case the enode's children will be canonicalized.
888    ///
889    /// # Example
890    /// ```
891    /// # use egg::*;
892    /// let mut egraph: EGraph<SymbolLang, ()> = Default::default();
893    /// let a = egraph.add(SymbolLang::leaf("a"));
894    /// let b = egraph.add(SymbolLang::leaf("b"));
895    ///
896    /// // lookup will find this node if its in the egraph
897    /// let mut node_f_ab = SymbolLang::new("f", vec![a, b]);
898    /// assert_eq!(egraph.lookup(node_f_ab.clone()), None);
899    /// let id = egraph.add(node_f_ab.clone());
900    /// assert_eq!(egraph.lookup(node_f_ab.clone()), Some(id));
901    ///
902    /// // if the query node isn't canonical, and its passed in by &mut instead of owned,
903    /// // its children will be canonicalized
904    /// egraph.union(a, b);
905    /// egraph.rebuild();
906    /// assert_eq!(egraph.lookup(&mut node_f_ab), Some(id));
907    /// assert_eq!(node_f_ab, SymbolLang::new("f", vec![a, a]));
908    /// ```
909    pub fn lookup<B>(&self, enode: B) -> Option<Id>
910    where
911        B: BorrowMut<L>,
912    {
913        self.lookup_internal(enode).map(|id| self.find(id))
914    }
915
916    fn lookup_internal<B>(&self, mut enode: B) -> Option<Id>
917    where
918        B: BorrowMut<L>,
919    {
920        let enode = enode.borrow_mut();
921        enode.update_children(|id| self.find(id));
922        self.memo.get(enode).copied()
923    }
924
925    /// Lookup the eclass of the given [`RecExpr`].
926    ///
927    /// Equivalent to the last value in [`EGraph::lookup_expr_ids`].
928    pub fn lookup_expr(&self, expr: &RecExpr<L>) -> Option<Id> {
929        self.lookup_expr_ids(expr)
930            .and_then(|ids| ids.last().copied())
931    }
932
933    /// Lookup the eclasses of all the nodes in the given [`RecExpr`].
934    pub fn lookup_expr_ids(&self, expr: &RecExpr<L>) -> Option<Vec<Id>> {
935        let mut new_ids = Vec::with_capacity(expr.len());
936        for node in expr {
937            let node = node.clone().map_children(|i| new_ids[usize::from(i)]);
938            let id = self.lookup(node)?;
939            new_ids.push(id)
940        }
941        Some(new_ids)
942    }
943
944    /// Adds an enode to the [`EGraph`].
945    ///
946    /// When adding an enode, to the egraph, [`add`] it performs
947    /// _hashconsing_ (sometimes called interning in other contexts).
948    ///
949    /// Hashconsing ensures that only one copy of that enode is in the egraph.
950    /// If a copy is in the egraph, then [`add`] simply returns the id of the
951    /// eclass in which the enode was found.
952    ///
953    /// Like [`union`](EGraph::union), this modifies the e-graph.
954    ///
955    /// [`add`]: EGraph::add()
956    pub fn add(&mut self, enode: L) -> Id {
957        let id = self.add_uncanonical(enode);
958        self.find(id)
959    }
960
961    /// Similar to [`add`](EGraph::add) but the `Id` returned may not be canonical
962    ///
963    /// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will
964    /// correspond to the parameter `enode`
965    ///
966    /// ## Example
967    /// ```
968    /// # use egg::*;
969    /// let mut egraph: EGraph<SymbolLang, ()> = EGraph::default().with_explanations_enabled();
970    /// let a = egraph.add_uncanonical(SymbolLang::leaf("a"));
971    /// let b = egraph.add_uncanonical(SymbolLang::leaf("b"));
972    /// egraph.union(a, b);
973    /// egraph.rebuild();
974    ///
975    /// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a]));
976    /// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b]));
977    ///
978    /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap());
979    /// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap());
980    /// ```
981    ///
982    /// When explanations are not enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will
983    /// produce an expression with equivalent but not necessarily identical children
984    ///
985    /// # Example
986    /// ```
987    /// # use egg::*;
988    /// let mut egraph: EGraph<SymbolLang, ()> = EGraph::default().with_explanations_disabled();
989    /// let a = egraph.add_uncanonical(SymbolLang::leaf("a"));
990    /// let b = egraph.add_uncanonical(SymbolLang::leaf("b"));
991    /// egraph.union(a, b);
992    /// egraph.rebuild();
993    ///
994    /// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a]));
995    /// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b]));
996    ///
997    /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap());
998    /// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap());
999    /// ```
1000    pub fn add_uncanonical(&mut self, mut enode: L) -> Id {
1001        let original = enode.clone();
1002        if let Some(existing_id) = self.lookup_internal(&mut enode) {
1003            let id = self.find(existing_id);
1004            // when explanations are enabled, we need a new representative for this expr
1005            if let Some(explain) = self.explain.as_mut() {
1006                if let Some(existing_explain) = explain.uncanon_memo.get(&original) {
1007                    *existing_explain
1008                } else {
1009                    let new_id = self.unionfind.make_set();
1010                    explain.add(original.clone(), new_id);
1011                    debug_assert_eq!(Id::from(self.nodes.len()), new_id);
1012                    self.nodes.push(original);
1013                    self.unionfind.union(id, new_id);
1014                    explain.union(existing_id, new_id, Justification::Congruence);
1015                    new_id
1016                }
1017            } else {
1018                existing_id
1019            }
1020        } else {
1021            let id = self.make_new_eclass(enode, original.clone());
1022            if let Some(explain) = self.explain.as_mut() {
1023                explain.add(original, id);
1024            }
1025
1026            // now that we updated explanations, run the analysis for the new eclass
1027            N::modify(self, id);
1028            self.clean = false;
1029            id
1030        }
1031    }
1032
1033    /// This function makes a new eclass in the egraph (but doesn't touch explanations)
1034    fn make_new_eclass(&mut self, enode: L, original: L) -> Id {
1035        let id = self.unionfind.make_set();
1036        log::trace!("  ...adding to {}", id);
1037        let class = EClass {
1038            id,
1039            nodes: vec![enode.clone()],
1040            data: N::make(self, &original),
1041            parents: Default::default(),
1042        };
1043
1044        debug_assert_eq!(Id::from(self.nodes.len()), id);
1045        self.nodes.push(original);
1046
1047        // add this enode to the parent lists of its children
1048        enode.for_each(|child| {
1049            self[child].parents.push(id);
1050        });
1051
1052        // TODO is this needed?
1053        self.pending.push(id);
1054
1055        self.classes.insert(id, class);
1056        assert!(self.memo.insert(enode, id).is_none());
1057
1058        id
1059    }
1060
1061    /// Checks whether two [`RecExpr`]s are equivalent.
1062    /// Returns a list of id where both expression are represented.
1063    /// In most cases, there will none or exactly one id.
1064    ///
1065    pub fn equivs(&self, expr1: &RecExpr<L>, expr2: &RecExpr<L>) -> Vec<Id> {
1066        let pat1 = Pattern::from(expr1);
1067        let pat2 = Pattern::from(expr2);
1068        let matches1 = pat1.search(self);
1069        trace!("Matches1: {:?}", matches1);
1070
1071        let matches2 = pat2.search(self);
1072        trace!("Matches2: {:?}", matches2);
1073
1074        let mut equiv_eclasses = Vec::new();
1075
1076        for m1 in &matches1 {
1077            for m2 in &matches2 {
1078                if self.find(m1.eclass) == self.find(m2.eclass) {
1079                    equiv_eclasses.push(m1.eclass)
1080                }
1081            }
1082        }
1083
1084        equiv_eclasses
1085    }
1086
1087    /// Given two patterns and a substitution, add the patterns
1088    /// and union them.
1089    ///
1090    /// When explanations are enabled [`with_explanations_enabled`](Runner::with_explanations_enabled), use
1091    /// this function instead of [`union`](EGraph::union).
1092    ///
1093    /// Returns the id of the new eclass, along with
1094    /// a `bool` indicating whether a union occured.
1095    pub fn union_instantiations(
1096        &mut self,
1097        from_pat: &PatternAst<L>,
1098        to_pat: &PatternAst<L>,
1099        subst: &Subst,
1100        rule_name: impl Into<Symbol>,
1101    ) -> (Id, bool) {
1102        let id1 = self.add_instantiation_noncanonical(from_pat, subst);
1103        let id2 = self.add_instantiation_noncanonical(to_pat, subst);
1104
1105        let did_union = self.perform_union(id1, id2, Some(Justification::Rule(rule_name.into())));
1106        (self.find(id1), did_union)
1107    }
1108
1109    /// Unions two e-classes, using a given reason to justify it.
1110    ///
1111    /// This function picks representatives using [`id_to_expr`](EGraph::id_to_expr) so choosing
1112    /// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important
1113    /// to control explanations
1114    pub fn union_trusted(&mut self, from: Id, to: Id, reason: impl Into<Symbol>) -> bool {
1115        self.perform_union(from, to, Some(Justification::Rule(reason.into())))
1116    }
1117
1118    /// Unions two eclasses given their ids.
1119    ///
1120    /// The given ids need not be canonical.
1121    /// The returned `bool` indicates whether a union is necessary,
1122    /// so it's `false` if they were already equivalent.
1123    ///
1124    /// When explanations are enabled, this function behaves like [`EGraph::union_trusted`],
1125    ///  and it lists the call site as the proof reason.
1126    /// You should prefer [`union_instantiations`](EGraph::union_instantiations) when
1127    ///  you want the proofs to always be meaningful.
1128    /// Alternatively you can use [`EGraph::union_trusted`] using uncanonical `Id`s obtained from
1129    ///  functions like [`EGraph::add_uncanonical`]
1130    /// See [`explain_equivalence`](Runner::explain_equivalence) for a more detailed
1131    /// explanation of the feature.
1132    #[track_caller]
1133    pub fn union(&mut self, id1: Id, id2: Id) -> bool {
1134        if self.explain.is_some() {
1135            let caller = std::panic::Location::caller();
1136            self.union_trusted(id1, id2, caller.to_string())
1137        } else {
1138            self.perform_union(id1, id2, None)
1139        }
1140    }
1141
1142    fn perform_union(&mut self, enode_id1: Id, enode_id2: Id, rule: Option<Justification>) -> bool {
1143        N::pre_union(self, enode_id1, enode_id2, &rule);
1144
1145        self.clean = false;
1146        let mut id1 = self.find_mut(enode_id1);
1147        let mut id2 = self.find_mut(enode_id2);
1148        if id1 == id2 {
1149            if let Some(Justification::Rule(_)) = rule {
1150                if let Some(explain) = &mut self.explain {
1151                    explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap());
1152                }
1153            }
1154            return false;
1155        }
1156        // make sure class2 has fewer parents
1157        let class1_parents = self.classes[&id1].parents.len();
1158        let class2_parents = self.classes[&id2].parents.len();
1159        if class1_parents < class2_parents {
1160            std::mem::swap(&mut id1, &mut id2);
1161        }
1162
1163        if let Some(explain) = &mut self.explain {
1164            explain.union(enode_id1, enode_id2, rule.unwrap());
1165        }
1166
1167        // make id1 the new root
1168        self.unionfind.union(id1, id2);
1169
1170        assert_ne!(id1, id2);
1171        let class2 = self.classes.remove(&id2).unwrap();
1172        let class1 = self.classes.get_mut(&id1).unwrap();
1173        assert_eq!(id1, class1.id);
1174
1175        self.pending.extend(class2.parents.iter().copied());
1176        let did_merge = self.analysis.merge(&mut class1.data, class2.data);
1177        if did_merge.0 {
1178            self.analysis_pending.extend(class1.parents.iter().copied());
1179        }
1180        if did_merge.1 {
1181            self.analysis_pending.extend(class2.parents.iter().copied());
1182        }
1183
1184        concat_vecs(&mut class1.nodes, class2.nodes);
1185        concat_vecs(&mut class1.parents, class2.parents);
1186
1187        N::modify(self, id1);
1188        true
1189    }
1190
1191    /// Update the analysis data of an e-class.
1192    ///
1193    /// This also propagates the changes through the e-graph,
1194    /// so [`Analysis::make`] and [`Analysis::merge`] will get
1195    /// called for other parts of the e-graph on rebuild.
1196    pub fn set_analysis_data(&mut self, id: Id, new_data: N::Data) {
1197        let id = self.find_mut(id);
1198        let class = self.classes.get_mut(&id).unwrap();
1199        class.data = new_data;
1200        self.analysis_pending.extend(class.parents.iter().copied());
1201        N::modify(self, id)
1202    }
1203
1204    /// Returns a more debug-able representation of the egraph.
1205    ///
1206    /// [`EGraph`]s implement [`Debug`], but it ain't pretty. It
1207    /// prints a lot of stuff you probably don't care about.
1208    /// This method returns a wrapper that implements [`Debug`] in a
1209    /// slightly nicer way, just dumping enodes in each eclass.
1210    ///
1211    /// [`Debug`]: std::fmt::Debug
1212    pub fn dump(&self) -> impl Debug + '_ {
1213        EGraphDump(self)
1214    }
1215}
1216
1217impl<L: Language + Display, N: Analysis<L>> EGraph<L, N> {
1218    /// Panic if the given eclass doesn't contain the given patterns
1219    ///
1220    /// Useful for testing.
1221    pub fn check_goals(&self, id: Id, goals: &[Pattern<L>]) {
1222        let (cost, best) = Extractor::new(self, AstSize).find_best(id);
1223        println!("End ({}): {}", cost, best.pretty(80));
1224
1225        for (i, goal) in goals.iter().enumerate() {
1226            println!("Trying to prove goal {}: {}", i, goal.pretty(40));
1227            let matches = goal.search_eclass(self, id);
1228            if matches.is_none() {
1229                let best = Extractor::new(self, AstSize).find_best(id).1;
1230                panic!(
1231                    "Could not prove goal {}:\n\
1232                     {}\n\
1233                     Best thing found:\n\
1234                     {}",
1235                    i,
1236                    goal.pretty(40),
1237                    best.pretty(40),
1238                );
1239            }
1240        }
1241    }
1242}
1243
1244// All the rebuilding stuff
1245impl<L: Language, N: Analysis<L>> EGraph<L, N> {
1246    #[inline(never)]
1247    fn rebuild_classes(&mut self) -> usize {
1248        let mut classes_by_op = std::mem::take(&mut self.classes_by_op);
1249        classes_by_op.values_mut().for_each(|ids| ids.clear());
1250
1251        let mut trimmed = 0;
1252        let uf = &mut self.unionfind;
1253
1254        for class in self.classes.values_mut() {
1255            let old_len = class.len();
1256            class
1257                .nodes
1258                .iter_mut()
1259                .for_each(|n| n.update_children(|id| uf.find_mut(id)));
1260            class.nodes.sort_unstable();
1261            class.nodes.dedup();
1262
1263            trimmed += old_len - class.nodes.len();
1264
1265            let mut add = |n: &L| {
1266                classes_by_op
1267                    .entry(n.discriminant())
1268                    .or_default()
1269                    .insert(class.id)
1270            };
1271
1272            // we can go through the ops in order to dedup them, becaue we
1273            // just sorted them
1274            let mut nodes = class.nodes.iter();
1275            if let Some(mut prev) = nodes.next() {
1276                add(prev);
1277                for n in nodes {
1278                    if !prev.matches(n) {
1279                        add(n);
1280                        prev = n;
1281                    }
1282                }
1283            }
1284        }
1285
1286        #[cfg(debug_assertions)]
1287        for ids in classes_by_op.values_mut() {
1288            let unique: HashSet<Id> = ids.iter().copied().collect();
1289            assert_eq!(ids.len(), unique.len());
1290        }
1291
1292        self.classes_by_op = classes_by_op;
1293        trimmed
1294    }
1295
1296    #[inline(never)]
1297    fn check_memo(&self) -> bool {
1298        let mut test_memo = HashMap::default();
1299
1300        for (&id, class) in self.classes.iter() {
1301            assert_eq!(class.id, id);
1302            for node in &class.nodes {
1303                if let Some(old) = test_memo.insert(node, id) {
1304                    assert_eq!(
1305                        self.find(old),
1306                        self.find(id),
1307                        "Found unexpected equivalence for {:?}\n{:?}\nvs\n{:?}",
1308                        node,
1309                        self[self.find(id)].nodes,
1310                        self[self.find(old)].nodes,
1311                    );
1312                }
1313            }
1314        }
1315
1316        for (n, e) in test_memo {
1317            assert_eq!(e, self.find(e));
1318            assert_eq!(
1319                Some(e),
1320                self.memo.get(n).map(|id| self.find(*id)),
1321                "Entry for {:?} at {} in test_memo was incorrect",
1322                n,
1323                e
1324            );
1325        }
1326
1327        true
1328    }
1329
1330    #[inline(never)]
1331    fn process_unions(&mut self) -> usize {
1332        let mut n_unions = 0;
1333
1334        while !self.pending.is_empty() || !self.analysis_pending.is_empty() {
1335            while let Some(class) = self.pending.pop() {
1336                let mut node = self.nodes[usize::from(class)].clone();
1337                node.update_children(|id| self.find_mut(id));
1338                if let Some(memo_class) = self.memo.insert(node, class) {
1339                    let did_something =
1340                        self.perform_union(memo_class, class, Some(Justification::Congruence));
1341                    n_unions += did_something as usize;
1342                }
1343            }
1344
1345            while let Some(class_id) = self.analysis_pending.pop() {
1346                let node = self.nodes[usize::from(class_id)].clone();
1347                let class_id = self.find_mut(class_id);
1348                let node_data = N::make(self, &node);
1349                let class = self.classes.get_mut(&class_id).unwrap();
1350
1351                let did_merge = self.analysis.merge(&mut class.data, node_data);
1352                if did_merge.0 {
1353                    self.analysis_pending.extend(class.parents.iter().copied());
1354                    N::modify(self, class_id)
1355                }
1356            }
1357        }
1358
1359        assert!(self.pending.is_empty());
1360        assert!(self.analysis_pending.is_empty());
1361
1362        n_unions
1363    }
1364
1365    /// Restores the egraph invariants of congruence and enode uniqueness.
1366    ///
1367    /// As mentioned
1368    /// [in the tutorial](tutorials/_01_background/index.html#invariants-and-rebuilding),
1369    /// `egg` takes a lazy approach to maintaining the egraph invariants.
1370    /// The `rebuild` method allows the user to manually restore those
1371    /// invariants at a time of their choosing. It's a reasonably
1372    /// fast, linear-ish traversal through the egraph.
1373    ///
1374    /// After modifying an e-graph with [`add`](EGraph::add) or
1375    /// [`union`](EGraph::union), you must call `rebuild` to restore
1376    /// invariants before any query operations, otherwise the results
1377    /// may be stale or incorrect.
1378    ///
1379    /// This will set [`EGraph::clean`] to `true`.
1380    ///
1381    /// # Example
1382    /// ```
1383    /// use egg::{*, SymbolLang as S};
1384    /// let mut egraph = EGraph::<S, ()>::default();
1385    /// let x = egraph.add(S::leaf("x"));
1386    /// let y = egraph.add(S::leaf("y"));
1387    /// let ax = egraph.add_expr(&"(+ a x)".parse().unwrap());
1388    /// let ay = egraph.add_expr(&"(+ a y)".parse().unwrap());
1389
1390    /// // Union x and y
1391    /// egraph.union(x, y);
1392    /// // Classes: [x y] [ax] [ay] [a]
1393    /// assert_eq!(egraph.find(x), egraph.find(y));
1394    ///
1395    /// // Rebuilding restores the congruence invariant, finding
1396    /// // that ax and ay are equivalent.
1397    /// egraph.rebuild();
1398    /// // Classes: [x y] [ax ay] [a]
1399    /// assert_eq!(egraph.number_of_classes(), 3);
1400    /// assert_eq!(egraph.find(ax), egraph.find(ay));
1401    /// ```
1402    pub fn rebuild(&mut self) -> usize {
1403        let old_hc_size = self.memo.len();
1404        let old_n_eclasses = self.number_of_classes();
1405
1406        let start = Instant::now();
1407
1408        let n_unions = self.process_unions();
1409        let trimmed_nodes = self.rebuild_classes();
1410
1411        let elapsed = start.elapsed();
1412        info!(
1413            concat!(
1414                "REBUILT! in {}.{:03}s\n",
1415                "  Old: hc size {}, eclasses: {}\n",
1416                "  New: hc size {}, eclasses: {}\n",
1417                "  unions: {}, trimmed nodes: {}"
1418            ),
1419            elapsed.as_secs(),
1420            elapsed.subsec_millis(),
1421            old_hc_size,
1422            old_n_eclasses,
1423            self.memo.len(),
1424            self.number_of_classes(),
1425            n_unions,
1426            trimmed_nodes,
1427        );
1428
1429        debug_assert!(self.check_memo());
1430        self.clean = true;
1431        n_unions
1432    }
1433
1434    pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite<L, N>]) -> bool {
1435        if let Some(explain) = &mut self.explain {
1436            explain.with_nodes(&self.nodes).check_each_explain(rules)
1437        } else {
1438            panic!("Can't check explain when explanations are off");
1439        }
1440    }
1441}
1442
1443struct EGraphDump<'a, L: Language, N: Analysis<L>>(&'a EGraph<L, N>);
1444
1445impl<'a, L: Language, N: Analysis<L>> Debug for EGraphDump<'a, L, N> {
1446    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1447        let mut ids: Vec<Id> = self.0.classes().map(|c| c.id).collect();
1448        ids.sort();
1449        for id in ids {
1450            let mut nodes = self.0[id].nodes.clone();
1451            nodes.sort();
1452            writeln!(f, "{} ({:?}): {:?}", id, self.0[id].data, nodes)?
1453        }
1454        Ok(())
1455    }
1456}
1457
1458#[cfg(test)]
1459mod tests {
1460
1461    use super::*;
1462
1463    #[test]
1464    fn simple_add() {
1465        use SymbolLang as S;
1466
1467        crate::init_logger();
1468        let mut egraph = EGraph::<S, ()>::default();
1469
1470        let x = egraph.add(S::leaf("x"));
1471        let x2 = egraph.add(S::leaf("x"));
1472        let _plus = egraph.add(S::new("+", vec![x, x2]));
1473
1474        egraph.union_instantiations(
1475            &"x".parse().unwrap(),
1476            &"y".parse().unwrap(),
1477            &Default::default(),
1478            "union x and y".to_string(),
1479        );
1480        egraph.rebuild();
1481    }
1482
1483    #[cfg(all(feature = "serde-1", feature = "serde_json"))]
1484    #[test]
1485    fn test_serde() {
1486        fn ser(_: &impl Serialize) {}
1487        fn de<'a>(_: &impl Deserialize<'a>) {}
1488
1489        let mut egraph = EGraph::<SymbolLang, ()>::default();
1490        egraph.add_expr(&"(foo bar baz)".parse().unwrap());
1491        ser(&egraph);
1492        de(&egraph);
1493
1494        let json_rep = serde_json::to_string_pretty(&egraph).unwrap();
1495        println!("{}", json_rep);
1496    }
1497}