Skip to main content

slotted_egraphs/egraph/
union.rs

1use crate::*;
2
3impl<L: Language, N: Analysis<L>> EGraph<L, N> {
4    pub fn union(&mut self, l: &AppliedId, r: &AppliedId) -> bool {
5        self.union_justified(l, r, None)
6    }
7
8    pub fn union_justified(&mut self, l: &AppliedId, r: &AppliedId, j: Option<String>) -> bool {
9        let subst = [
10            (String::from("a"), l.clone()),
11            (String::from("b"), r.clone()),
12        ]
13        .into_iter()
14        .collect();
15        let a = Pattern::parse("?a").unwrap();
16        let b = Pattern::parse("?b").unwrap();
17
18        self.union_instantiations(&a, &b, &subst, j)
19    }
20
21    pub fn union_instantiations(
22        &mut self,
23        from_pat: &Pattern<L>,
24        to_pat: &Pattern<L>,
25        subst: &Subst,
26        #[allow(unused)] justification: Option<String>,
27    ) -> bool {
28        let a = pattern_subst(self, from_pat, subst);
29        let b = pattern_subst(self, to_pat, subst);
30
31        #[allow(unused)]
32        let syn_a = self.synify_app_id(a.clone());
33        #[allow(unused)]
34        let syn_b = self.synify_app_id(b.clone());
35
36        let proof = ghost!(self.prove_explicit(&syn_a, &syn_b, justification));
37
38        let out = self.union_internal(&a, &b, proof);
39        self.rebuild_called_from_union_instantiations();
40        out
41    }
42
43    fn rebuild_called_from_union_instantiations(&mut self) {
44        self.rebuild();
45    }
46
47    pub(in crate::egraph) fn union_internal(
48        &mut self,
49        l: &AppliedId,
50        r: &AppliedId,
51        #[allow(unused)] proof: ProvenEq,
52    ) -> bool {
53        // normalize inputs
54        let pai_l = self.proven_find_applied_id(&l);
55        let pai_r = self.proven_find_applied_id(&r);
56
57        let proof = ghost!({
58            if CHECKS {
59                pai_l.proof.check(self);
60                pai_r.proof.check(self);
61            }
62
63            let a = self.prove_symmetry(pai_l.proof);
64            let a = self.prove_transitivity(a, proof);
65            let a = self.prove_transitivity(a, pai_r.proof);
66            if CHECKS {
67                assert_eq!(a.l.id, pai_l.elem.id);
68                assert_eq!(a.r.id, pai_r.elem.id);
69            }
70            a
71        });
72        self.union_leaders(pai_l.elem, pai_r.elem, proof)
73    }
74
75    fn union_leaders(&mut self, l: AppliedId, r: AppliedId, proof: ProvenEq) -> bool {
76        // early return, if union should not be made.
77        if self.eq(&l, &r) {
78            return false;
79        }
80
81        let cap = &l.slots() & &r.slots();
82
83        if l.slots() != cap {
84            self.shrink_slots(&l, &cap, proof.clone());
85            self.union_internal(&l, &r, proof);
86            return true;
87        }
88
89        if r.slots() != cap {
90            let flipped_proof = ghost!(self.prove_symmetry(proof.clone()));
91            self.shrink_slots(&r, &cap, flipped_proof);
92            self.union_internal(&l, &r, proof);
93            return true;
94        }
95
96        if l.id == r.id {
97            let id = l.id;
98
99            // l.m :: slots(id) -> X
100            // r.m :: slots(id) -> X
101            // perm :: slots(id) -> slots(id)
102            let perm = l.m.compose(&r.m.inverse());
103            if CHECKS {
104                assert!(perm.is_perm());
105                assert_eq!(&perm.keys(), &self.classes[&id].slots);
106            }
107
108            let proven_perm = ProvenPerm {
109                elem: perm,
110                #[cfg(feature = "explanations")]
111                proof,
112                #[cfg(feature = "explanations")]
113                reg: self.proof_registry.clone(),
114            };
115
116            if CHECKS {
117                #[cfg(feature = "explanations")]
118                assert_eq!(proven_perm.proof.l.id, id);
119
120                proven_perm.check();
121            }
122            let grp = &mut self.classes.get_mut(&id).unwrap().group;
123            if grp.contains(&proven_perm.to_slotmap()) {
124                return false;
125            }
126
127            grp.add(proven_perm);
128
129            self.touched_class(id, PendingType::Full);
130
131            true
132        } else {
133            let slot_size = |i| self.classes[&i].syn_enode.slots().len();
134
135            let size = |i| {
136                let c = &self.classes[&i];
137                c.nodes.len() + c.usages.len()
138            };
139
140            // we intend to deprecate `l` in favor of `r`.
141            // return true if this is the correct decision.
142            let right_order = |l, r| {
143                // we prefer e-classes with e-nodes with few slots (i.e. prefer constants over e-node with redundancies).
144                // It generates easier proofs.
145                let (ssl, ssr) = (slot_size(l), slot_size(r));
146                if ssl > ssr {
147                    return true;
148                }
149                if ssl < ssr {
150                    return false;
151                }
152
153                // prefer bigger e-classes, because then we need to update less.
154                size(l) <= size(r)
155            };
156
157            if right_order(l.id, r.id) {
158                self.move_to(&l, &r, proof)
159            } else {
160                let proof = ghost!(self.prove_symmetry(proof));
161                self.move_to(&r, &l, proof)
162            }
163
164            true
165        }
166    }
167
168    fn assert_ty(&self, m: &SlotMap, keys: &SmallHashSet<Slot>, values: &SmallHashSet<Slot>) {
169        assert!(m.keys().is_subset(keys));
170        assert!(m.values().is_subset(values));
171    }
172
173    // moves everything from `from` to `to`.
174    fn move_to(&mut self, from: &AppliedId, to: &AppliedId, #[allow(unused)] proof: ProvenEq) {
175        if CHECKS {
176            assert_eq!(from.slots(), to.slots());
177            #[cfg(feature = "explanations")]
178            assert_eq!(from.id, proof.l.id);
179            #[cfg(feature = "explanations")]
180            assert_eq!(to.id, proof.r.id);
181        }
182
183        {
184            let analysis_from = self.analysis_data(from.id).clone();
185            let analysis_to = self.analysis_data_mut(to.id);
186            let old_analysis_to = analysis_to.clone();
187            let new_analysis_to = N::merge(analysis_from, analysis_to.clone());
188            let changed = old_analysis_to != new_analysis_to;
189            *analysis_to = new_analysis_to;
190
191            if changed {
192                self.modify_queue.push(to.id);
193                self.touched_class(to.id, PendingType::OnlyAnalysis);
194            }
195        }
196
197        // from.m :: slots(from.id) -> X
198        // to.m :: slots(to.id) -> X
199        let map = to.m.compose_partial(&from.m.inverse());
200        if CHECKS {
201            self.assert_ty(&map, &self.slots(to.id), &self.slots(from.id));
202        }
203
204        let app_id = self.mk_sem_applied_id(to.id, map.clone());
205        let pai = ProvenAppliedId {
206            elem: app_id,
207
208            #[cfg(feature = "explanations")]
209            proof,
210        };
211        self.unionfind_set(from.id, pai);
212
213        // who updates the usages? raw_add_to_class & raw_remove_from_class do that.
214
215        let from_nodes = self.classes.get(&from.id).unwrap().nodes.clone();
216        for (sh, psn) in from_nodes {
217            self.raw_remove_from_class(from.id, sh.clone());
218            // if `sh` contains redundant slots, these won't be covered by 'map'.
219            // Thus we need compose_fresh.
220            let new_bij = psn.elem.compose_fresh(&map.inverse());
221
222            let src_id = psn.src_id;
223
224            self.raw_add_to_class(to.id, (sh.clone(), new_bij), src_id);
225            self.pending.insert(sh, PendingType::Full);
226        }
227
228        // re-add the group equations as well.
229
230        // This basically calls self.union(from, from * perm) for each perm generator in the group of from.
231        // from.m :: slots(from.id) -> C
232        // to.m :: slots(to.id) -> C
233
234        // f :: slots(from.id) -> slots(to.id)
235        // Note that f is a partial map, because some slots might have become redundant.
236        let f = from.m.compose_partial(&to.m.inverse());
237
238        let change_permutation_from_from_to_to = |x: Perm| -> Perm {
239            let perm: Perm = x
240                .iter()
241                .filter_map(|(x, y)| {
242                    if f.contains_key(x) && f.contains_key(y) {
243                        Some((f[x], f[y]))
244                    } else {
245                        None
246                    }
247                })
248                .collect();
249
250            if CHECKS {
251                assert!(perm.is_perm());
252                assert_eq!(perm.keys(), self.classes[&to.id].slots);
253            }
254
255            perm
256        };
257        #[cfg(feature = "explanations")]
258        let prf = self.proven_find_applied_id(&from).proof;
259        #[cfg(feature = "explanations")]
260        let prf_rev = self.prove_symmetry(prf.clone());
261
262        let change_proven_permutation_from_from_to_to = |proven_perm: ProvenPerm| {
263            let new_perm = change_permutation_from_from_to_to(proven_perm.elem);
264            #[cfg(feature = "explanations")]
265            let new_proof = self.prove_transitivity(
266                prf_rev.clone(),
267                self.prove_transitivity(proven_perm.proof, prf.clone()),
268            );
269            ProvenPerm {
270                elem: new_perm,
271                #[cfg(feature = "explanations")]
272                proof: new_proof,
273                #[cfg(feature = "explanations")]
274                reg: self.proof_registry.clone(),
275            }
276        };
277
278        let set = self.classes[&from.id]
279            .group
280            .generators()
281            .into_iter()
282            .map(change_proven_permutation_from_from_to_to)
283            .collect();
284
285        if self.classes.get_mut(&to.id).unwrap().group.add_set(set) {
286            self.touched_class(to.id, PendingType::Full);
287        }
288
289        // touched because the class is now dead and no e-nodes should point to it.
290        self.touched_class(from.id, PendingType::Full);
291    }
292}