1#![allow(clippy::only_used_in_recursion)]
2use crate::Symbol;
3use crate::{
4 util::pretty_print, Analysis, EClass, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language,
5 PatternAst, RecExpr, Rewrite, UnionFind, Var,
6};
7
8use std::cmp::Ordering;
9use std::collections::{BinaryHeap, VecDeque};
10use std::fmt::{self, Debug, Display, Formatter};
11use std::ops::{Deref, DerefMut};
12use std::rc::Rc;
13
14use num_bigint::BigUint;
15use num_traits::identities::{One, Zero};
16use symbolic_expressions::Sexp;
17
18type ProofCost = BigUint;
19
20const CONGRUENCE_LIMIT: usize = 2;
21const GREEDY_NUM_ITERS: usize = 2;
22
23#[derive(Debug, Clone, Hash, PartialEq, Eq)]
26#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
27pub enum Justification {
28 Rule(Symbol),
30 Congruence,
32}
33
34#[derive(Debug, Clone, Hash, PartialEq, Eq)]
35#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
36struct Connection {
37 next: Id,
38 current: Id,
39 justification: Justification,
40 is_rewrite_forward: bool,
41}
42
43#[derive(Debug, Clone)]
44#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
45struct ExplainNode {
46 neighbors: Vec<Connection>,
48 parent_connection: Connection,
49}
50
51#[derive(Debug, Clone)]
52#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
53pub struct Explain<L: Language> {
54 explainfind: Vec<ExplainNode>,
55 #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
56 #[cfg_attr(
57 feature = "serde-1",
58 serde(bound(
59 serialize = "L: serde::Serialize",
60 deserialize = "L: serde::Deserialize<'de>",
61 ))
62 )]
63 pub uncanon_memo: HashMap<L, Id>,
64 pub optimize_explanation_lengths: bool,
66 #[cfg_attr(feature = "serde-1", serde(skip))]
73 shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>,
74}
75
76pub(crate) struct ExplainNodes<'a, L: Language> {
77 explain: &'a mut Explain<L>,
78 nodes: &'a [L],
79}
80
81#[derive(Default)]
82struct DistanceMemo {
83 parent_distance: Vec<(Id, ProofCost)>,
84 common_ancestor: HashMap<(Id, Id), Id>,
85 tree_depth: HashMap<Id, ProofCost>,
86}
87
88pub type TreeExplanation<L> = Vec<Rc<TreeTerm<L>>>;
99
100pub type FlatExplanation<L> = Vec<FlatTerm<L>>;
107
108pub type UnionEqualities = Vec<(Id, Id, Symbol)>;
111
112type ExplainCache<L> = HashMap<(Id, Id), Rc<TreeTerm<L>>>;
114type NodeExplanationCache<L> = HashMap<Id, Rc<TreeTerm<L>>>;
115
116pub struct Explanation<L: Language> {
123 pub explanation_trees: TreeExplanation<L>,
125 flat_explanation: Option<FlatExplanation<L>>,
126}
127
128impl<L: Language + Display + FromOp> Display for Explanation<L> {
129 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
130 let s = self.get_sexp().to_string();
131 f.write_str(&s)
132 }
133}
134
135impl<L: Language + Display + FromOp> Explanation<L> {
136 pub fn get_flat_string(&mut self) -> String {
156 self.get_flat_strings().join("\n")
157 }
158
159 pub fn get_string(&self) -> String {
189 self.to_string()
190 }
191
192 pub fn get_string_with_let(&self) -> String {
224 let mut s = "".to_string();
225 pretty_print(&mut s, &self.get_sexp_with_let(), 100, 0).unwrap();
226 s
227 }
228
229 pub fn get_flat_strings(&mut self) -> Vec<String> {
232 self.make_flat_explanation()
233 .iter()
234 .map(|e| e.to_string())
235 .collect()
236 }
237
238 fn get_sexp(&self) -> Sexp {
239 let mut items = vec![Sexp::String("Explanation".to_string())];
240 for e in self.explanation_trees.iter() {
241 items.push(e.get_sexp());
242 }
243
244 Sexp::List(items)
245 }
246
247 pub fn get_tree_size(&self) -> ProofCost {
250 let mut seen = Default::default();
251 let mut seen_adjacent = Default::default();
252 let mut sum: ProofCost = BigUint::zero();
253 for e in self.explanation_trees.iter() {
254 sum += self.tree_size(&mut seen, &mut seen_adjacent, e);
255 }
256 sum
257 }
258
259 fn tree_size(
260 &self,
261 seen: &mut HashSet<*const TreeTerm<L>>,
262 seen_adjacent: &mut HashSet<(Id, Id)>,
263 current: &Rc<TreeTerm<L>>,
264 ) -> ProofCost {
265 if !seen.insert(&**current as *const TreeTerm<L>) {
266 return BigUint::zero();
267 }
268 let mut my_size: ProofCost = BigUint::zero();
269 if current.forward_rule.is_some() {
270 my_size += 1_u32;
271 }
272 if current.backward_rule.is_some() {
273 my_size += 1_u32;
274 }
275 assert!(my_size.is_zero() || my_size.is_one());
276 if my_size.is_one() {
277 if !seen_adjacent.insert((current.current, current.last)) {
278 return BigUint::zero();
279 } else {
280 seen_adjacent.insert((current.last, current.current));
281 }
282 }
283
284 for child_proof in ¤t.child_proofs {
285 for child in child_proof {
286 my_size += self.tree_size(seen, seen_adjacent, child);
287 }
288 }
289 my_size
290 }
291
292 fn get_sexp_with_let(&self) -> Sexp {
293 let mut shared: HashSet<*const TreeTerm<L>> = Default::default();
294 let mut to_let_bind = vec![];
295 for term in &self.explanation_trees {
296 self.find_to_let_bind(term.clone(), &mut shared, &mut to_let_bind);
297 }
298
299 let mut bindings: HashMap<*const TreeTerm<L>, Sexp> = Default::default();
300 let mut generated_bindings: Vec<(Sexp, Sexp)> = Default::default();
301 for to_bind in to_let_bind {
302 if bindings.get(&(&*to_bind as *const TreeTerm<L>)).is_none() {
303 let name = Sexp::String("v_".to_string() + &generated_bindings.len().to_string());
304 let ast = to_bind.get_sexp_with_bindings(&bindings);
305 generated_bindings.push((name.clone(), ast));
306 bindings.insert(&*to_bind as *const TreeTerm<L>, name);
307 }
308 }
309
310 let mut items = vec![Sexp::String("Explanation".to_string())];
311 for e in self.explanation_trees.iter() {
312 if let Some(existing) = bindings.get(&(&**e as *const TreeTerm<L>)) {
313 items.push(existing.clone());
314 } else {
315 items.push(e.get_sexp_with_bindings(&bindings));
316 }
317 }
318
319 let mut result = Sexp::List(items);
320
321 for (name, expr) in generated_bindings.into_iter().rev() {
322 let let_expr = Sexp::List(vec![name, expr]);
323 result = Sexp::List(vec![Sexp::String("let".to_string()), let_expr, result]);
324 }
325
326 result
327 }
328
329 fn find_to_let_bind(
332 &self,
333 term: Rc<TreeTerm<L>>,
334 shared: &mut HashSet<*const TreeTerm<L>>,
335 to_let_bind: &mut Vec<Rc<TreeTerm<L>>>,
336 ) {
337 if !term.child_proofs.is_empty() {
338 if shared.insert(&*term as *const TreeTerm<L>) {
339 for proof in &term.child_proofs {
340 for child in proof {
341 self.find_to_let_bind(child.clone(), shared, to_let_bind);
342 }
343 }
344 } else {
345 to_let_bind.push(term);
346 }
347 }
348 }
349}
350
351impl<L: Language> Explanation<L> {
352 pub fn new(explanation_trees: TreeExplanation<L>) -> Explanation<L> {
354 Explanation {
355 explanation_trees,
356 flat_explanation: None,
357 }
358 }
359
360 pub fn make_flat_explanation(&mut self) -> &FlatExplanation<L> {
362 if self.flat_explanation.is_some() {
363 return self.flat_explanation.as_ref().unwrap();
364 } else {
365 self.flat_explanation = Some(TreeTerm::flatten_proof(&self.explanation_trees));
366 self.flat_explanation.as_ref().unwrap()
367 }
368 }
369
370 pub fn check_proof<'a, R, N>(&mut self, rules: R)
373 where
374 R: IntoIterator<Item = &'a Rewrite<L, N>>,
375 L: 'a,
376 N: Analysis<L> + 'a,
377 {
378 let rules: Vec<&Rewrite<L, N>> = rules.into_iter().collect();
379 let rule_table = Explain::make_rule_table(rules.as_slice());
380 self.make_flat_explanation();
381 let flat_explanation = self.flat_explanation.as_ref().unwrap();
382 assert!(!flat_explanation[0].has_rewrite_forward());
383 assert!(!flat_explanation[0].has_rewrite_backward());
384 for i in 0..flat_explanation.len() - 1 {
385 let current = &flat_explanation[i];
386 let next = &flat_explanation[i + 1];
387
388 let has_forward = next.has_rewrite_forward();
389 let has_backward = next.has_rewrite_backward();
390 assert!(has_forward ^ has_backward);
391
392 if has_forward {
393 assert!(self.check_rewrite_at(current, next, &rule_table, true));
394 } else {
395 assert!(self.check_rewrite_at(current, next, &rule_table, false));
396 }
397 }
398 }
399
400 fn check_rewrite_at<N: Analysis<L>>(
401 &self,
402 current: &FlatTerm<L>,
403 next: &FlatTerm<L>,
404 table: &HashMap<Symbol, &Rewrite<L, N>>,
405 is_forward: bool,
406 ) -> bool {
407 if is_forward && next.forward_rule.is_some() {
408 let rule_name = next.forward_rule.as_ref().unwrap();
409 if let Some(rule) = table.get(rule_name) {
410 Explanation::check_rewrite(current, next, rule)
411 } else {
412 true
414 }
415 } else if !is_forward && next.backward_rule.is_some() {
416 let rule_name = next.backward_rule.as_ref().unwrap();
417 if let Some(rule) = table.get(rule_name) {
418 Explanation::check_rewrite(next, current, rule)
419 } else {
420 true
421 }
422 } else {
423 for (left, right) in current.children.iter().zip(next.children.iter()) {
424 if !self.check_rewrite_at(left, right, table, is_forward) {
425 return false;
426 }
427 }
428 true
429 }
430 }
431
432 fn check_rewrite<'a, N: Analysis<L>>(
434 current: &'a FlatTerm<L>,
435 next: &'a FlatTerm<L>,
436 rewrite: &Rewrite<L, N>,
437 ) -> bool {
438 if let Some(lhs) = rewrite.searcher.get_pattern_ast() {
439 if let Some(rhs) = rewrite.applier.get_pattern_ast() {
440 let rewritten = current.rewrite(lhs, rhs);
441 if &rewritten != next {
442 return false;
443 }
444 }
445 }
446 true
447 }
448}
449
450#[derive(Debug, Clone)]
465pub struct TreeTerm<L: Language> {
466 pub node: L,
468 pub backward_rule: Option<Symbol>,
470 pub forward_rule: Option<Symbol>,
472 pub child_proofs: Vec<TreeExplanation<L>>,
474
475 last: Id,
476 current: Id,
477}
478
479impl<L: Language> TreeTerm<L> {
480 pub fn new(node: L, child_proofs: Vec<TreeExplanation<L>>) -> TreeTerm<L> {
482 TreeTerm {
483 node,
484 backward_rule: None,
485 forward_rule: None,
486 child_proofs,
487 current: Id::from(0),
488 last: Id::from(0),
489 }
490 }
491
492 fn flatten_proof(proof: &[Rc<TreeTerm<L>>]) -> FlatExplanation<L> {
493 let mut flat_proof: FlatExplanation<L> = vec![];
494 for tree in proof {
495 let mut explanation = tree.flatten_explanation();
496
497 if !flat_proof.is_empty()
498 && !explanation[0].has_rewrite_forward()
499 && !explanation[0].has_rewrite_backward()
500 {
501 let last = flat_proof.pop().unwrap();
502 explanation[0].combine_rewrites(&last);
503 }
504
505 flat_proof.extend(explanation);
506 }
507
508 flat_proof
509 }
510
511 pub fn get_initial_flat_term(&self) -> FlatTerm<L> {
513 FlatTerm {
514 node: self.node.clone(),
515 backward_rule: self.backward_rule,
516 forward_rule: self.forward_rule,
517 children: self
518 .child_proofs
519 .iter()
520 .map(|child_proof| child_proof[0].get_initial_flat_term())
521 .collect(),
522 }
523 }
524
525 pub fn get_last_flat_term(&self) -> FlatTerm<L> {
527 FlatTerm {
528 node: self.node.clone(),
529 backward_rule: self.backward_rule,
530 forward_rule: self.forward_rule,
531 children: self
532 .child_proofs
533 .iter()
534 .map(|child_proof| child_proof[child_proof.len() - 1].get_last_flat_term())
535 .collect(),
536 }
537 }
538
539 pub fn flatten_explanation(&self) -> FlatExplanation<L> {
541 let mut proof = vec![];
542 let mut child_proofs = vec![];
543 let mut representative_terms = vec![];
544 for child_explanation in &self.child_proofs {
545 let flat_proof = TreeTerm::flatten_proof(child_explanation);
546 representative_terms.push(flat_proof[0].remove_rewrites());
547 child_proofs.push(flat_proof);
548 }
549
550 proof.push(FlatTerm::new(
551 self.node.clone(),
552 representative_terms.clone(),
553 ));
554
555 for (i, child_proof) in child_proofs.iter().enumerate() {
556 proof.last_mut().unwrap().children[i] = child_proof[0].clone();
558
559 for child in child_proof.iter().skip(1) {
560 let mut children = vec![];
561 for (j, rep_term) in representative_terms.iter().enumerate() {
562 if j == i {
563 children.push(child.clone());
564 } else {
565 children.push(rep_term.clone());
566 }
567 }
568
569 proof.push(FlatTerm::new(self.node.clone(), children));
570 }
571 representative_terms[i] = child_proof.last().unwrap().remove_rewrites();
572 }
573
574 proof[0].backward_rule = self.backward_rule;
575 proof[0].forward_rule = self.forward_rule;
576
577 proof
578 }
579}
580
581#[derive(Debug, Clone, Eq)]
595pub struct FlatTerm<L: Language> {
596 pub node: L,
599 pub backward_rule: Option<Symbol>,
601 pub forward_rule: Option<Symbol>,
603 pub children: FlatExplanation<L>,
605}
606
607impl<L: Language + Display + FromOp> Display for FlatTerm<L> {
608 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
609 let s = self.get_sexp().to_string();
610 write!(f, "{}", s)
611 }
612}
613
614impl<L: Language> PartialEq for FlatTerm<L> {
615 fn eq(&self, other: &FlatTerm<L>) -> bool {
616 if !self.node.matches(&other.node) {
617 return false;
618 }
619
620 for (child1, child2) in self.children.iter().zip(other.children.iter()) {
621 if !child1.eq(child2) {
622 return false;
623 }
624 }
625 true
626 }
627}
628
629impl<L: Language> FlatTerm<L> {
630 pub fn remove_rewrites(&self) -> FlatTerm<L> {
632 FlatTerm::new(
633 self.node.clone(),
634 self.children
635 .iter()
636 .map(|child| child.remove_rewrites())
637 .collect(),
638 )
639 }
640
641 fn combine_rewrites(&mut self, other: &FlatTerm<L>) {
642 if other.forward_rule.is_some() {
643 assert!(self.forward_rule.is_none());
644 self.forward_rule = other.forward_rule;
645 }
646
647 if other.backward_rule.is_some() {
648 assert!(self.backward_rule.is_none());
649 self.backward_rule = other.backward_rule;
650 }
651
652 for (left, right) in self.children.iter_mut().zip(other.children.iter()) {
653 left.combine_rewrites(right);
654 }
655 }
656}
657
658impl<L: Language> Default for Explain<L> {
659 fn default() -> Self {
660 Self::new()
661 }
662}
663
664impl<L: Language + Display + FromOp> FlatTerm<L> {
665 pub fn get_string(&self) -> String {
668 self.get_sexp().to_string()
669 }
670
671 fn get_sexp(&self) -> Sexp {
672 let op = Sexp::String(self.node.to_string());
673 let mut expr = if self.node.is_leaf() {
674 op
675 } else {
676 let mut vec = vec![op];
677 for child in &self.children {
678 vec.push(child.get_sexp());
679 }
680 Sexp::List(vec)
681 };
682
683 if let Some(rule_name) = &self.backward_rule {
684 expr = Sexp::List(vec![
685 Sexp::String("Rewrite<=".to_string()),
686 Sexp::String((*rule_name).to_string()),
687 expr,
688 ]);
689 }
690
691 if let Some(rule_name) = &self.forward_rule {
692 expr = Sexp::List(vec![
693 Sexp::String("Rewrite=>".to_string()),
694 Sexp::String((*rule_name).to_string()),
695 expr,
696 ]);
697 }
698
699 expr
700 }
701
702 pub fn get_recexpr(&self) -> RecExpr<L> {
704 self.remove_rewrites().to_string().parse().unwrap()
705 }
706}
707
708impl<L: Language + Display + FromOp> Display for TreeTerm<L> {
709 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
710 let mut buf = String::new();
711 let width = 80;
712 pretty_print(&mut buf, &self.get_sexp(), width, 1).unwrap();
713 write!(f, "{}", buf)
714 }
715}
716
717impl<L: Language + Display + FromOp> TreeTerm<L> {
718 fn get_sexp(&self) -> Sexp {
720 self.get_sexp_with_bindings(&Default::default())
721 }
722
723 fn get_sexp_with_bindings(&self, bindings: &HashMap<*const TreeTerm<L>, Sexp>) -> Sexp {
724 let op = Sexp::String(self.node.to_string());
725 let mut expr = if self.node.is_leaf() {
726 op
727 } else {
728 let mut vec = vec![op];
729 for child in &self.child_proofs {
730 assert!(!child.is_empty());
731 if child.len() == 1 {
732 if let Some(existing) = bindings.get(&(&*child[0] as *const TreeTerm<L>)) {
733 vec.push(existing.clone());
734 } else {
735 vec.push(child[0].get_sexp_with_bindings(bindings));
736 }
737 } else {
738 let mut child_expressions = vec![Sexp::String("Explanation".to_string())];
739 for child_explanation in child.iter() {
740 if let Some(existing) =
741 bindings.get(&(&**child_explanation as *const TreeTerm<L>))
742 {
743 child_expressions.push(existing.clone());
744 } else {
745 child_expressions
746 .push(child_explanation.get_sexp_with_bindings(bindings));
747 }
748 }
749 vec.push(Sexp::List(child_expressions));
750 }
751 }
752 Sexp::List(vec)
753 };
754
755 if let Some(rule_name) = &self.backward_rule {
756 expr = Sexp::List(vec![
757 Sexp::String("Rewrite<=".to_string()),
758 Sexp::String((*rule_name).to_string()),
759 expr,
760 ]);
761 }
762
763 if let Some(rule_name) = &self.forward_rule {
764 expr = Sexp::List(vec![
765 Sexp::String("Rewrite=>".to_string()),
766 Sexp::String((*rule_name).to_string()),
767 expr,
768 ]);
769 }
770
771 expr
772 }
773}
774
775impl<L: Language> FlatTerm<L> {
776 pub fn new(node: L, children: FlatExplanation<L>) -> FlatTerm<L> {
778 FlatTerm {
779 node,
780 backward_rule: None,
781 forward_rule: None,
782 children,
783 }
784 }
785
786 pub fn rewrite(&self, lhs: &PatternAst<L>, rhs: &PatternAst<L>) -> FlatTerm<L> {
789 let mut bindings = Default::default();
790 self.make_bindings(lhs, lhs.len() - 1, &mut bindings);
791 FlatTerm::from_pattern(rhs, rhs.len() - 1, &bindings)
792 }
793
794 pub fn has_rewrite_forward(&self) -> bool {
796 self.forward_rule.is_some()
797 || self
798 .children
799 .iter()
800 .any(|child| child.has_rewrite_forward())
801 }
802
803 pub fn has_rewrite_backward(&self) -> bool {
805 self.backward_rule.is_some()
806 || self
807 .children
808 .iter()
809 .any(|child| child.has_rewrite_backward())
810 }
811
812 fn from_pattern(
813 pattern: &[ENodeOrVar<L>],
814 location: usize,
815 bindings: &HashMap<Var, &FlatTerm<L>>,
816 ) -> FlatTerm<L> {
817 match &pattern[location] {
818 ENodeOrVar::Var(var) => (*bindings.get(var).unwrap()).clone(),
819 ENodeOrVar::ENode(node) => {
820 let children = node.fold(vec![], |mut acc, child| {
821 acc.push(FlatTerm::from_pattern(
822 pattern,
823 usize::from(child),
824 bindings,
825 ));
826 acc
827 });
828 FlatTerm::new(node.clone(), children)
829 }
830 }
831 }
832
833 fn make_bindings<'a>(
834 &'a self,
835 pattern: &[ENodeOrVar<L>],
836 location: usize,
837 bindings: &mut HashMap<Var, &'a FlatTerm<L>>,
838 ) {
839 match &pattern[location] {
840 ENodeOrVar::Var(var) => {
841 if let Some(existing) = bindings.get(var) {
842 if existing != &self {
843 panic!(
844 "Invalid proof: binding for variable {:?} does not match between {:?} \n and \n {:?}",
845 var, existing, self);
846 }
847 } else {
848 bindings.insert(*var, self);
849 }
850 }
851 ENodeOrVar::ENode(node) => {
852 assert!(node.matches(&self.node));
854 let mut counter = 0;
855 node.for_each(|child| {
856 self.children[counter].make_bindings(pattern, usize::from(child), bindings);
857 counter += 1;
858 });
859 }
860 }
861 }
862}
863
864#[derive(Clone, Eq, PartialEq)]
866struct HeapState<I> {
867 cost: ProofCost,
868 item: I,
869}
870impl<I: Eq + PartialEq> Ord for HeapState<I> {
874 fn cmp(&self, other: &Self) -> Ordering {
875 other
879 .cost
880 .cmp(&self.cost)
881 .then_with(|| self.cost.cmp(&other.cost))
882 }
883}
884
885impl<I: Eq + PartialEq> PartialOrd for HeapState<I> {
887 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
888 Some(self.cmp(other))
889 }
890}
891
892impl<L: Language> Explain<L> {
893 fn make_rule_table<'a, N: Analysis<L>>(
894 rules: &[&'a Rewrite<L, N>],
895 ) -> HashMap<Symbol, &'a Rewrite<L, N>> {
896 let mut table: HashMap<Symbol, &'a Rewrite<L, N>> = Default::default();
897 for r in rules {
898 table.insert(r.name, r);
899 }
900 table
901 }
902 pub fn new() -> Self {
903 Explain {
904 explainfind: vec![],
905 uncanon_memo: Default::default(),
906 shortest_explanation_memo: Default::default(),
907 optimize_explanation_lengths: true,
908 }
909 }
910
911 pub(crate) fn add(&mut self, node: L, set: Id) -> Id {
912 assert_eq!(self.explainfind.len(), usize::from(set));
913 self.uncanon_memo.insert(node, set);
914 self.explainfind.push(ExplainNode {
915 neighbors: vec![],
916 parent_connection: Connection {
917 justification: Justification::Congruence,
918 is_rewrite_forward: false,
919 next: set,
920 current: set,
921 },
922 });
923 set
924 }
925
926 fn make_leader(&mut self, node: Id) {
928 let next = self.explainfind[usize::from(node)].parent_connection.next;
929 if next != node {
930 self.make_leader(next);
931 let node_connection = &self.explainfind[usize::from(node)].parent_connection;
932 let pconnection = Connection {
933 justification: node_connection.justification.clone(),
934 is_rewrite_forward: !node_connection.is_rewrite_forward,
935 next: node,
936 current: next,
937 };
938 self.explainfind[usize::from(next)].parent_connection = pconnection;
939 }
940 }
941
942 pub(crate) fn alternate_rewrite(&mut self, node1: Id, node2: Id, justification: Justification) {
943 if node1 == node2 {
944 return;
945 }
946 if let Some((cost, _)) = self.shortest_explanation_memo.get(&(node1, node2)) {
947 if cost.is_zero() || cost.is_one() {
948 return;
949 }
950 }
951
952 let lconnection = Connection {
953 justification: justification.clone(),
954 is_rewrite_forward: true,
955 next: node2,
956 current: node1,
957 };
958
959 let rconnection = Connection {
960 justification,
961 is_rewrite_forward: false,
962 next: node1,
963 current: node2,
964 };
965
966 self.explainfind[usize::from(node1)]
967 .neighbors
968 .push(lconnection);
969 self.explainfind[usize::from(node2)]
970 .neighbors
971 .push(rconnection);
972 self.shortest_explanation_memo
973 .insert((node1, node2), (BigUint::one(), node2));
974 self.shortest_explanation_memo
975 .insert((node2, node1), (BigUint::one(), node1));
976 }
977
978 pub(crate) fn union(&mut self, node1: Id, node2: Id, justification: Justification) {
979 if let Justification::Congruence = justification {
980 }
982
983 self.make_leader(node1);
984 self.explainfind[usize::from(node1)].parent_connection.next = node2;
985
986 if let Justification::Rule(_) = justification {
987 self.shortest_explanation_memo
988 .insert((node1, node2), (BigUint::one(), node2));
989 self.shortest_explanation_memo
990 .insert((node2, node1), (BigUint::one(), node1));
991 }
992
993 let pconnection = Connection {
994 justification: justification.clone(),
995 is_rewrite_forward: true,
996 next: node2,
997 current: node1,
998 };
999 let other_pconnection = Connection {
1000 justification,
1001 is_rewrite_forward: false,
1002 next: node1,
1003 current: node2,
1004 };
1005 self.explainfind[usize::from(node1)]
1006 .neighbors
1007 .push(pconnection.clone());
1008 self.explainfind[usize::from(node2)]
1009 .neighbors
1010 .push(other_pconnection);
1011 self.explainfind[usize::from(node1)].parent_connection = pconnection;
1012 }
1013 pub(crate) fn get_union_equalities(&self) -> UnionEqualities {
1014 let mut equalities = vec![];
1015 for node in &self.explainfind {
1016 for neighbor in &node.neighbors {
1017 if neighbor.is_rewrite_forward {
1018 if let Justification::Rule(r) = neighbor.justification {
1019 equalities.push((neighbor.current, neighbor.next, r));
1020 }
1021 }
1022 }
1023 }
1024 equalities
1025 }
1026
1027 pub(crate) fn with_nodes<'a>(&'a mut self, nodes: &'a [L]) -> ExplainNodes<'a, L> {
1028 ExplainNodes {
1029 explain: self,
1030 nodes,
1031 }
1032 }
1033}
1034
1035impl<'a, L: Language> Deref for ExplainNodes<'a, L> {
1036 type Target = Explain<L>;
1037
1038 fn deref(&self) -> &Self::Target {
1039 self.explain
1040 }
1041}
1042
1043impl<'a, L: Language> DerefMut for ExplainNodes<'a, L> {
1044 fn deref_mut(&mut self) -> &mut Self::Target {
1045 &mut *self.explain
1046 }
1047}
1048
1049impl<'x, L: Language> ExplainNodes<'x, L> {
1050 pub(crate) fn node(&self, node_id: Id) -> &L {
1051 &self.nodes[usize::from(node_id)]
1052 }
1053 fn node_to_explanation(
1054 &self,
1055 node_id: Id,
1056 cache: &mut NodeExplanationCache<L>,
1057 ) -> Rc<TreeTerm<L>> {
1058 if let Some(existing) = cache.get(&node_id) {
1059 existing.clone()
1060 } else {
1061 let node = self.node(node_id).clone();
1062 let children = node.fold(vec![], |mut sofar, child| {
1063 sofar.push(vec![self.node_to_explanation(child, cache)]);
1064 sofar
1065 });
1066 let res = Rc::new(TreeTerm::new(node, children));
1067 cache.insert(node_id, res.clone());
1068 res
1069 }
1070 }
1071
1072 fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm<L> {
1073 let node = self.node(node_id).clone();
1074 let children = node.fold(vec![], |mut sofar, child| {
1075 sofar.push(self.node_to_flat_explanation(child));
1076 sofar
1077 });
1078 FlatTerm::new(node, children)
1079 }
1080
1081 pub fn check_each_explain<N: Analysis<L>>(&self, rules: &[&Rewrite<L, N>]) -> bool {
1082 let rule_table = Explain::make_rule_table(rules);
1083 for i in 0..self.explainfind.len() {
1084 let explain_node = &self.explainfind[i];
1085
1086 if explain_node.parent_connection.next != Id::from(i) {
1087 let mut current_explanation = self.node_to_flat_explanation(Id::from(i));
1088 let mut next_explanation =
1089 self.node_to_flat_explanation(explain_node.parent_connection.next);
1090 if let Justification::Rule(rule_name) =
1091 &explain_node.parent_connection.justification
1092 {
1093 if let Some(rule) = rule_table.get(rule_name) {
1094 if !explain_node.parent_connection.is_rewrite_forward {
1095 std::mem::swap(&mut current_explanation, &mut next_explanation);
1096 }
1097 if !Explanation::check_rewrite(
1098 ¤t_explanation,
1099 &next_explanation,
1100 rule,
1101 ) {
1102 return false;
1103 }
1104 }
1105 }
1106 }
1107 }
1108 true
1109 }
1110
1111 pub(crate) fn explain_equivalence<N: Analysis<L>>(
1112 &mut self,
1113 left: Id,
1114 right: Id,
1115 unionfind: &mut UnionFind,
1116 classes: &HashMap<Id, EClass<L, N::Data>>,
1117 ) -> Explanation<L> {
1118 if self.optimize_explanation_lengths {
1119 self.calculate_shortest_explanations::<N>(left, right, classes, unionfind);
1120 }
1121
1122 let mut cache = Default::default();
1123 let mut enode_cache = Default::default();
1124 Explanation::new(self.explain_enodes(left, right, &mut cache, &mut enode_cache, false))
1125 }
1126
1127 fn common_ancestor(&self, mut left: Id, mut right: Id) -> Id {
1128 let mut seen_left: HashSet<Id> = Default::default();
1129 let mut seen_right: HashSet<Id> = Default::default();
1130 loop {
1131 seen_left.insert(left);
1132 if seen_right.contains(&left) {
1133 return left;
1134 }
1135
1136 seen_right.insert(right);
1137 if seen_left.contains(&right) {
1138 return right;
1139 }
1140
1141 let next_left = self.explainfind[usize::from(left)].parent_connection.next;
1142 let next_right = self.explainfind[usize::from(right)].parent_connection.next;
1143 assert!(next_left != left || next_right != right);
1144 left = next_left;
1145 right = next_right;
1146 }
1147 }
1148
1149 fn get_connections(&self, mut node: Id, ancestor: Id) -> Vec<Connection> {
1150 if node == ancestor {
1151 return vec![];
1152 }
1153
1154 let mut nodes = vec![];
1155 loop {
1156 let next = self.explainfind[usize::from(node)].parent_connection.next;
1157 nodes.push(
1158 self.explainfind[usize::from(node)]
1159 .parent_connection
1160 .clone(),
1161 );
1162 if next == ancestor {
1163 return nodes;
1164 }
1165 assert!(next != node);
1166 node = next;
1167 }
1168 }
1169
1170 fn get_path_unoptimized(&self, left: Id, right: Id) -> (Vec<Connection>, Vec<Connection>) {
1171 let ancestor = self.common_ancestor(left, right);
1172 let left_connections = self.get_connections(left, ancestor);
1173 let right_connections = self.get_connections(right, ancestor);
1174 (left_connections, right_connections)
1175 }
1176
1177 fn get_neighbor(&self, current: Id, next: Id) -> Connection {
1178 for neighbor in &self.explainfind[usize::from(current)].neighbors {
1179 if neighbor.next == next {
1180 if let Justification::Rule(_) = neighbor.justification {
1181 return neighbor.clone();
1182 }
1183 }
1184 }
1185 Connection {
1186 justification: Justification::Congruence,
1187 current,
1188 next,
1189 is_rewrite_forward: true,
1190 }
1191 }
1192
1193 fn get_path(&self, mut left: Id, right: Id) -> (Vec<Connection>, Vec<Connection>) {
1194 let mut left_connections = vec![];
1195 loop {
1196 if left == right {
1197 return (left_connections, vec![]);
1198 }
1199 if let Some((_, next)) = self.shortest_explanation_memo.get(&(left, right)) {
1200 left_connections.push(self.get_neighbor(left, *next));
1201 left = *next;
1202 } else {
1203 break;
1204 }
1205 }
1206
1207 let (restleft, right_connections) = self.get_path_unoptimized(left, right);
1208 left_connections.extend(restleft);
1209 (left_connections, right_connections)
1210 }
1211
1212 fn explain_enodes(
1213 &self,
1214 left: Id,
1215 right: Id,
1216 cache: &mut ExplainCache<L>,
1217 node_explanation_cache: &mut NodeExplanationCache<L>,
1218 use_unoptimized: bool,
1219 ) -> TreeExplanation<L> {
1220 let mut proof = vec![self.node_to_explanation(left, node_explanation_cache)];
1221 let (left_connections, right_connections) = if use_unoptimized {
1222 self.get_path_unoptimized(left, right)
1223 } else {
1224 self.get_path(left, right)
1225 };
1226
1227 for (i, connection) in left_connections
1228 .iter()
1229 .chain(right_connections.iter().rev())
1230 .enumerate()
1231 {
1232 let mut connection = connection.clone();
1233 if i >= left_connections.len() {
1234 connection.is_rewrite_forward = !connection.is_rewrite_forward;
1235 std::mem::swap(&mut connection.next, &mut connection.current);
1236 }
1237
1238 proof.push(self.explain_adjacent(
1239 connection,
1240 cache,
1241 node_explanation_cache,
1242 use_unoptimized,
1243 ));
1244 }
1245 proof
1246 }
1247
1248 fn explain_adjacent(
1249 &self,
1250 connection: Connection,
1251 cache: &mut ExplainCache<L>,
1252 node_explanation_cache: &mut NodeExplanationCache<L>,
1253 use_unoptimized: bool,
1254 ) -> Rc<TreeTerm<L>> {
1255 let fingerprint = (connection.current, connection.next);
1256
1257 if let Some(answer) = cache.get(&fingerprint) {
1258 return answer.clone();
1259 }
1260
1261 let term = match connection.justification {
1262 Justification::Rule(name) => {
1263 let mut rewritten =
1264 (*self.node_to_explanation(connection.next, node_explanation_cache)).clone();
1265 if connection.is_rewrite_forward {
1266 rewritten.forward_rule = Some(name);
1267 } else {
1268 rewritten.backward_rule = Some(name);
1269 }
1270
1271 rewritten.current = connection.next;
1272 rewritten.last = connection.current;
1273
1274 Rc::new(rewritten)
1275 }
1276 Justification::Congruence => {
1277 let current_node = self.node(connection.current);
1279 let next_node = self.node(connection.next);
1280 assert!(current_node.matches(next_node));
1281 let mut subproofs = vec![];
1282
1283 for (left_child, right_child) in current_node
1284 .children()
1285 .iter()
1286 .zip(next_node.children().iter())
1287 {
1288 subproofs.push(self.explain_enodes(
1289 *left_child,
1290 *right_child,
1291 cache,
1292 node_explanation_cache,
1293 use_unoptimized,
1294 ));
1295 }
1296 Rc::new(TreeTerm::new(current_node.clone(), subproofs))
1297 }
1298 };
1299
1300 cache.insert(fingerprint, term.clone());
1301
1302 term
1303 }
1304
1305 fn find_all_enodes(&self, eclass: Id) -> HashSet<Id> {
1306 let mut enodes = HashSet::default();
1307 let mut todo = vec![eclass];
1308
1309 while let Some(current) = todo.pop() {
1310 if enodes.insert(current) {
1311 for neighbor in &self.explainfind[usize::from(current)].neighbors {
1312 todo.push(neighbor.next);
1313 }
1314 }
1315 }
1316 enodes
1317 }
1318
1319 fn add_tree_depths(&self, node: Id, depths: &mut HashMap<Id, ProofCost>) -> ProofCost {
1320 if depths.get(&node).is_none() {
1321 let parent = self.parent(node);
1322 let depth = if parent == node {
1323 BigUint::zero()
1324 } else {
1325 self.add_tree_depths(parent, depths) + 1_u32
1326 };
1327
1328 depths.insert(node, depth);
1329 }
1330
1331 depths.get(&node).unwrap().clone()
1332 }
1333
1334 fn calculate_tree_depths(&self) -> HashMap<Id, ProofCost> {
1335 let mut depths = HashMap::default();
1336 for i in 0..self.explainfind.len() {
1337 self.add_tree_depths(Id::from(i), &mut depths);
1338 }
1339 depths
1340 }
1341
1342 fn replace_distance(&mut self, current: Id, next: Id, right: Id, distance: ProofCost) {
1343 self.shortest_explanation_memo
1344 .insert((current, right), (distance, next));
1345 }
1346
1347 fn populate_path_length(
1348 &mut self,
1349 right: Id,
1350 left_connections: &[Connection],
1351 distance_memo: &mut DistanceMemo,
1352 ) {
1353 self.shortest_explanation_memo
1354 .insert((right, right), (BigUint::zero(), right));
1355 for connection in left_connections.iter().rev() {
1356 let next = connection.next;
1357 let current = connection.current;
1358 let next_cost = self
1359 .shortest_explanation_memo
1360 .get(&(next, right))
1361 .unwrap()
1362 .0
1363 .clone();
1364 let dist = self.connection_distance(connection, distance_memo);
1365 self.replace_distance(current, next, right, next_cost + dist);
1366 }
1367 }
1368
1369 fn distance_between(
1370 &mut self,
1371 left: Id,
1372 right: Id,
1373 distance_memo: &mut DistanceMemo,
1374 ) -> ProofCost {
1375 if left == right {
1376 return BigUint::zero();
1377 }
1378 let ancestor = if let Some(a) = distance_memo.common_ancestor.get(&(left, right)) {
1379 *a
1380 } else {
1381 self.common_ancestor(left, right)
1383 };
1384 self.calculate_parent_distance(left, ancestor, distance_memo);
1386 self.calculate_parent_distance(right, ancestor, distance_memo);
1387
1388 let a = self.calculate_parent_distance(ancestor, Id::from(usize::MAX), distance_memo);
1390 let b = self.calculate_parent_distance(left, Id::from(usize::MAX), distance_memo);
1391 let c = self.calculate_parent_distance(right, Id::from(usize::MAX), distance_memo);
1392
1393 assert!(
1394 distance_memo.parent_distance[usize::from(ancestor)].0
1395 == distance_memo.parent_distance[usize::from(left)].0
1396 );
1397 assert!(
1398 distance_memo.parent_distance[usize::from(ancestor)].0
1399 == distance_memo.parent_distance[usize::from(right)].0
1400 );
1401
1402 b + c - (a << 1)
1404
1405 }
1407
1408 fn congruence_distance(
1409 &mut self,
1410 current: Id,
1411 next: Id,
1412 distance_memo: &mut DistanceMemo,
1413 ) -> ProofCost {
1414 let current_node = self.node(current).clone();
1415 let next_node = self.node(next).clone();
1416 let mut cost: ProofCost = BigUint::zero();
1417 for (left_child, right_child) in current_node
1418 .children()
1419 .iter()
1420 .zip(next_node.children().iter())
1421 {
1422 cost += self.distance_between(*left_child, *right_child, distance_memo);
1423 }
1424 cost
1425 }
1426
1427 fn connection_distance(
1428 &mut self,
1429 connection: &Connection,
1430 distance_memo: &mut DistanceMemo,
1431 ) -> ProofCost {
1432 match connection.justification {
1433 Justification::Congruence => {
1434 self.congruence_distance(connection.current, connection.next, distance_memo)
1435 }
1436 Justification::Rule(_) => BigUint::one(),
1437 }
1438 }
1439
1440 fn calculate_parent_distance(
1441 &mut self,
1442 enode: Id,
1443 ancestor: Id,
1444 distance_memo: &mut DistanceMemo,
1445 ) -> ProofCost {
1446 loop {
1447 let parent = distance_memo.parent_distance[usize::from(enode)].0;
1448 let dist = distance_memo.parent_distance[usize::from(enode)].1.clone();
1449 if self.parent(parent) == parent {
1450 break;
1451 }
1452
1453 let parent_parent = distance_memo.parent_distance[usize::from(parent)].0;
1454 if parent_parent != parent {
1455 let new_dist = dist + distance_memo.parent_distance[usize::from(parent)].1.clone();
1456 distance_memo.parent_distance[usize::from(enode)] = (parent_parent, new_dist);
1457 } else {
1458 if ancestor == Id::from(usize::MAX) {
1459 break;
1460 }
1461 if distance_memo.tree_depth.get(&parent).unwrap()
1462 <= distance_memo.tree_depth.get(&ancestor).unwrap()
1463 {
1464 break;
1465 }
1466
1467 let connection = &self.explainfind[usize::from(parent)].parent_connection;
1469 let current = connection.current;
1470 let next = connection.next;
1471 let cost = match connection.justification {
1472 Justification::Congruence => {
1473 self.congruence_distance(current, next, distance_memo)
1474 }
1475 Justification::Rule(_) => BigUint::one(),
1476 };
1477 distance_memo.parent_distance[usize::from(parent)] = (self.parent(parent), cost);
1478 }
1479 }
1480
1481 distance_memo.parent_distance[usize::from(enode)].1.clone()
1485 }
1486
1487 fn find_congruence_neighbors<N: Analysis<L>>(
1488 &self,
1489 classes: &HashMap<Id, EClass<L, N::Data>>,
1490 congruence_neighbors: &mut [Vec<Id>],
1491 unionfind: &UnionFind,
1492 ) {
1493 let mut counter = 0;
1494 for node in &self.explainfind {
1496 if let Justification::Congruence = node.parent_connection.justification {
1497 let current = node.parent_connection.current;
1498 let next = node.parent_connection.next;
1499 congruence_neighbors[usize::from(current)].push(next);
1500 congruence_neighbors[usize::from(next)].push(current);
1501 counter += 1;
1502 }
1503 }
1504
1505 'outer: for eclass in classes.keys() {
1506 let enodes = self.find_all_enodes(*eclass);
1507 let mut cannon_enodes: HashMap<L, Vec<Id>> = Default::default();
1509 for enode in &enodes {
1510 let cannon = self
1511 .node(*enode)
1512 .clone()
1513 .map_children(|child| unionfind.find(child));
1514 if let Some(others) = cannon_enodes.get_mut(&cannon) {
1515 for other in others.iter() {
1516 congruence_neighbors[usize::from(*enode)].push(*other);
1517 congruence_neighbors[usize::from(*other)].push(*enode);
1518 counter += 1;
1519 }
1520 others.push(*enode);
1521 } else {
1522 counter += 1;
1523 cannon_enodes.insert(cannon, vec![*enode]);
1524 }
1525 if counter > CONGRUENCE_LIMIT * self.explainfind.len() {
1527 break 'outer;
1528 }
1529 }
1530 }
1531 }
1532
1533 pub fn get_num_congr<N: Analysis<L>>(
1534 &self,
1535 classes: &HashMap<Id, EClass<L, N::Data>>,
1536 unionfind: &UnionFind,
1537 ) -> usize {
1538 let mut congruence_neighbors = vec![vec![]; self.explainfind.len()];
1539 self.find_congruence_neighbors::<N>(classes, &mut congruence_neighbors, unionfind);
1540 let mut count = 0;
1541 for v in congruence_neighbors {
1542 count += v.len();
1543 }
1544
1545 count / 2
1546 }
1547
1548 pub fn get_num_nodes(&self) -> usize {
1549 self.explainfind.len()
1550 }
1551
1552 fn shortest_path_modulo_congruence(
1553 &mut self,
1554 start: Id,
1555 end: Id,
1556 congruence_neighbors: &[Vec<Id>],
1557 distance_memo: &mut DistanceMemo,
1558 ) -> Option<(Vec<Connection>, Vec<Connection>)> {
1559 let mut todo = BinaryHeap::new();
1560 todo.push(HeapState {
1561 cost: BigUint::zero(),
1562 item: Connection {
1563 current: start,
1564 next: start,
1565 justification: Justification::Congruence,
1566 is_rewrite_forward: true,
1567 },
1568 });
1569
1570 let mut last = HashMap::default();
1571 let mut path_cost = HashMap::default();
1572
1573 'outer: loop {
1574 if todo.is_empty() {
1575 break 'outer;
1576 }
1577 let state = todo.pop().unwrap();
1578 let connection = state.item;
1579 let cost_so_far = state.cost.clone();
1580 let current = connection.next;
1581
1582 if last.get(¤t).is_some() {
1583 continue 'outer;
1584 } else {
1585 last.insert(current, connection);
1586 path_cost.insert(current, cost_so_far.clone());
1587 }
1588
1589 if current == end {
1590 break;
1591 }
1592
1593 for neighbor in &self.explainfind[usize::from(current)].neighbors {
1594 if let Justification::Rule(_) = neighbor.justification {
1595 let neighbor_cost = cost_so_far.clone() + 1_u32;
1596 todo.push(HeapState {
1597 item: neighbor.clone(),
1598 cost: neighbor_cost,
1599 });
1600 }
1601 }
1602
1603 for other in congruence_neighbors[usize::from(current)].iter() {
1604 let next = other;
1605 let distance = self.congruence_distance(current, *next, distance_memo);
1606 let next_cost = cost_so_far.clone() + distance;
1607 todo.push(HeapState {
1608 item: Connection {
1609 current,
1610 next: *next,
1611 justification: Justification::Congruence,
1612 is_rewrite_forward: true,
1613 },
1614 cost: next_cost,
1615 });
1616 }
1617 }
1618
1619 let total_cost = path_cost.get(&end);
1620
1621 let left_connections;
1622 let mut right_connections = vec![];
1623
1624 if *total_cost.unwrap() >= self.distance_between(start, end, distance_memo) {
1635 let (a_left_connections, a_right_connections) = self.get_path_unoptimized(start, end);
1636 left_connections = a_left_connections;
1637 right_connections = a_right_connections;
1638 } else {
1639 let mut current = end;
1640 let mut connections = vec![];
1641 while current != start {
1642 let prev = last.get(¤t);
1643 if let Some(prev_connection) = prev {
1644 connections.push(prev_connection.clone());
1645 current = prev_connection.current;
1646 } else {
1647 break;
1648 }
1649 }
1650 connections.reverse();
1651 self.populate_path_length(end, &connections, distance_memo);
1652 left_connections = connections;
1653 }
1654
1655 Some((left_connections, right_connections))
1656 }
1657
1658 fn greedy_short_explanations(
1659 &mut self,
1660 start: Id,
1661 end: Id,
1662 congruence_neighbors: &[Vec<Id>],
1663 distance_memo: &mut DistanceMemo,
1664 mut fuel: usize,
1665 ) {
1666 let mut todo_congruence = VecDeque::new();
1667 todo_congruence.push_back((start, end));
1668
1669 while !todo_congruence.is_empty() {
1670 let (start, end) = todo_congruence.pop_front().unwrap();
1671 let eclass_size = self.find_all_enodes(start).len();
1672 if fuel < eclass_size {
1673 continue;
1674 }
1675 fuel = fuel.saturating_sub(eclass_size);
1676
1677 let (left_connections, right_connections) = self
1678 .shortest_path_modulo_congruence(start, end, congruence_neighbors, distance_memo)
1679 .unwrap();
1680
1681 for (i, connection) in left_connections
1684 .iter()
1685 .chain(right_connections.iter().rev())
1686 .enumerate()
1687 {
1688 let mut next = connection.next;
1689 let mut current = connection.current;
1690 if i >= left_connections.len() {
1691 std::mem::swap(&mut next, &mut current);
1692 }
1693 if let Justification::Congruence = connection.justification {
1694 let current_node = self.node(current).clone();
1695 let next_node = self.node(next).clone();
1696 for (left_child, right_child) in current_node
1697 .children()
1698 .iter()
1699 .zip(next_node.children().iter())
1700 {
1701 todo_congruence.push_back((*left_child, *right_child));
1702 }
1703 }
1704 }
1705 }
1706 }
1707
1708 #[allow(clippy::too_many_arguments)]
1709 fn tarjan_ocla(
1710 &self,
1711 enode: Id,
1712 children: &HashMap<Id, Vec<Id>>,
1713 common_ancestor_queries: &HashMap<Id, Vec<Id>>,
1714 black_set: &mut HashSet<Id>,
1715 unionfind: &mut UnionFind,
1716 ancestor: &mut Vec<Id>,
1717 common_ancestor: &mut HashMap<(Id, Id), Id>,
1718 ) {
1719 ancestor[usize::from(enode)] = enode;
1720 for child in children[&enode].iter() {
1721 self.tarjan_ocla(
1722 *child,
1723 children,
1724 common_ancestor_queries,
1725 black_set,
1726 unionfind,
1727 ancestor,
1728 common_ancestor,
1729 );
1730 unionfind.union(enode, *child);
1731 ancestor[usize::from(unionfind.find(enode))] = enode;
1732 }
1733
1734 if common_ancestor_queries.get(&enode).is_some() {
1735 black_set.insert(enode);
1736 for other in common_ancestor_queries.get(&enode).unwrap() {
1737 if black_set.contains(other) {
1738 let ancestor = ancestor[usize::from(unionfind.find(*other))];
1739 common_ancestor.insert((enode, *other), ancestor);
1740 common_ancestor.insert((*other, enode), ancestor);
1741 }
1742 }
1743 }
1744 }
1745
1746 fn parent(&self, enode: Id) -> Id {
1747 self.explainfind[usize::from(enode)].parent_connection.next
1748 }
1749
1750 fn calculate_common_ancestor<N: Analysis<L>>(
1751 &self,
1752 classes: &HashMap<Id, EClass<L, N::Data>>,
1753 congruence_neighbors: &[Vec<Id>],
1754 ) -> HashMap<(Id, Id), Id> {
1755 let mut common_ancestor_queries = HashMap::default();
1756 for (s_int, others) in congruence_neighbors.iter().enumerate() {
1757 let start = &Id::from(s_int);
1758 for other in others {
1759 for (left, right) in self
1760 .node(*start)
1761 .children()
1762 .iter()
1763 .zip(self.node(*other).children().iter())
1764 {
1765 if left != right {
1766 if common_ancestor_queries.get(start).is_none() {
1767 common_ancestor_queries.insert(*start, vec![]);
1768 }
1769 if common_ancestor_queries.get(other).is_none() {
1770 common_ancestor_queries.insert(*other, vec![]);
1771 }
1772 common_ancestor_queries.get_mut(start).unwrap().push(*other);
1773 common_ancestor_queries.get_mut(other).unwrap().push(*start);
1774 }
1775 }
1776 }
1777 }
1778
1779 let mut common_ancestor = HashMap::default();
1780 let mut unionfind = UnionFind::default();
1781 let mut ancestor = vec![];
1782 for i in 0..self.explainfind.len() {
1783 unionfind.make_set();
1784 ancestor.push(Id::from(i));
1785 }
1786 for (eclass, _) in classes.iter() {
1787 let enodes = self.find_all_enodes(*eclass);
1788 let mut children: HashMap<Id, Vec<Id>> = HashMap::default();
1789 for enode in &enodes {
1790 children.insert(*enode, vec![]);
1791 }
1792 for enode in &enodes {
1793 if self.parent(*enode) != *enode {
1794 children.get_mut(&self.parent(*enode)).unwrap().push(*enode);
1795 }
1796 }
1797
1798 let mut black_set = HashSet::default();
1799
1800 let mut parent = *enodes.iter().next().unwrap();
1801 while parent != self.parent(parent) {
1802 parent = self.parent(parent);
1803 }
1804 self.tarjan_ocla(
1805 parent,
1806 &children,
1807 &common_ancestor_queries,
1808 &mut black_set,
1809 &mut unionfind,
1810 &mut ancestor,
1811 &mut common_ancestor,
1812 );
1813 }
1814
1815 common_ancestor
1816 }
1817
1818 fn calculate_shortest_explanations<N: Analysis<L>>(
1819 &mut self,
1820 start: Id,
1821 end: Id,
1822 classes: &HashMap<Id, EClass<L, N::Data>>,
1823 unionfind: &UnionFind,
1824 ) {
1825 let mut congruence_neighbors = vec![vec![]; self.explainfind.len()];
1826 self.find_congruence_neighbors::<N>(classes, &mut congruence_neighbors, unionfind);
1827 let mut parent_distance = vec![(Id::from(0), BigUint::zero()); self.explainfind.len()];
1828 for (i, entry) in parent_distance.iter_mut().enumerate() {
1829 entry.0 = Id::from(i);
1830 }
1831
1832 let mut distance_memo = DistanceMemo {
1833 parent_distance,
1834 common_ancestor: self.calculate_common_ancestor::<N>(classes, &congruence_neighbors),
1835 tree_depth: self.calculate_tree_depths(),
1836 };
1837
1838 let fuel = GREEDY_NUM_ITERS * self.explainfind.len();
1839 self.greedy_short_explanations(start, end, &congruence_neighbors, &mut distance_memo, fuel);
1840 }
1841}
1842
1843#[cfg(test)]
1844mod tests {
1845 use super::super::*;
1846
1847 #[test]
1848 fn simple_explain() {
1849 use SymbolLang as S;
1850
1851 crate::init_logger();
1852 let mut egraph = EGraph::<S, ()>::default().with_explanations_enabled();
1853
1854 let fa = "(f a)".parse().unwrap();
1855 let fb = "(f b)".parse().unwrap();
1856 egraph.add_expr(&fa);
1857 egraph.add_expr(&fb);
1858 egraph.add_expr(&"c".parse().unwrap());
1859 egraph.add_expr(&"d".parse().unwrap());
1860
1861 egraph.union_instantiations(
1862 &"a".parse().unwrap(),
1863 &"c".parse().unwrap(),
1864 &Default::default(),
1865 "ac".to_string(),
1866 );
1867
1868 egraph.union_instantiations(
1869 &"c".parse().unwrap(),
1870 &"d".parse().unwrap(),
1871 &Default::default(),
1872 "cd".to_string(),
1873 );
1874
1875 egraph.union_instantiations(
1876 &"d".parse().unwrap(),
1877 &"b".parse().unwrap(),
1878 &Default::default(),
1879 "db".to_string(),
1880 );
1881
1882 egraph.rebuild();
1883
1884 assert_eq!(egraph.add_expr(&fa), egraph.add_expr(&fb));
1885 assert_eq!(
1886 egraph
1887 .explain_equivalence(&fa, &fb)
1888 .get_flat_strings()
1889 .len(),
1890 4
1891 );
1892 assert_eq!(
1893 egraph
1894 .explain_equivalence(&fa, &fb)
1895 .get_flat_strings()
1896 .len(),
1897 4
1898 );
1899 assert_eq!(
1900 egraph
1901 .explain_equivalence(&fa, &fb)
1902 .get_flat_strings()
1903 .len(),
1904 4
1905 );
1906
1907 egraph.union_instantiations(
1908 &"(f a)".parse().unwrap(),
1909 &"g".parse().unwrap(),
1910 &Default::default(),
1911 "fag".to_string(),
1912 );
1913 egraph.union_instantiations(
1914 &"g".parse().unwrap(),
1915 &"(f b)".parse().unwrap(),
1916 &Default::default(),
1917 "gfb".to_string(),
1918 );
1919
1920 egraph.rebuild();
1921
1922 egraph = egraph.without_explanation_length_optimization();
1923 assert_eq!(
1924 egraph
1925 .explain_equivalence(&fa, &fb)
1926 .get_flat_strings()
1927 .len(),
1928 4
1929 );
1930 egraph = egraph.with_explanation_length_optimization();
1931 assert_eq!(
1932 egraph
1933 .explain_equivalence(&fa, &fb)
1934 .get_flat_strings()
1935 .len(),
1936 3
1937 );
1938
1939 assert_eq!(
1940 egraph
1941 .explain_equivalence(&fa, &fb)
1942 .get_flat_strings()
1943 .len(),
1944 3
1945 );
1946
1947 egraph.dot().to_dot("target/foo.dot").unwrap();
1948 }
1949}
1950
1951#[test]
1952fn simple_explain_union_trusted() {
1953 use crate::{EGraph, SymbolLang};
1954 crate::init_logger();
1955 let mut egraph = EGraph::new(()).with_explanations_enabled();
1956
1957 let a = egraph.add_uncanonical(SymbolLang::leaf("a"));
1958 let b = egraph.add_uncanonical(SymbolLang::leaf("b"));
1959 let c = egraph.add_uncanonical(SymbolLang::leaf("c"));
1960 let d = egraph.add_uncanonical(SymbolLang::leaf("d"));
1961 egraph.union_trusted(a, b, "a=b");
1962 egraph.rebuild();
1963 let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a]));
1964 let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b]));
1965 egraph.union_trusted(c, fa, "c=fa");
1966 egraph.union_trusted(d, fb, "d=fb");
1967 egraph.rebuild();
1968 let mut exp = egraph.explain_equivalence(&"c".parse().unwrap(), &"d".parse().unwrap());
1969 assert_eq!(exp.make_flat_explanation().len(), 4)
1970}