1use crate::*;
2
3impl<L: Language, N: Analysis<L>> EGraph<L, N> {
4 pub fn union(&mut self, l: &AppliedId, r: &AppliedId) -> bool {
5 self.union_justified(l, r, None)
6 }
7
8 pub fn union_justified(&mut self, l: &AppliedId, r: &AppliedId, j: Option<String>) -> bool {
9 let subst = [
10 (String::from("a"), l.clone()),
11 (String::from("b"), r.clone()),
12 ]
13 .into_iter()
14 .collect();
15 let a = Pattern::parse("?a").unwrap();
16 let b = Pattern::parse("?b").unwrap();
17
18 self.union_instantiations(&a, &b, &subst, j)
19 }
20
21 pub fn union_instantiations(
22 &mut self,
23 from_pat: &Pattern<L>,
24 to_pat: &Pattern<L>,
25 subst: &Subst,
26 #[allow(unused)] justification: Option<String>,
27 ) -> bool {
28 let a = pattern_subst(self, from_pat, subst);
29 let b = pattern_subst(self, to_pat, subst);
30
31 #[allow(unused)]
32 let syn_a = self.synify_app_id(a.clone());
33 #[allow(unused)]
34 let syn_b = self.synify_app_id(b.clone());
35
36 let proof = ghost!(self.prove_explicit(&syn_a, &syn_b, justification));
37
38 let out = self.union_internal(&a, &b, proof);
39 self.rebuild_called_from_union_instantiations();
40 out
41 }
42
43 fn rebuild_called_from_union_instantiations(&mut self) {
44 self.rebuild();
45 }
46
47 pub(in crate::egraph) fn union_internal(
48 &mut self,
49 l: &AppliedId,
50 r: &AppliedId,
51 #[allow(unused)] proof: ProvenEq,
52 ) -> bool {
53 let pai_l = self.proven_find_applied_id(&l);
55 let pai_r = self.proven_find_applied_id(&r);
56
57 let proof = ghost!({
58 if CHECKS {
59 pai_l.proof.check(self);
60 pai_r.proof.check(self);
61 }
62
63 let a = self.prove_symmetry(pai_l.proof);
64 let a = self.prove_transitivity(a, proof);
65 let a = self.prove_transitivity(a, pai_r.proof);
66 if CHECKS {
67 assert_eq!(a.l.id, pai_l.elem.id);
68 assert_eq!(a.r.id, pai_r.elem.id);
69 }
70 a
71 });
72 self.union_leaders(pai_l.elem, pai_r.elem, proof)
73 }
74
75 fn union_leaders(&mut self, l: AppliedId, r: AppliedId, proof: ProvenEq) -> bool {
76 if self.eq(&l, &r) {
78 return false;
79 }
80
81 let cap = &l.slots() & &r.slots();
82
83 if l.slots() != cap {
84 self.shrink_slots(&l, &cap, proof.clone());
85 self.union_internal(&l, &r, proof);
86 return true;
87 }
88
89 if r.slots() != cap {
90 let flipped_proof = ghost!(self.prove_symmetry(proof.clone()));
91 self.shrink_slots(&r, &cap, flipped_proof);
92 self.union_internal(&l, &r, proof);
93 return true;
94 }
95
96 if l.id == r.id {
97 let id = l.id;
98
99 let perm = l.m.compose(&r.m.inverse());
103 if CHECKS {
104 assert!(perm.is_perm());
105 assert_eq!(&perm.keys(), &self.classes[&id].slots);
106 }
107
108 let proven_perm = ProvenPerm {
109 elem: perm,
110 #[cfg(feature = "explanations")]
111 proof,
112 #[cfg(feature = "explanations")]
113 reg: self.proof_registry.clone(),
114 };
115
116 if CHECKS {
117 #[cfg(feature = "explanations")]
118 assert_eq!(proven_perm.proof.l.id, id);
119
120 proven_perm.check();
121 }
122 let grp = &mut self.classes.get_mut(&id).unwrap().group;
123 if grp.contains(&proven_perm.to_slotmap()) {
124 return false;
125 }
126
127 grp.add(proven_perm);
128
129 self.touched_class(id, PendingType::Full);
130
131 true
132 } else {
133 let slot_size = |i| self.classes[&i].syn_enode.slots().len();
134
135 let size = |i| {
136 let c = &self.classes[&i];
137 c.nodes.len() + c.usages.len()
138 };
139
140 let right_order = |l, r| {
143 let (ssl, ssr) = (slot_size(l), slot_size(r));
146 if ssl > ssr {
147 return true;
148 }
149 if ssl < ssr {
150 return false;
151 }
152
153 size(l) <= size(r)
155 };
156
157 if right_order(l.id, r.id) {
158 self.move_to(&l, &r, proof)
159 } else {
160 let proof = ghost!(self.prove_symmetry(proof));
161 self.move_to(&r, &l, proof)
162 }
163
164 true
165 }
166 }
167
168 fn assert_ty(&self, m: &SlotMap, keys: &SmallHashSet<Slot>, values: &SmallHashSet<Slot>) {
169 assert!(m.keys().is_subset(keys));
170 assert!(m.values().is_subset(values));
171 }
172
173 fn move_to(&mut self, from: &AppliedId, to: &AppliedId, #[allow(unused)] proof: ProvenEq) {
175 if CHECKS {
176 assert_eq!(from.slots(), to.slots());
177 #[cfg(feature = "explanations")]
178 assert_eq!(from.id, proof.l.id);
179 #[cfg(feature = "explanations")]
180 assert_eq!(to.id, proof.r.id);
181 }
182
183 {
184 let analysis_from = self.analysis_data(from.id).clone();
185 let analysis_to = self.analysis_data_mut(to.id);
186 let old_analysis_to = analysis_to.clone();
187 let new_analysis_to = N::merge(analysis_from, analysis_to.clone());
188 let changed = old_analysis_to != new_analysis_to;
189 *analysis_to = new_analysis_to;
190
191 if changed {
192 self.modify_queue.push(to.id);
193 self.touched_class(to.id, PendingType::OnlyAnalysis);
194 }
195 }
196
197 let map = to.m.compose_partial(&from.m.inverse());
200 if CHECKS {
201 self.assert_ty(&map, &self.slots(to.id), &self.slots(from.id));
202 }
203
204 let app_id = self.mk_sem_applied_id(to.id, map.clone());
205 let pai = ProvenAppliedId {
206 elem: app_id,
207
208 #[cfg(feature = "explanations")]
209 proof,
210 };
211 self.unionfind_set(from.id, pai);
212
213 let from_nodes = self.classes.get(&from.id).unwrap().nodes.clone();
216 for (sh, psn) in from_nodes {
217 self.raw_remove_from_class(from.id, sh.clone());
218 let new_bij = psn.elem.compose_fresh(&map.inverse());
221
222 let src_id = psn.src_id;
223
224 self.raw_add_to_class(to.id, (sh.clone(), new_bij), src_id);
225 self.pending.insert(sh, PendingType::Full);
226 }
227
228 let f = from.m.compose_partial(&to.m.inverse());
237
238 let change_permutation_from_from_to_to = |x: Perm| -> Perm {
239 let perm: Perm = x
240 .iter()
241 .filter_map(|(x, y)| {
242 if f.contains_key(x) && f.contains_key(y) {
243 Some((f[x], f[y]))
244 } else {
245 None
246 }
247 })
248 .collect();
249
250 if CHECKS {
251 assert!(perm.is_perm());
252 assert_eq!(perm.keys(), self.classes[&to.id].slots);
253 }
254
255 perm
256 };
257 #[cfg(feature = "explanations")]
258 let prf = self.proven_find_applied_id(&from).proof;
259 #[cfg(feature = "explanations")]
260 let prf_rev = self.prove_symmetry(prf.clone());
261
262 let change_proven_permutation_from_from_to_to = |proven_perm: ProvenPerm| {
263 let new_perm = change_permutation_from_from_to_to(proven_perm.elem);
264 #[cfg(feature = "explanations")]
265 let new_proof = self.prove_transitivity(
266 prf_rev.clone(),
267 self.prove_transitivity(proven_perm.proof, prf.clone()),
268 );
269 ProvenPerm {
270 elem: new_perm,
271 #[cfg(feature = "explanations")]
272 proof: new_proof,
273 #[cfg(feature = "explanations")]
274 reg: self.proof_registry.clone(),
275 }
276 };
277
278 let set = self.classes[&from.id]
279 .group
280 .generators()
281 .into_iter()
282 .map(change_proven_permutation_from_from_to_to)
283 .collect();
284
285 if self.classes.get_mut(&to.id).unwrap().group.add_set(set) {
286 self.touched_class(to.id, PendingType::Full);
287 }
288
289 self.touched_class(from.id, PendingType::Full);
291 }
292}