egg/egraph.rs
1use crate::*;
2use std::{
3 borrow::BorrowMut,
4 fmt::{self, Debug, Display},
5 marker::PhantomData,
6};
7
8#[cfg(feature = "serde-1")]
9use serde::{Deserialize, Serialize};
10
11use log::*;
12
13/** A data structure to keep track of equalities between expressions.
14
15Check out the [background tutorial](crate::tutorials::_01_background)
16for more information on e-graphs in general.
17
18# E-graphs in `egg`
19
20In `egg`, the main types associated with e-graphs are
21[`EGraph`], [`EClass`], [`Language`], and [`Id`].
22
23[`EGraph`] and [`EClass`] are all generic over a
24[`Language`], meaning that types actually floating around in the
25egraph are all user-defined.
26In particular, the e-nodes are elements of your [`Language`].
27[`EGraph`]s and [`EClass`]es are additionally parameterized by some
28[`Analysis`], abritrary data associated with each e-class.
29
30Many methods of [`EGraph`] deal with [`Id`]s, which represent e-classes.
31Because eclasses are frequently merged, many [`Id`]s will refer to the
32same e-class.
33
34You can use the `egraph[id]` syntax to get an [`EClass`] from an [`Id`], because
35[`EGraph`] implements
36`Index` and `IndexMut`.
37
38Enabling the `serde-1` feature on this crate will allow you to
39de/serialize [`EGraph`]s using [`serde`](https://serde.rs/).
40You must call [`EGraph::rebuild`] after deserializing an e-graph!
41
42[`add`]: EGraph::add()
43[`union`]: EGraph::union()
44[`rebuild`]: EGraph::rebuild()
45[equivalence relation]: https://en.wikipedia.org/wiki/Equivalence_relation
46[congruence relation]: https://en.wikipedia.org/wiki/Congruence_relation
47[dot]: Dot
48[extract]: Extractor
49[sound]: https://itinerarium.github.io/phoneme-synthesis/?w=/'igraf/
50**/
51#[derive(Clone)]
52#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
53pub struct EGraph<L: Language, N: Analysis<L>> {
54 /// The `Analysis` given when creating this `EGraph`.
55 pub analysis: N,
56 /// The `Explain` used to explain equivalences in this `EGraph`.
57 pub(crate) explain: Option<Explain<L>>,
58 unionfind: UnionFind,
59 /// Stores the original node represented by each non-canonical id
60 nodes: Vec<L>,
61 /// Stores each enode's `Id`, not the `Id` of the eclass.
62 /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new
63 /// unions can cause them to become out of date.
64 #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
65 memo: HashMap<L, Id>,
66 /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode,
67 /// not the canonical id of the eclass.
68 pending: Vec<Id>,
69 analysis_pending: UniqueQueue<Id>,
70 #[cfg_attr(
71 feature = "serde-1",
72 serde(bound(
73 serialize = "N::Data: Serialize",
74 deserialize = "N::Data: for<'a> Deserialize<'a>",
75 ))
76 )]
77 pub(crate) classes: HashMap<Id, EClass<L, N::Data>>,
78 #[cfg_attr(feature = "serde-1", serde(skip))]
79 #[cfg_attr(feature = "serde-1", serde(default = "default_classes_by_op"))]
80 classes_by_op: HashMap<L::Discriminant, HashSet<Id>>,
81 /// Whether or not reading operation are allowed on this e-graph.
82 /// Mutating operations will set this to `false`, and
83 /// [`EGraph::rebuild`] will set it to true.
84 /// Reading operations require this to be `true`.
85 /// Only manually set it if you know what you're doing.
86 #[cfg_attr(feature = "serde-1", serde(skip))]
87 pub clean: bool,
88}
89
90#[cfg(feature = "serde-1")]
91fn default_classes_by_op<K>() -> HashMap<K, HashSet<Id>> {
92 HashMap::default()
93}
94
95impl<L: Language, N: Analysis<L> + Default> Default for EGraph<L, N> {
96 fn default() -> Self {
97 Self::new(N::default())
98 }
99}
100
101// manual debug impl to avoid L: Language bound on EGraph defn
102impl<L: Language, N: Analysis<L>> Debug for EGraph<L, N> {
103 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
104 f.debug_struct("EGraph")
105 .field("memo", &self.memo)
106 .field("classes", &self.classes)
107 .finish()
108 }
109}
110
111impl<L: Language, N: Analysis<L>> EGraph<L, N> {
112 /// Creates a new, empty `EGraph` with the given `Analysis`
113 pub fn new(analysis: N) -> Self {
114 Self {
115 analysis,
116 classes: Default::default(),
117 unionfind: Default::default(),
118 nodes: Default::default(),
119 clean: false,
120 explain: None,
121 pending: Default::default(),
122 memo: Default::default(),
123 analysis_pending: Default::default(),
124 classes_by_op: Default::default(),
125 }
126 }
127
128 /// Returns an iterator over the eclasses in the egraph.
129 pub fn classes(&self) -> impl ExactSizeIterator<Item = &EClass<L, N::Data>> {
130 self.classes.values()
131 }
132
133 /// Returns an mutating iterator over the eclasses in the egraph.
134 pub fn classes_mut(&mut self) -> impl ExactSizeIterator<Item = &mut EClass<L, N::Data>> {
135 self.classes.values_mut()
136 }
137
138 /// Returns an iterator over the eclasses that contain a given op.
139 pub fn classes_for_op(
140 &self,
141 op: &L::Discriminant,
142 ) -> Option<impl ExactSizeIterator<Item = Id> + '_> {
143 self.classes_by_op.get(op).map(|s| s.iter().copied())
144 }
145
146 /// Exposes the actual nodes in the egraph.
147 ///
148 /// Un-canonical id's can be used to index into this.
149 /// In normal circumstances, you should not need to use this.
150 pub fn nodes(&self) -> &[L] {
151 &self.nodes
152 }
153
154 /// Returns `true` if the egraph is empty
155 /// # Example
156 /// ```
157 /// use egg::{*, SymbolLang as S};
158 /// let mut egraph = EGraph::<S, ()>::default();
159 /// assert!(egraph.is_empty());
160 /// egraph.add(S::leaf("foo"));
161 /// assert!(!egraph.is_empty());
162 /// ```
163 pub fn is_empty(&self) -> bool {
164 self.memo.is_empty()
165 }
166
167 /// Returns the number of enodes in the `EGraph`.
168 ///
169 /// Actually returns the size of the hashcons index.
170 /// # Example
171 /// ```
172 /// use egg::{*, SymbolLang as S};
173 /// let mut egraph = EGraph::<S, ()>::default();
174 /// let x = egraph.add(S::leaf("x"));
175 /// let y = egraph.add(S::leaf("y"));
176 /// // only one eclass
177 /// egraph.union(x, y);
178 /// egraph.rebuild();
179 ///
180 /// assert_eq!(egraph.total_size(), 2);
181 /// assert_eq!(egraph.number_of_classes(), 1);
182 /// ```
183 pub fn total_size(&self) -> usize {
184 self.memo.len()
185 }
186
187 /// Iterates over the classes, returning the total number of nodes.
188 pub fn total_number_of_nodes(&self) -> usize {
189 self.classes().map(|c| c.len()).sum()
190 }
191
192 /// Returns the number of eclasses in the egraph.
193 pub fn number_of_classes(&self) -> usize {
194 self.classes.len()
195 }
196
197 /// Enable explanations for this `EGraph`.
198 /// This allows the egraph to explain why two expressions are
199 /// equivalent with the [`explain_equivalence`](EGraph::explain_equivalence) function.
200 pub fn with_explanations_enabled(mut self) -> Self {
201 if self.explain.is_some() {
202 return self;
203 }
204 if self.total_size() > 0 {
205 panic!("Need to set explanations enabled before adding any expressions to the egraph.");
206 }
207 self.explain = Some(Explain::new());
208 self
209 }
210
211 /// By default, egg runs a greedy algorithm to reduce the size of resulting explanations (without complexity overhead).
212 /// Use this function to turn this algorithm off.
213 pub fn without_explanation_length_optimization(mut self) -> Self {
214 if let Some(explain) = &mut self.explain {
215 explain.optimize_explanation_lengths = false;
216 self
217 } else {
218 panic!("Need to set explanations enabled before setting length optimization.");
219 }
220 }
221
222 /// By default, egg runs a greedy algorithm to reduce the size of resulting explanations (without complexity overhead).
223 /// Use this function to turn this algorithm on again if you have turned it off.
224 pub fn with_explanation_length_optimization(mut self) -> Self {
225 if let Some(explain) = &mut self.explain {
226 explain.optimize_explanation_lengths = true;
227 self
228 } else {
229 panic!("Need to set explanations enabled before setting length optimization.");
230 }
231 }
232
233 /// Make a copy of the egraph with the same nodes, but no unions between them.
234 pub fn copy_without_unions(&self, analysis: N) -> Self {
235 if self.explain.is_none() {
236 panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions");
237 }
238 let mut egraph = Self::new(analysis);
239 for node in &self.nodes {
240 egraph.add(node.clone());
241 }
242 egraph
243 }
244
245 /// Performs the union between two egraphs.
246 pub fn egraph_union(&mut self, other: &EGraph<L, N>) {
247 let right_unions = other.get_union_equalities();
248 for (left, right, why) in right_unions {
249 self.union_instantiations(
250 &other.id_to_pattern(left, &Default::default()).0.ast,
251 &other.id_to_pattern(right, &Default::default()).0.ast,
252 &Default::default(),
253 why,
254 );
255 }
256 self.rebuild();
257 }
258
259 fn from_enodes(enodes: Vec<(L, Id)>, analysis: N) -> Self {
260 let mut egraph = Self::new(analysis);
261 let mut ids: HashMap<Id, Id> = Default::default();
262
263 loop {
264 let mut did_something = false;
265
266 for (enode, id) in &enodes {
267 let valid = enode.children().iter().all(|c| ids.contains_key(c));
268 if !valid {
269 continue;
270 }
271
272 let mut enode = enode.clone().map_children(|c| ids[&c]);
273
274 if egraph.lookup(&mut enode).is_some() {
275 continue;
276 }
277
278 let added = egraph.add(enode);
279 if let Some(existing) = ids.get(id) {
280 egraph.union(*existing, added);
281 } else {
282 ids.insert(*id, added);
283 }
284
285 did_something = true;
286 }
287
288 if !did_something {
289 break;
290 }
291 }
292
293 egraph
294 }
295
296 /// A intersection algorithm between two egraphs.
297 /// The intersection is correct for all terms that are equal in both egraphs.
298 /// Be wary, though, because terms which are not represented in both egraphs
299 /// are not captured in the intersection.
300 /// The runtime of this algorithm is O(|E1| * |E2|), where |E1| and |E2| are the number of enodes in each egraph.
301 pub fn egraph_intersect(&self, other: &EGraph<L, N>, analysis: N) -> EGraph<L, N> {
302 let mut product_map: HashMap<(Id, Id), Id> = Default::default();
303 let mut enodes = vec![];
304
305 for class1 in self.classes() {
306 for class2 in other.classes() {
307 self.intersect_classes(other, &mut enodes, class1.id, class2.id, &mut product_map);
308 }
309 }
310
311 Self::from_enodes(enodes, analysis)
312 }
313
314 fn get_product_id(class1: Id, class2: Id, product_map: &mut HashMap<(Id, Id), Id>) -> Id {
315 if let Some(id) = product_map.get(&(class1, class2)) {
316 *id
317 } else {
318 let id = Id::from(product_map.len());
319 product_map.insert((class1, class2), id);
320 id
321 }
322 }
323
324 fn intersect_classes(
325 &self,
326 other: &EGraph<L, N>,
327 res: &mut Vec<(L, Id)>,
328 class1: Id,
329 class2: Id,
330 product_map: &mut HashMap<(Id, Id), Id>,
331 ) {
332 let res_id = Self::get_product_id(class1, class2, product_map);
333 for node1 in &self.classes[&class1].nodes {
334 for node2 in &other.classes[&class2].nodes {
335 if node1.matches(node2) {
336 let children1 = node1.children();
337 let children2 = node2.children();
338 let mut new_node = node1.clone();
339 let children = new_node.children_mut();
340 for (i, (child1, child2)) in children1.iter().zip(children2.iter()).enumerate()
341 {
342 let prod = Self::get_product_id(
343 self.find(*child1),
344 other.find(*child2),
345 product_map,
346 );
347 children[i] = prod;
348 }
349
350 res.push((new_node, res_id));
351 }
352 }
353 }
354 }
355
356 /// Pick a representative term for a given Id.
357 ///
358 /// Calling this function on an uncanonical `Id` returns a representative based on the how it
359 /// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical),
360 /// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical))
361 pub fn id_to_expr(&self, id: Id) -> RecExpr<L> {
362 let mut res = Default::default();
363 let mut cache = Default::default();
364 self.id_to_expr_internal(&mut res, id, &mut cache);
365 res
366 }
367
368 fn id_to_expr_internal(
369 &self,
370 res: &mut RecExpr<L>,
371 node_id: Id,
372 cache: &mut HashMap<Id, Id>,
373 ) -> Id {
374 if let Some(existing) = cache.get(&node_id) {
375 return *existing;
376 }
377 let new_node = self
378 .id_to_node(node_id)
379 .clone()
380 .map_children(|child| self.id_to_expr_internal(res, child, cache));
381 let res_id = res.add(new_node);
382 cache.insert(node_id, res_id);
383 res_id
384 }
385
386 /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep
387 pub fn id_to_node(&self, id: Id) -> &L {
388 &self.nodes[usize::from(id)]
389 }
390
391 /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term.
392 /// When an eclass listed in the given substitutions is found, it creates a variable.
393 /// It also adds this variable and the corresponding Id value to the resulting [`Subst`]
394 /// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr).
395 pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap<Id, Id>) -> (Pattern<L>, Subst) {
396 let mut res = Default::default();
397 let mut subst = Default::default();
398 let mut cache = Default::default();
399 self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache);
400 (Pattern::new(res), subst)
401 }
402
403 fn id_to_pattern_internal(
404 &self,
405 res: &mut PatternAst<L>,
406 node_id: Id,
407 var_substitutions: &HashMap<Id, Id>,
408 subst: &mut Subst,
409 cache: &mut HashMap<Id, Id>,
410 ) -> Id {
411 if let Some(existing) = cache.get(&node_id) {
412 return *existing;
413 }
414 let res_id = if let Some(existing) = var_substitutions.get(&node_id) {
415 let var = format!("?{}", node_id).parse().unwrap();
416 subst.insert(var, *existing);
417 res.add(ENodeOrVar::Var(var))
418 } else {
419 let new_node = self.id_to_node(node_id).clone().map_children(|child| {
420 self.id_to_pattern_internal(res, child, var_substitutions, subst, cache)
421 });
422 res.add(ENodeOrVar::ENode(new_node))
423 };
424 cache.insert(node_id, res_id);
425 res_id
426 }
427
428 /// Get all the unions ever found in the egraph in terms of enode ids.
429 pub fn get_union_equalities(&self) -> UnionEqualities {
430 if let Some(explain) = &self.explain {
431 explain.get_union_equalities()
432 } else {
433 panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get union equalities");
434 }
435 }
436
437 /// Disable explanations for this `EGraph`.
438 pub fn with_explanations_disabled(mut self) -> Self {
439 self.explain = None;
440 self
441 }
442
443 /// Check if explanations are enabled.
444 pub fn are_explanations_enabled(&self) -> bool {
445 self.explain.is_some()
446 }
447
448 /// Get the number of congruences between nodes in the egraph.
449 /// Only available when explanations are enabled.
450 pub fn get_num_congr(&mut self) -> usize {
451 if let Some(explain) = &mut self.explain {
452 explain
453 .with_nodes(&self.nodes)
454 .get_num_congr::<N>(&self.classes, &self.unionfind)
455 } else {
456 panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
457 }
458 }
459
460 /// Get the number of nodes in the egraph used for explanations.
461 pub fn get_explanation_num_nodes(&mut self) -> usize {
462 if let Some(explain) = &mut self.explain {
463 explain.with_nodes(&self.nodes).get_num_nodes()
464 } else {
465 panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
466 }
467 }
468
469 /// When explanations are enabled, this function
470 /// produces an [`Explanation`] describing why two expressions are equivalent.
471 ///
472 /// The [`Explanation`] can be used in it's default tree form or in a less compact
473 /// flattened form. Each of these also has a s-expression string representation,
474 /// given by [`get_flat_string`](Explanation::get_flat_string) and [`get_string`](Explanation::get_string).
475 pub fn explain_equivalence(
476 &mut self,
477 left_expr: &RecExpr<L>,
478 right_expr: &RecExpr<L>,
479 ) -> Explanation<L> {
480 let left = self.add_expr_uncanonical(left_expr);
481 let right = self.add_expr_uncanonical(right_expr);
482
483 self.explain_id_equivalence(left, right)
484 }
485
486 /// Equivalent to calling [`explain_equivalence`](EGraph::explain_equivalence)`(`[`id_to_expr`](EGraph::id_to_expr)`(left),`
487 /// [`id_to_expr`](EGraph::id_to_expr)`(right))` but more efficient
488 ///
489 /// This function picks representatives using [`id_to_expr`](EGraph::id_to_expr) so choosing
490 /// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important
491 /// to control explanations
492 pub fn explain_id_equivalence(&mut self, left: Id, right: Id) -> Explanation<L> {
493 if self.find(left) != self.find(right) {
494 panic!(
495 "Tried to explain equivalence between non-equal terms {:?} and {:?}",
496 self.id_to_expr(left),
497 self.id_to_expr(left)
498 );
499 }
500 if let Some(explain) = &mut self.explain {
501 explain.with_nodes(&self.nodes).explain_equivalence::<N>(
502 left,
503 right,
504 &mut self.unionfind,
505 &self.classes,
506 )
507 } else {
508 panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
509 }
510 }
511
512 /// Get an explanation for why an expression matches a pattern.
513 pub fn explain_matches(
514 &mut self,
515 left_expr: &RecExpr<L>,
516 right_pattern: &PatternAst<L>,
517 subst: &Subst,
518 ) -> Explanation<L> {
519 let left = self.add_expr_uncanonical(left_expr);
520 let right = self.add_instantiation_noncanonical(right_pattern, subst);
521
522 if self.find(left) != self.find(right) {
523 panic!(
524 "Tried to explain equivalence between non-equal terms {:?} and {:?}",
525 left_expr, right_pattern
526 );
527 }
528 if let Some(explain) = &mut self.explain {
529 explain.with_nodes(&self.nodes).explain_equivalence::<N>(
530 left,
531 right,
532 &mut self.unionfind,
533 &self.classes,
534 )
535 } else {
536 panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.");
537 }
538 }
539
540 /// Canonicalizes an eclass id.
541 ///
542 /// This corresponds to the `find` operation on the egraph's
543 /// underlying unionfind data structure.
544 ///
545 /// # Example
546 /// ```
547 /// use egg::{*, SymbolLang as S};
548 /// let mut egraph = EGraph::<S, ()>::default();
549 /// let x = egraph.add(S::leaf("x"));
550 /// let y = egraph.add(S::leaf("y"));
551 /// assert_ne!(egraph.find(x), egraph.find(y));
552 ///
553 /// egraph.union(x, y);
554 /// egraph.rebuild();
555 /// assert_eq!(egraph.find(x), egraph.find(y));
556 /// ```
557 pub fn find(&self, id: Id) -> Id {
558 self.unionfind.find(id)
559 }
560
561 /// This is private, but internals should use this whenever
562 /// possible because it does path compression.
563 fn find_mut(&mut self, id: Id) -> Id {
564 self.unionfind.find_mut(id)
565 }
566
567 /// Creates a [`Dot`] to visualize this egraph. See [`Dot`].
568 pub fn dot(&self) -> Dot<L, N> {
569 Dot {
570 egraph: self,
571 config: vec![],
572 use_anchors: true,
573 }
574 }
575}
576
577/// Translates `EGraph<L, A>` into `EGraph<L2, A2>`. For common cases, you don't
578/// need to implement this manually. See the provided [`SimpleLanguageMapper`].
579pub trait LanguageMapper<L, A>
580where
581 L: Language,
582 A: Analysis<L>,
583{
584 /// The target language to translate into.
585 type L2: Language;
586
587 /// The target analysis to transate into.
588 type A2: Analysis<Self::L2>;
589
590 /// Translate a node of `L` into a node of `L2`.
591 fn map_node(&self, node: L) -> Self::L2;
592
593 /// Translate `L::Discriminant` into `L2::Discriminant`
594 fn map_discriminant(
595 &self,
596 discriminant: L::Discriminant,
597 ) -> <Self::L2 as Language>::Discriminant;
598
599 /// Translate an analysis of type `A` into an analysis of `A2`.
600 fn map_analysis(&self, analysis: A) -> Self::A2;
601
602 /// Translate `A::Data` into `A2::Data`.
603 fn map_data(&self, data: A::Data) -> <Self::A2 as Analysis<Self::L2>>::Data;
604
605 /// Translate an [`EClass`] over `L` into an [`EClass`] over `L2`.
606 fn map_eclass(
607 &self,
608 src_eclass: EClass<L, A::Data>,
609 ) -> EClass<Self::L2, <Self::A2 as Analysis<Self::L2>>::Data> {
610 EClass {
611 id: src_eclass.id,
612 nodes: src_eclass
613 .nodes
614 .into_iter()
615 .map(|l| self.map_node(l))
616 .collect(),
617 data: self.map_data(src_eclass.data),
618 parents: src_eclass.parents,
619 }
620 }
621
622 /// Map an `EGraph` over `L` into an `EGraph` over `L2`.
623 fn map_egraph(&self, src_egraph: EGraph<L, A>) -> EGraph<Self::L2, Self::A2> {
624 let kv_map = |(k, v): (L, Id)| (self.map_node(k), v);
625 EGraph {
626 analysis: self.map_analysis(src_egraph.analysis),
627 explain: None,
628 unionfind: src_egraph.unionfind,
629 memo: src_egraph.memo.into_iter().map(kv_map).collect(),
630 pending: src_egraph.pending,
631 nodes: src_egraph
632 .nodes
633 .into_iter()
634 .map(|x| self.map_node(x))
635 .collect(),
636 analysis_pending: src_egraph.analysis_pending,
637 classes: src_egraph
638 .classes
639 .into_iter()
640 .map(|(id, eclass)| (id, self.map_eclass(eclass)))
641 .collect(),
642 classes_by_op: src_egraph
643 .classes_by_op
644 .into_iter()
645 .map(|(k, v)| (self.map_discriminant(k), v))
646 .collect(),
647 clean: src_egraph.clean,
648 }
649 }
650}
651
652/// An implementation of [`LanguageMapper`] that can convert an [`EGraph`] over one
653/// language into an [`EGraph`] over a different language in common cases.
654///
655/// Specifically, you can use this if have
656/// [`conversion`](https://doc.rust-lang.org/1.76.0/core/convert/index.html)
657/// implemented between your source and target language, as well as your source and
658/// target analysis.
659///
660/// Here is an example of how to use this. Consider a case where you have a newtype
661/// wrapper over an existing language type:
662///
663/// ```rust
664/// use egg::*;
665///
666/// #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
667/// struct MyLang(SymbolLang);
668/// # impl Language for MyLang {
669/// # type Discriminant = <SymbolLang as Language>::Discriminant;
670/// #
671/// # fn matches(&self, other: &Self) -> bool {
672/// # self.0.matches(&other.0)
673/// # }
674/// #
675/// # fn children(&self) -> &[Id] {
676/// # self.0.children()
677/// # }
678/// #
679/// # fn children_mut(&mut self) -> &mut [Id] {
680/// # self.0.children_mut()
681/// # }
682/// #
683/// # fn discriminant(&self) -> Self::Discriminant {
684/// # self.0.discriminant()
685/// # }
686/// # }
687///
688/// // some external library function
689/// pub fn external(egraph: EGraph<SymbolLang, ()>) { }
690///
691/// fn do_thing(egraph: EGraph<MyLang, ()>) {
692/// // how do I call external?
693/// external(todo!())
694/// }
695/// ```
696///
697/// By providing an implementation of `From<MyLang> for SymbolLang`, we can
698/// construct `SimpleLanguageMapper` and use it to translate our [`EGraph`] into the
699/// right type.
700///
701/// ```rust
702/// # use egg::*;
703/// # #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
704/// # struct MyLang(SymbolLang);
705/// # impl Language for MyLang {
706/// # type Discriminant = <SymbolLang as Language>::Discriminant;
707/// #
708/// # fn matches(&self, other: &Self) -> bool {
709/// # self.0.matches(&other.0)
710/// # }
711/// #
712/// # fn children(&self) -> &[Id] {
713/// # self.0.children()
714/// # }
715/// #
716/// # fn children_mut(&mut self) -> &mut [Id] {
717/// # self.0.children_mut()
718/// # }
719/// #
720/// # fn discriminant(&self) -> Self::Discriminant {
721/// # self.0.discriminant()
722/// # }
723/// # }
724/// # pub fn external(egraph: EGraph<SymbolLang, ()>) { }
725/// impl From<MyLang> for SymbolLang {
726/// fn from(value: MyLang) -> Self {
727/// value.0
728/// }
729/// }
730///
731/// fn do_thing(egraph: EGraph<MyLang, ()>) {
732/// external(SimpleLanguageMapper::default().map_egraph(egraph))
733/// }
734/// ```
735///
736/// Note that we do not need to provide any conversion for the analysis, because it
737/// is the same in both source and target e-graphs.
738pub struct SimpleLanguageMapper<L2, A2> {
739 _phantom: PhantomData<(L2, A2)>,
740}
741
742impl<L, A> Default for SimpleLanguageMapper<L, A> {
743 fn default() -> Self {
744 SimpleLanguageMapper {
745 _phantom: PhantomData,
746 }
747 }
748}
749
750impl<L, A, L2, A2> LanguageMapper<L, A> for SimpleLanguageMapper<L2, A2>
751where
752 L: Language,
753 A: Analysis<L>,
754 L2: Language + From<L>,
755 A2: Analysis<L2> + From<A>,
756 <L2 as Language>::Discriminant: From<<L as Language>::Discriminant>,
757 <A2 as Analysis<L2>>::Data: From<<A as Analysis<L>>::Data>,
758{
759 type L2 = L2;
760 type A2 = A2;
761
762 fn map_node(&self, node: L) -> Self::L2 {
763 node.into()
764 }
765
766 fn map_discriminant(
767 &self,
768 discriminant: <L as Language>::Discriminant,
769 ) -> <Self::L2 as Language>::Discriminant {
770 discriminant.into()
771 }
772
773 fn map_analysis(&self, analysis: A) -> Self::A2 {
774 analysis.into()
775 }
776
777 fn map_data(&self, data: <A as Analysis<L>>::Data) -> <Self::A2 as Analysis<Self::L2>>::Data {
778 data.into()
779 }
780}
781
782/// Given an `Id` using the `egraph[id]` syntax, retrieve the e-class.
783impl<L: Language, N: Analysis<L>> std::ops::Index<Id> for EGraph<L, N> {
784 type Output = EClass<L, N::Data>;
785 fn index(&self, id: Id) -> &Self::Output {
786 let id = self.find(id);
787 self.classes
788 .get(&id)
789 .unwrap_or_else(|| panic!("Invalid id {}", id))
790 }
791}
792
793/// Given an `Id` using the `&mut egraph[id]` syntax, retrieve a mutable
794/// reference to the e-class.
795impl<L: Language, N: Analysis<L>> std::ops::IndexMut<Id> for EGraph<L, N> {
796 fn index_mut(&mut self, id: Id) -> &mut Self::Output {
797 let id = self.find_mut(id);
798 self.classes
799 .get_mut(&id)
800 .unwrap_or_else(|| panic!("Invalid id {}", id))
801 }
802}
803
804impl<L: Language, N: Analysis<L>> EGraph<L, N> {
805 /// Adds a [`RecExpr`] to the [`EGraph`], returning the id of the RecExpr's eclass.
806 ///
807 /// # Example
808 /// ```
809 /// use egg::{*, SymbolLang as S};
810 /// let mut egraph = EGraph::<S, ()>::default();
811 /// let x = egraph.add(S::leaf("x"));
812 /// let y = egraph.add(S::leaf("y"));
813 /// let plus = egraph.add(S::new("+", vec![x, y]));
814 /// let plus_recexpr = "(+ x y)".parse().unwrap();
815 /// assert_eq!(plus, egraph.add_expr(&plus_recexpr));
816 /// ```
817 ///
818 /// [`add_expr`]: EGraph::add_expr()
819 pub fn add_expr(&mut self, expr: &RecExpr<L>) -> Id {
820 let id = self.add_expr_uncanonical(expr);
821 self.find(id)
822 }
823
824 /// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical
825 ///
826 /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled
827 pub fn add_expr_uncanonical(&mut self, expr: &RecExpr<L>) -> Id {
828 let mut new_ids = Vec::with_capacity(expr.len());
829 let mut new_node_q = Vec::with_capacity(expr.len());
830 for node in expr {
831 let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
832 let size_before = self.unionfind.size();
833 let next_id = self.add_uncanonical(new_node);
834 if self.unionfind.size() > size_before {
835 new_node_q.push(true);
836 } else {
837 new_node_q.push(false);
838 }
839 new_ids.push(next_id);
840 }
841 *new_ids.last().unwrap()
842 }
843
844 /// Adds a [`Pattern`] and a substitution to the [`EGraph`], returning
845 /// the eclass of the instantiated pattern.
846 pub fn add_instantiation(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
847 let id = self.add_instantiation_noncanonical(pat, subst);
848 self.find(id)
849 }
850
851 /// Similar to [`add_instantiation`](EGraph::add_instantiation) but the `Id` returned may not be
852 /// canonical
853 ///
854 /// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling
855 /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an correspond to the
856 /// instantiation of the pattern
857 fn add_instantiation_noncanonical(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
858 let mut new_ids = Vec::with_capacity(pat.len());
859 let mut new_node_q = Vec::with_capacity(pat.len());
860 for node in pat {
861 match node {
862 ENodeOrVar::Var(var) => {
863 let id = self.find(subst[*var]);
864 new_ids.push(id);
865 new_node_q.push(false);
866 }
867 ENodeOrVar::ENode(node) => {
868 let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
869 let size_before = self.unionfind.size();
870 let next_id = self.add_uncanonical(new_node);
871 if self.unionfind.size() > size_before {
872 new_node_q.push(true);
873 } else {
874 new_node_q.push(false);
875 }
876
877 new_ids.push(next_id);
878 }
879 }
880 }
881 *new_ids.last().unwrap()
882 }
883
884 /// Lookup the eclass of the given enode.
885 ///
886 /// You can pass in either an owned enode or a `&mut` enode,
887 /// in which case the enode's children will be canonicalized.
888 ///
889 /// # Example
890 /// ```
891 /// # use egg::*;
892 /// let mut egraph: EGraph<SymbolLang, ()> = Default::default();
893 /// let a = egraph.add(SymbolLang::leaf("a"));
894 /// let b = egraph.add(SymbolLang::leaf("b"));
895 ///
896 /// // lookup will find this node if its in the egraph
897 /// let mut node_f_ab = SymbolLang::new("f", vec![a, b]);
898 /// assert_eq!(egraph.lookup(node_f_ab.clone()), None);
899 /// let id = egraph.add(node_f_ab.clone());
900 /// assert_eq!(egraph.lookup(node_f_ab.clone()), Some(id));
901 ///
902 /// // if the query node isn't canonical, and its passed in by &mut instead of owned,
903 /// // its children will be canonicalized
904 /// egraph.union(a, b);
905 /// egraph.rebuild();
906 /// assert_eq!(egraph.lookup(&mut node_f_ab), Some(id));
907 /// assert_eq!(node_f_ab, SymbolLang::new("f", vec![a, a]));
908 /// ```
909 pub fn lookup<B>(&self, enode: B) -> Option<Id>
910 where
911 B: BorrowMut<L>,
912 {
913 self.lookup_internal(enode).map(|id| self.find(id))
914 }
915
916 fn lookup_internal<B>(&self, mut enode: B) -> Option<Id>
917 where
918 B: BorrowMut<L>,
919 {
920 let enode = enode.borrow_mut();
921 enode.update_children(|id| self.find(id));
922 self.memo.get(enode).copied()
923 }
924
925 /// Lookup the eclass of the given [`RecExpr`].
926 ///
927 /// Equivalent to the last value in [`EGraph::lookup_expr_ids`].
928 pub fn lookup_expr(&self, expr: &RecExpr<L>) -> Option<Id> {
929 self.lookup_expr_ids(expr)
930 .and_then(|ids| ids.last().copied())
931 }
932
933 /// Lookup the eclasses of all the nodes in the given [`RecExpr`].
934 pub fn lookup_expr_ids(&self, expr: &RecExpr<L>) -> Option<Vec<Id>> {
935 let mut new_ids = Vec::with_capacity(expr.len());
936 for node in expr {
937 let node = node.clone().map_children(|i| new_ids[usize::from(i)]);
938 let id = self.lookup(node)?;
939 new_ids.push(id)
940 }
941 Some(new_ids)
942 }
943
944 /// Adds an enode to the [`EGraph`].
945 ///
946 /// When adding an enode, to the egraph, [`add`] it performs
947 /// _hashconsing_ (sometimes called interning in other contexts).
948 ///
949 /// Hashconsing ensures that only one copy of that enode is in the egraph.
950 /// If a copy is in the egraph, then [`add`] simply returns the id of the
951 /// eclass in which the enode was found.
952 ///
953 /// Like [`union`](EGraph::union), this modifies the e-graph.
954 ///
955 /// [`add`]: EGraph::add()
956 pub fn add(&mut self, enode: L) -> Id {
957 let id = self.add_uncanonical(enode);
958 self.find(id)
959 }
960
961 /// Similar to [`add`](EGraph::add) but the `Id` returned may not be canonical
962 ///
963 /// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will
964 /// correspond to the parameter `enode`
965 ///
966 /// ## Example
967 /// ```
968 /// # use egg::*;
969 /// let mut egraph: EGraph<SymbolLang, ()> = EGraph::default().with_explanations_enabled();
970 /// let a = egraph.add_uncanonical(SymbolLang::leaf("a"));
971 /// let b = egraph.add_uncanonical(SymbolLang::leaf("b"));
972 /// egraph.union(a, b);
973 /// egraph.rebuild();
974 ///
975 /// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a]));
976 /// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b]));
977 ///
978 /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap());
979 /// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap());
980 /// ```
981 ///
982 /// When explanations are not enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will
983 /// produce an expression with equivalent but not necessarily identical children
984 ///
985 /// # Example
986 /// ```
987 /// # use egg::*;
988 /// let mut egraph: EGraph<SymbolLang, ()> = EGraph::default().with_explanations_disabled();
989 /// let a = egraph.add_uncanonical(SymbolLang::leaf("a"));
990 /// let b = egraph.add_uncanonical(SymbolLang::leaf("b"));
991 /// egraph.union(a, b);
992 /// egraph.rebuild();
993 ///
994 /// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a]));
995 /// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b]));
996 ///
997 /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap());
998 /// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap());
999 /// ```
1000 pub fn add_uncanonical(&mut self, mut enode: L) -> Id {
1001 let original = enode.clone();
1002 if let Some(existing_id) = self.lookup_internal(&mut enode) {
1003 let id = self.find(existing_id);
1004 // when explanations are enabled, we need a new representative for this expr
1005 if let Some(explain) = self.explain.as_mut() {
1006 if let Some(existing_explain) = explain.uncanon_memo.get(&original) {
1007 *existing_explain
1008 } else {
1009 let new_id = self.unionfind.make_set();
1010 explain.add(original.clone(), new_id);
1011 debug_assert_eq!(Id::from(self.nodes.len()), new_id);
1012 self.nodes.push(original);
1013 self.unionfind.union(id, new_id);
1014 explain.union(existing_id, new_id, Justification::Congruence);
1015 new_id
1016 }
1017 } else {
1018 existing_id
1019 }
1020 } else {
1021 let id = self.make_new_eclass(enode, original.clone());
1022 if let Some(explain) = self.explain.as_mut() {
1023 explain.add(original, id);
1024 }
1025
1026 // now that we updated explanations, run the analysis for the new eclass
1027 N::modify(self, id);
1028 self.clean = false;
1029 id
1030 }
1031 }
1032
1033 /// This function makes a new eclass in the egraph (but doesn't touch explanations)
1034 fn make_new_eclass(&mut self, enode: L, original: L) -> Id {
1035 let id = self.unionfind.make_set();
1036 log::trace!(" ...adding to {}", id);
1037 let class = EClass {
1038 id,
1039 nodes: vec![enode.clone()],
1040 data: N::make(self, &original),
1041 parents: Default::default(),
1042 };
1043
1044 debug_assert_eq!(Id::from(self.nodes.len()), id);
1045 self.nodes.push(original);
1046
1047 // add this enode to the parent lists of its children
1048 enode.for_each(|child| {
1049 self[child].parents.push(id);
1050 });
1051
1052 // TODO is this needed?
1053 self.pending.push(id);
1054
1055 self.classes.insert(id, class);
1056 assert!(self.memo.insert(enode, id).is_none());
1057
1058 id
1059 }
1060
1061 /// Checks whether two [`RecExpr`]s are equivalent.
1062 /// Returns a list of id where both expression are represented.
1063 /// In most cases, there will none or exactly one id.
1064 ///
1065 pub fn equivs(&self, expr1: &RecExpr<L>, expr2: &RecExpr<L>) -> Vec<Id> {
1066 let pat1 = Pattern::from(expr1);
1067 let pat2 = Pattern::from(expr2);
1068 let matches1 = pat1.search(self);
1069 trace!("Matches1: {:?}", matches1);
1070
1071 let matches2 = pat2.search(self);
1072 trace!("Matches2: {:?}", matches2);
1073
1074 let mut equiv_eclasses = Vec::new();
1075
1076 for m1 in &matches1 {
1077 for m2 in &matches2 {
1078 if self.find(m1.eclass) == self.find(m2.eclass) {
1079 equiv_eclasses.push(m1.eclass)
1080 }
1081 }
1082 }
1083
1084 equiv_eclasses
1085 }
1086
1087 /// Given two patterns and a substitution, add the patterns
1088 /// and union them.
1089 ///
1090 /// When explanations are enabled [`with_explanations_enabled`](Runner::with_explanations_enabled), use
1091 /// this function instead of [`union`](EGraph::union).
1092 ///
1093 /// Returns the id of the new eclass, along with
1094 /// a `bool` indicating whether a union occured.
1095 pub fn union_instantiations(
1096 &mut self,
1097 from_pat: &PatternAst<L>,
1098 to_pat: &PatternAst<L>,
1099 subst: &Subst,
1100 rule_name: impl Into<Symbol>,
1101 ) -> (Id, bool) {
1102 let id1 = self.add_instantiation_noncanonical(from_pat, subst);
1103 let id2 = self.add_instantiation_noncanonical(to_pat, subst);
1104
1105 let did_union = self.perform_union(id1, id2, Some(Justification::Rule(rule_name.into())));
1106 (self.find(id1), did_union)
1107 }
1108
1109 /// Unions two e-classes, using a given reason to justify it.
1110 ///
1111 /// This function picks representatives using [`id_to_expr`](EGraph::id_to_expr) so choosing
1112 /// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important
1113 /// to control explanations
1114 pub fn union_trusted(&mut self, from: Id, to: Id, reason: impl Into<Symbol>) -> bool {
1115 self.perform_union(from, to, Some(Justification::Rule(reason.into())))
1116 }
1117
1118 /// Unions two eclasses given their ids.
1119 ///
1120 /// The given ids need not be canonical.
1121 /// The returned `bool` indicates whether a union is necessary,
1122 /// so it's `false` if they were already equivalent.
1123 ///
1124 /// When explanations are enabled, this function behaves like [`EGraph::union_trusted`],
1125 /// and it lists the call site as the proof reason.
1126 /// You should prefer [`union_instantiations`](EGraph::union_instantiations) when
1127 /// you want the proofs to always be meaningful.
1128 /// Alternatively you can use [`EGraph::union_trusted`] using uncanonical `Id`s obtained from
1129 /// functions like [`EGraph::add_uncanonical`]
1130 /// See [`explain_equivalence`](Runner::explain_equivalence) for a more detailed
1131 /// explanation of the feature.
1132 #[track_caller]
1133 pub fn union(&mut self, id1: Id, id2: Id) -> bool {
1134 if self.explain.is_some() {
1135 let caller = std::panic::Location::caller();
1136 self.union_trusted(id1, id2, caller.to_string())
1137 } else {
1138 self.perform_union(id1, id2, None)
1139 }
1140 }
1141
1142 fn perform_union(&mut self, enode_id1: Id, enode_id2: Id, rule: Option<Justification>) -> bool {
1143 N::pre_union(self, enode_id1, enode_id2, &rule);
1144
1145 self.clean = false;
1146 let mut id1 = self.find_mut(enode_id1);
1147 let mut id2 = self.find_mut(enode_id2);
1148 if id1 == id2 {
1149 if let Some(Justification::Rule(_)) = rule {
1150 if let Some(explain) = &mut self.explain {
1151 explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap());
1152 }
1153 }
1154 return false;
1155 }
1156 // make sure class2 has fewer parents
1157 let class1_parents = self.classes[&id1].parents.len();
1158 let class2_parents = self.classes[&id2].parents.len();
1159 if class1_parents < class2_parents {
1160 std::mem::swap(&mut id1, &mut id2);
1161 }
1162
1163 if let Some(explain) = &mut self.explain {
1164 explain.union(enode_id1, enode_id2, rule.unwrap());
1165 }
1166
1167 // make id1 the new root
1168 self.unionfind.union(id1, id2);
1169
1170 assert_ne!(id1, id2);
1171 let class2 = self.classes.remove(&id2).unwrap();
1172 let class1 = self.classes.get_mut(&id1).unwrap();
1173 assert_eq!(id1, class1.id);
1174
1175 self.pending.extend(class2.parents.iter().copied());
1176 let did_merge = self.analysis.merge(&mut class1.data, class2.data);
1177 if did_merge.0 {
1178 self.analysis_pending.extend(class1.parents.iter().copied());
1179 }
1180 if did_merge.1 {
1181 self.analysis_pending.extend(class2.parents.iter().copied());
1182 }
1183
1184 concat_vecs(&mut class1.nodes, class2.nodes);
1185 concat_vecs(&mut class1.parents, class2.parents);
1186
1187 N::modify(self, id1);
1188 true
1189 }
1190
1191 /// Update the analysis data of an e-class.
1192 ///
1193 /// This also propagates the changes through the e-graph,
1194 /// so [`Analysis::make`] and [`Analysis::merge`] will get
1195 /// called for other parts of the e-graph on rebuild.
1196 pub fn set_analysis_data(&mut self, id: Id, new_data: N::Data) {
1197 let id = self.find_mut(id);
1198 let class = self.classes.get_mut(&id).unwrap();
1199 class.data = new_data;
1200 self.analysis_pending.extend(class.parents.iter().copied());
1201 N::modify(self, id)
1202 }
1203
1204 /// Returns a more debug-able representation of the egraph.
1205 ///
1206 /// [`EGraph`]s implement [`Debug`], but it ain't pretty. It
1207 /// prints a lot of stuff you probably don't care about.
1208 /// This method returns a wrapper that implements [`Debug`] in a
1209 /// slightly nicer way, just dumping enodes in each eclass.
1210 ///
1211 /// [`Debug`]: std::fmt::Debug
1212 pub fn dump(&self) -> impl Debug + '_ {
1213 EGraphDump(self)
1214 }
1215}
1216
1217impl<L: Language + Display, N: Analysis<L>> EGraph<L, N> {
1218 /// Panic if the given eclass doesn't contain the given patterns
1219 ///
1220 /// Useful for testing.
1221 pub fn check_goals(&self, id: Id, goals: &[Pattern<L>]) {
1222 let (cost, best) = Extractor::new(self, AstSize).find_best(id);
1223 println!("End ({}): {}", cost, best.pretty(80));
1224
1225 for (i, goal) in goals.iter().enumerate() {
1226 println!("Trying to prove goal {}: {}", i, goal.pretty(40));
1227 let matches = goal.search_eclass(self, id);
1228 if matches.is_none() {
1229 let best = Extractor::new(self, AstSize).find_best(id).1;
1230 panic!(
1231 "Could not prove goal {}:\n\
1232 {}\n\
1233 Best thing found:\n\
1234 {}",
1235 i,
1236 goal.pretty(40),
1237 best.pretty(40),
1238 );
1239 }
1240 }
1241 }
1242}
1243
1244// All the rebuilding stuff
1245impl<L: Language, N: Analysis<L>> EGraph<L, N> {
1246 #[inline(never)]
1247 fn rebuild_classes(&mut self) -> usize {
1248 let mut classes_by_op = std::mem::take(&mut self.classes_by_op);
1249 classes_by_op.values_mut().for_each(|ids| ids.clear());
1250
1251 let mut trimmed = 0;
1252 let uf = &mut self.unionfind;
1253
1254 for class in self.classes.values_mut() {
1255 let old_len = class.len();
1256 class
1257 .nodes
1258 .iter_mut()
1259 .for_each(|n| n.update_children(|id| uf.find_mut(id)));
1260 class.nodes.sort_unstable();
1261 class.nodes.dedup();
1262
1263 trimmed += old_len - class.nodes.len();
1264
1265 let mut add = |n: &L| {
1266 classes_by_op
1267 .entry(n.discriminant())
1268 .or_default()
1269 .insert(class.id)
1270 };
1271
1272 // we can go through the ops in order to dedup them, becaue we
1273 // just sorted them
1274 let mut nodes = class.nodes.iter();
1275 if let Some(mut prev) = nodes.next() {
1276 add(prev);
1277 for n in nodes {
1278 if !prev.matches(n) {
1279 add(n);
1280 prev = n;
1281 }
1282 }
1283 }
1284 }
1285
1286 #[cfg(debug_assertions)]
1287 for ids in classes_by_op.values_mut() {
1288 let unique: HashSet<Id> = ids.iter().copied().collect();
1289 assert_eq!(ids.len(), unique.len());
1290 }
1291
1292 self.classes_by_op = classes_by_op;
1293 trimmed
1294 }
1295
1296 #[inline(never)]
1297 fn check_memo(&self) -> bool {
1298 let mut test_memo = HashMap::default();
1299
1300 for (&id, class) in self.classes.iter() {
1301 assert_eq!(class.id, id);
1302 for node in &class.nodes {
1303 if let Some(old) = test_memo.insert(node, id) {
1304 assert_eq!(
1305 self.find(old),
1306 self.find(id),
1307 "Found unexpected equivalence for {:?}\n{:?}\nvs\n{:?}",
1308 node,
1309 self[self.find(id)].nodes,
1310 self[self.find(old)].nodes,
1311 );
1312 }
1313 }
1314 }
1315
1316 for (n, e) in test_memo {
1317 assert_eq!(e, self.find(e));
1318 assert_eq!(
1319 Some(e),
1320 self.memo.get(n).map(|id| self.find(*id)),
1321 "Entry for {:?} at {} in test_memo was incorrect",
1322 n,
1323 e
1324 );
1325 }
1326
1327 true
1328 }
1329
1330 #[inline(never)]
1331 fn process_unions(&mut self) -> usize {
1332 let mut n_unions = 0;
1333
1334 while !self.pending.is_empty() || !self.analysis_pending.is_empty() {
1335 while let Some(class) = self.pending.pop() {
1336 let mut node = self.nodes[usize::from(class)].clone();
1337 node.update_children(|id| self.find_mut(id));
1338 if let Some(memo_class) = self.memo.insert(node, class) {
1339 let did_something =
1340 self.perform_union(memo_class, class, Some(Justification::Congruence));
1341 n_unions += did_something as usize;
1342 }
1343 }
1344
1345 while let Some(class_id) = self.analysis_pending.pop() {
1346 let node = self.nodes[usize::from(class_id)].clone();
1347 let class_id = self.find_mut(class_id);
1348 let node_data = N::make(self, &node);
1349 let class = self.classes.get_mut(&class_id).unwrap();
1350
1351 let did_merge = self.analysis.merge(&mut class.data, node_data);
1352 if did_merge.0 {
1353 self.analysis_pending.extend(class.parents.iter().copied());
1354 N::modify(self, class_id)
1355 }
1356 }
1357 }
1358
1359 assert!(self.pending.is_empty());
1360 assert!(self.analysis_pending.is_empty());
1361
1362 n_unions
1363 }
1364
1365 /// Restores the egraph invariants of congruence and enode uniqueness.
1366 ///
1367 /// As mentioned
1368 /// [in the tutorial](tutorials/_01_background/index.html#invariants-and-rebuilding),
1369 /// `egg` takes a lazy approach to maintaining the egraph invariants.
1370 /// The `rebuild` method allows the user to manually restore those
1371 /// invariants at a time of their choosing. It's a reasonably
1372 /// fast, linear-ish traversal through the egraph.
1373 ///
1374 /// After modifying an e-graph with [`add`](EGraph::add) or
1375 /// [`union`](EGraph::union), you must call `rebuild` to restore
1376 /// invariants before any query operations, otherwise the results
1377 /// may be stale or incorrect.
1378 ///
1379 /// This will set [`EGraph::clean`] to `true`.
1380 ///
1381 /// # Example
1382 /// ```
1383 /// use egg::{*, SymbolLang as S};
1384 /// let mut egraph = EGraph::<S, ()>::default();
1385 /// let x = egraph.add(S::leaf("x"));
1386 /// let y = egraph.add(S::leaf("y"));
1387 /// let ax = egraph.add_expr(&"(+ a x)".parse().unwrap());
1388 /// let ay = egraph.add_expr(&"(+ a y)".parse().unwrap());
1389
1390 /// // Union x and y
1391 /// egraph.union(x, y);
1392 /// // Classes: [x y] [ax] [ay] [a]
1393 /// assert_eq!(egraph.find(x), egraph.find(y));
1394 ///
1395 /// // Rebuilding restores the congruence invariant, finding
1396 /// // that ax and ay are equivalent.
1397 /// egraph.rebuild();
1398 /// // Classes: [x y] [ax ay] [a]
1399 /// assert_eq!(egraph.number_of_classes(), 3);
1400 /// assert_eq!(egraph.find(ax), egraph.find(ay));
1401 /// ```
1402 pub fn rebuild(&mut self) -> usize {
1403 let old_hc_size = self.memo.len();
1404 let old_n_eclasses = self.number_of_classes();
1405
1406 let start = Instant::now();
1407
1408 let n_unions = self.process_unions();
1409 let trimmed_nodes = self.rebuild_classes();
1410
1411 let elapsed = start.elapsed();
1412 info!(
1413 concat!(
1414 "REBUILT! in {}.{:03}s\n",
1415 " Old: hc size {}, eclasses: {}\n",
1416 " New: hc size {}, eclasses: {}\n",
1417 " unions: {}, trimmed nodes: {}"
1418 ),
1419 elapsed.as_secs(),
1420 elapsed.subsec_millis(),
1421 old_hc_size,
1422 old_n_eclasses,
1423 self.memo.len(),
1424 self.number_of_classes(),
1425 n_unions,
1426 trimmed_nodes,
1427 );
1428
1429 debug_assert!(self.check_memo());
1430 self.clean = true;
1431 n_unions
1432 }
1433
1434 pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite<L, N>]) -> bool {
1435 if let Some(explain) = &mut self.explain {
1436 explain.with_nodes(&self.nodes).check_each_explain(rules)
1437 } else {
1438 panic!("Can't check explain when explanations are off");
1439 }
1440 }
1441}
1442
1443struct EGraphDump<'a, L: Language, N: Analysis<L>>(&'a EGraph<L, N>);
1444
1445impl<'a, L: Language, N: Analysis<L>> Debug for EGraphDump<'a, L, N> {
1446 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1447 let mut ids: Vec<Id> = self.0.classes().map(|c| c.id).collect();
1448 ids.sort();
1449 for id in ids {
1450 let mut nodes = self.0[id].nodes.clone();
1451 nodes.sort();
1452 writeln!(f, "{} ({:?}): {:?}", id, self.0[id].data, nodes)?
1453 }
1454 Ok(())
1455 }
1456}
1457
1458#[cfg(test)]
1459mod tests {
1460
1461 use super::*;
1462
1463 #[test]
1464 fn simple_add() {
1465 use SymbolLang as S;
1466
1467 crate::init_logger();
1468 let mut egraph = EGraph::<S, ()>::default();
1469
1470 let x = egraph.add(S::leaf("x"));
1471 let x2 = egraph.add(S::leaf("x"));
1472 let _plus = egraph.add(S::new("+", vec![x, x2]));
1473
1474 egraph.union_instantiations(
1475 &"x".parse().unwrap(),
1476 &"y".parse().unwrap(),
1477 &Default::default(),
1478 "union x and y".to_string(),
1479 );
1480 egraph.rebuild();
1481 }
1482
1483 #[cfg(all(feature = "serde-1", feature = "serde_json"))]
1484 #[test]
1485 fn test_serde() {
1486 fn ser(_: &impl Serialize) {}
1487 fn de<'a>(_: &impl Deserialize<'a>) {}
1488
1489 let mut egraph = EGraph::<SymbolLang, ()>::default();
1490 egraph.add_expr(&"(foo bar baz)".parse().unwrap());
1491 ser(&egraph);
1492 de(&egraph);
1493
1494 let json_rep = serde_json::to_string_pretty(&egraph).unwrap();
1495 println!("{}", json_rep);
1496 }
1497}