Skip to main content

slotted_egraphs/egraph/
add.rs

1use vec_collections::AbstractVecSet;
2
3use crate::*;
4
5// syntactic add:
6impl<L: Language, N: Analysis<L>> EGraph<L, N> {
7    pub fn add_syn_expr(&mut self, re: RecExpr<L>) -> AppliedId {
8        let mut n = re.node;
9        let mut refs: Vec<&mut AppliedId> = n.applied_id_occurrences_mut();
10        if CHECKS {
11            assert_eq!(re.children.len(), refs.len());
12        }
13        for (i, child) in re.children.into_iter().enumerate() {
14            *(refs[i]) = self.add_syn_expr(child);
15        }
16        self.add_syn(n)
17    }
18
19    pub fn add_syn(&mut self, enode: L) -> AppliedId {
20        #[cfg(not(feature = "explanations"))]
21        {
22            self.add(enode)
23        }
24
25        #[cfg(feature = "explanations")]
26        {
27            let enode = self.synify_enode(enode);
28
29            self.add(enode.clone());
30
31            if let Some(x) = self.lookup_syn(&enode) {
32                if CHECKS {
33                    assert_eq!(enode.slots(), x.slots());
34                }
35                return x;
36            }
37
38            let old_slots = enode.slots();
39            let fresh_to_old = Bijection::bijection_from_fresh_to(&old_slots);
40            let old_to_fresh = fresh_to_old.inverse();
41            let new_enode = enode.apply_slotmap(&old_to_fresh);
42            let c = self.alloc_eclass(&old_to_fresh.values(), new_enode.clone());
43
44            let pc = self.pc_find(&self.refl_pc(c));
45
46            self.handle_congruence(pc);
47
48            let c_a = self.mk_syn_applied_id(c, fresh_to_old.clone());
49            if CHECKS {
50                assert_eq!(enode.slots(), c_a.slots());
51            }
52
53            c_a
54        }
55    }
56
57    #[cfg(feature = "explanations")]
58    fn lookup_syn(&self, enode: &L) -> Option<AppliedId> {
59        let (sh, bij) = enode.weak_shape();
60        let i = self.syn_hashcons.get(&sh)?;
61
62        // bij :: SHAPE -> X
63        // i :: slots(i.id) -> SHAPE
64        let i = i.apply_slotmap(&bij);
65        Some(i)
66    }
67}
68
69// semantic add:
70impl<L: Language, N: Analysis<L>> EGraph<L, N> {
71    pub fn add_expr(&mut self, re: RecExpr<L>) -> AppliedId {
72        let mut n = re.node;
73        let mut refs: Vec<&mut AppliedId> = n.applied_id_occurrences_mut();
74        if CHECKS {
75            assert_eq!(re.children.len(), refs.len());
76        }
77        for (i, child) in re.children.into_iter().enumerate() {
78            *(refs[i]) = self.add_expr(child);
79        }
80        self.add(n)
81    }
82
83    pub fn add(&mut self, enode: L) -> AppliedId {
84        self.add_internal(self.shape_called_from_add(enode))
85    }
86
87    fn shape_called_from_add(&self, enode: L) -> (L, Bijection) {
88        self.shape(&enode)
89    }
90
91    // self.add(x) = y implies that x.slots() is a superset of y.slots().
92    // x.slots() - y.slots() are redundant slots.
93    pub(in crate::egraph) fn add_internal(&mut self, t: (L, SlotMap)) -> AppliedId {
94        if let Some(x) = self.lookup_internal(&t) {
95            return x;
96        }
97
98        // TODO this code is kinda exactly what add_syn is supposed to do anyways. There's probably a way to write this more concisely.
99        // We convert the enode to "syn" so that semantic_add will compute the necessary redundancy proofs.
100        let enode = t.0.refresh_private().apply_slotmap(&t.1);
101        let enode = self.synify_enode(enode);
102
103        let syn = self.mk_singleton_class(enode);
104        self.semify_app_id(syn)
105    }
106
107    pub fn lookup(&self, n: &L) -> Option<AppliedId> {
108        self.lookup_internal(&self.shape(n))
109    }
110
111    pub(in crate::egraph) fn lookup_internal(
112        &self,
113        (shape, n_bij): &(L, Bijection),
114    ) -> Option<AppliedId> {
115        let i = self.hashcons.get(&shape)?;
116        let c = &self.classes[i];
117        let cn_bij = &c.nodes[&shape].elem;
118
119        // X = shape.slots()
120        // Y = n.slots()
121        // Z = c.slots()
122        // n_bij :: X -> Y
123        // cn_bij :: X -> Z
124        // out :: Z -> Y
125        let out = cn_bij.inverse().compose(&n_bij);
126
127        // Note that ENodes in an EClass can have redundant slots.
128        // They shouldn't come up in the AppliedId.
129        let out = out.iter().filter(|(x, _)| c.slots.contains(x)).collect();
130
131        let app_id = self.mk_sem_applied_id(*i, out);
132
133        if CHECKS {
134            assert_eq!(&c.slots, &app_id.m.keys());
135        }
136
137        Some(app_id)
138    }
139}
140
141impl<L: Language, N: Analysis<L>> EGraph<L, N> {
142    // returns a syn applied id.
143    fn mk_singleton_class(&mut self, syn_enode: L) -> AppliedId {
144        let old_slots = syn_enode.slots();
145
146        let fresh_to_old = Bijection::bijection_from_fresh_to(&old_slots);
147        let old_to_fresh = fresh_to_old.inverse();
148
149        // allocate new class & slot set.
150        let fresh_slots = old_to_fresh.values();
151        let syn_enode_fresh = syn_enode.apply_slotmap_fresh(&old_to_fresh);
152        let i = self.alloc_eclass(&fresh_slots, syn_enode_fresh.clone());
153
154        // we use semantic_add so that the redundancy, symmetry and congruence checks run on it.
155        let t = syn_enode_fresh.weak_shape();
156        self.raw_add_to_class(i, t.clone(), i);
157        self.pending.insert(t.0, PendingType::Full);
158        self.modify_queue.push(i);
159        self.rebuild_called_from_add();
160
161        self.mk_syn_applied_id(i, fresh_to_old)
162    }
163
164    fn rebuild_called_from_add(&mut self) {
165        self.rebuild();
166    }
167
168    // adds (sh, bij) to the eclass `id`.
169    // TODO src_id should be optional!
170    pub(in crate::egraph) fn raw_add_to_class(
171        &mut self,
172        id: Id,
173        (sh, bij): (L, Bijection),
174        src_id: Id,
175    ) {
176        let psn = ProvenSourceNode { elem: bij, src_id };
177
178        let tmp1 = self
179            .classes
180            .get_mut(&id)
181            .unwrap()
182            .nodes
183            .insert(sh.clone(), psn);
184        let tmp2 = self.hashcons.insert(sh.clone(), id);
185        if CHECKS {
186            assert!(tmp1.is_none());
187            assert!(tmp2.is_none());
188        }
189        for ref_id in sh.ids() {
190            let usages = &mut self.classes.get_mut(&ref_id).unwrap().usages;
191            usages.insert(sh.clone());
192        }
193    }
194
195    pub(in crate::egraph) fn raw_remove_from_class(&mut self, id: Id, sh: L) -> ProvenSourceNode {
196        let opt_psn = self.classes.get_mut(&id).unwrap().nodes.remove(&sh);
197        let opt_id = self.hashcons.remove(&sh);
198        if CHECKS {
199            assert!(opt_psn.is_some());
200            assert!(opt_id.is_some());
201        }
202        for ref_id in sh.ids() {
203            let usages = &mut self.classes.get_mut(&ref_id).unwrap().usages;
204            usages.remove(&sh);
205        }
206
207        opt_psn.unwrap()
208    }
209}
210
211impl<L: Language, N: Analysis<L>> EGraph<L, N> {
212    // TODO make the public API auto "fresh" slots.
213    #[allow(unused_variables)]
214    pub fn alloc_empty_eclass(&mut self, slots: &SmallHashSet<Slot>) -> Id {
215        panic!("Can't use alloc_empty_eclass if explanations are enabled!");
216    }
217
218    pub(in crate::egraph) fn alloc_eclass(
219        &mut self,
220        slots: &SmallHashSet<Slot>,
221        syn_enode: L,
222    ) -> Id {
223        let c_id = Id(self.unionfind_len()); // Pick the next unused Id.
224
225        let syn_slots = syn_enode.slots();
226        let proven_perm =
227            ProvenPerm::identity(c_id, &slots, &syn_slots, self.proof_registry.clone());
228
229        let c = EClass {
230            nodes: HashMap::default(),
231            group: Group::identity(&proven_perm),
232            slots: slots.clone(),
233            usages: HashSet::default(),
234            syn_enode: syn_enode.clone(),
235            analysis_data: N::make(&self, &syn_enode),
236        };
237        self.classes.insert(c_id, c);
238
239        {
240            // add syn_enode to the hashcons.
241            let (sh, bij) = syn_enode.weak_shape();
242
243            if CHECKS {
244                assert!(!self.syn_hashcons.contains_key(&sh));
245            }
246
247            let app_id = self.mk_syn_applied_id(c_id, bij.inverse());
248            self.syn_hashcons.insert(sh, app_id);
249        }
250
251        let syn_app_id = self.mk_syn_identity_applied_id(c_id);
252        let pai = self.refl_pai(&syn_app_id);
253        self.unionfind_set(c_id, pai);
254
255        c_id
256    }
257}