slotted_egraphs/egraph/
mod.rs

1use crate::*;
2
3mod find;
4pub use find::*;
5
6mod add;
7pub use add::*;
8
9mod union;
10pub use union::*;
11
12mod rebuild;
13pub use rebuild::*;
14
15mod check;
16pub use check::*;
17
18mod analysis;
19pub use analysis::*;
20use vec_collections::AbstractVecSet;
21
22use std::cell::RefCell;
23
24// invariants:
25// 1. If two ENodes (that are in the EGraph) have equal .shape(), they have to be in the same eclass.
26// 2. enode.slots() is always a superset of c.slots, if enode is within the eclass c.
27//    if ENode::Lam(si) = enode, then we require i to not be in c.slots.
28//    In practice, si will always be Slot(0).
29// 3. AppliedId::m is always a bijection. (eg. c1(s0, s1, s0) is illegal!)
30//    AppliedId::m also always has the same keys as the class expects slots.
31// 4. Slot(0) should not be in EClass::slots of any class.
32/// A datastructure to efficiently represent congruence relations on terms with binders.
33pub struct EGraph<L: Language, N: Analysis<L> = ()> {
34    // an entry (l, r(sa, sb)) in unionfind corresponds to the equality l(s0, s1, s2) = r(sa, sb), where sa, sb in {s0, s1, s2}.
35    // normalizes the eclass.
36    // Each Id i that is an output of the unionfind itself has unionfind[i] = (i, identity()).
37
38    // We use RefCell to allow for inter mutability, so that find(&self) can do path compression.
39    unionfind: RefCell<Vec<ProvenAppliedId>>,
40
41    // if a class does't have unionfind[x].id = x, then it doesn't contain nodes / usages.
42    // It's "shallow" if you will.
43    pub(crate) classes: HashMap<Id, EClass<L, N>>,
44
45    // For each shape contained in the EGraph, maps to the EClass that contains it.
46    hashcons: HashMap<L, Id>,
47
48    // For each (syn_slotset applied) non-normalized (i.e. "syntactic") weak shape, find the e-class who has this as syn_enode.
49    // TODO remove this if explanations are disabled.
50    syn_hashcons: HashMap<L, AppliedId>,
51
52    // E-Nodes that need to be re-processed, stored as shapes.
53    pending: HashMap<L, PendingType>,
54
55    // TODO remove this if explanations are disabled.
56    pub(crate) proof_registry: ProofRegistry,
57
58    pub(crate) subst_method: Option<Box<dyn SubstMethod<L, N>>>,
59
60    pub analysis: N,
61
62    // N::modify(_) will be run on these classes.
63    // We delay handling modify so that all invariants can be rebuild again, first.
64    modify_queue: Vec<Id>,
65}
66
67#[derive(Clone, Copy, PartialEq, Eq, Debug)]
68pub(crate) enum PendingType {
69    OnlyAnalysis, // only analysis needs to be updated.
70    Full,         // the e-node, it's strong shape & the analysis need to be updated.
71}
72
73/// Each E-Class can be understood "semantically" or "syntactically":
74/// - semantically means that it respects the equations already in the e-graph, and hence doesn't differentiate between equal things.
75/// - syntactically means that it only talks about the single representative term associated to each E-Class, recursively obtainable using syn_enode.
76#[derive(Clone)]
77pub(crate) struct EClass<L: Language, N: Analysis<L>> {
78    // The set of equivalent ENodes that make up this eclass.
79    // for (sh, bij) in nodes; sh.apply_slotmap(bij) represents the actual ENode.
80    nodes: HashMap<L, ProvenSourceNode>,
81
82    // All other slots are considered "redundant" (or they have to be qualified by a ENode::Lam).
83    // Should not contain Slot(0).
84    slots: SmallHashSet<Slot>,
85
86    // Shows which Shapes refer to this EClass.
87    usages: HashSet<L>,
88
89    // Expresses the self-symmetries of this e-class.
90    pub(crate) group: Group<ProvenPerm>,
91
92    // TODO remove this if explanations are disabled.
93    syn_enode: L,
94
95    analysis_data: N::Data,
96}
97
98impl<L: Language, N: Analysis<L> + Default> Default for EGraph<L, N> {
99    fn default() -> Self {
100        EGraph::new(N::default())
101    }
102}
103
104impl<L: Language, N: Analysis<L>> EGraph<L, N> {
105    /// Creates an empty e-graph.
106    pub fn new(analysis: N) -> Self {
107        Self::with_subst_method::<SynExprSubst>(analysis)
108    }
109
110    /// Creates an empty e-graph, while specifying the substitution method to use.
111    pub fn with_subst_method<S: SubstMethod<L, N>>(analysis: N) -> Self {
112        EGraph {
113            unionfind: Default::default(),
114            classes: Default::default(),
115            hashcons: Default::default(),
116            syn_hashcons: Default::default(),
117            pending: Default::default(),
118            proof_registry: ProofRegistry::default(),
119            subst_method: Some(S::new_boxed()),
120            analysis,
121            modify_queue: Vec::new(),
122        }
123    }
124
125    pub fn slots(&self, id: Id) -> SmallHashSet<Slot> {
126        self.classes[&id].slots.clone()
127    }
128
129    pub(crate) fn syn_slots(&self, id: Id) -> SmallHashSet<Slot> {
130        self.classes[&id].syn_enode.slots()
131    }
132
133    pub fn analysis_data(&self, i: Id) -> &N::Data {
134        &self.classes[&self.find_id(i)].analysis_data
135    }
136
137    pub fn analysis_data_mut(&mut self, i: Id) -> &mut N::Data {
138        &mut self
139            .classes
140            .get_mut(&self.find_id(i))
141            .unwrap()
142            .analysis_data
143    }
144
145    pub fn enodes(&self, i: Id) -> HashSet<L> {
146        // We prevent this, as otherwise the output will have wrong slots.
147        assert!(self.is_alive(i), "Can't access e-nodes of dead class");
148
149        self.classes[&i]
150            .nodes
151            .iter()
152            .map(|(x, psn)| x.apply_slotmap(&psn.elem))
153            .collect()
154    }
155
156    // Generates fresh slots for redundant slots.
157    pub fn enodes_applied(&self, i: &AppliedId) -> Vec<L> {
158        let class = &self.classes[&i.id];
159        let class_slots = &class.slots;
160
161        let mut result = Vec::with_capacity(class.nodes.len());
162
163        for (x, psn) in &class.nodes {
164            let mut x = x.apply_slotmap(&psn.elem);
165
166            let mut map: SmallHashMap<Slot, Slot> = SmallHashMap::default();
167            for slot in x.all_slot_occurrences_mut() {
168                if !class_slots.contains(&slot) {
169                    if let Some(v) = map.get(slot) {
170                        *slot = *v;
171                    } else {
172                        let v = Slot::fresh();
173                        map.insert(slot.clone(), v.clone());
174                        *slot = v;
175                    }
176                }
177            }
178
179            let mut m = SlotMap::new();
180            for slot in x.slots() {
181                if !i.m.contains_key(slot) {
182                    m.insert(slot, Slot::fresh());
183                }
184            }
185
186            for (x, y) in i.m.iter() {
187                m.insert(x, y);
188            }
189
190            x = x.apply_slotmap(&m);
191            result.push(x);
192        }
193
194        result
195    }
196
197    // number of enodes in the egraph.
198    pub fn total_number_of_nodes(&self) -> usize {
199        self.hashcons.len()
200    }
201
202    /// Checks that two AppliedIds are semantically equal.
203    pub fn eq(&self, a: &AppliedId, b: &AppliedId) -> bool {
204        let a = self.find_applied_id(a);
205        let b = self.find_applied_id(b);
206
207        if CHECKS {
208            self.check_sem_applied_id(&a);
209            self.check_sem_applied_id(&b);
210        }
211
212        if a.id != b.id {
213            return false;
214        }
215        if a.m.values() != b.m.values() {
216            return false;
217        }
218        let id = a.id;
219
220        let perm = a.m.compose(&b.m.inverse());
221        if CHECKS {
222            assert!(perm.is_perm());
223            assert_eq!(&perm.values(), &self.classes[&id].slots);
224        }
225
226        self.classes[&id].group.contains(&perm)
227    }
228
229    // refreshes all internal slots of l.
230    pub(crate) fn refresh_internals(&self, l: &L) -> L {
231        let i = self.lookup(l).unwrap();
232        l.refresh_internals(i.slots())
233    }
234
235    // converts l to its class normal form, so that calling lookup on it yields the identity AppliedId.
236    pub(crate) fn class_nf(&self, l: &L) -> L {
237        let l = self.refresh_internals(l);
238        let i = self.lookup(&l).unwrap();
239
240        // needs to be `apply_slotmap_fresh` in case `l` has redundancies.
241        let l = l.apply_slotmap_fresh(&i.m);
242
243        if CHECKS {
244            let identity = self.mk_sem_identity_applied_id(i.id);
245            assert!(self.eq(&i, &identity));
246        }
247
248        l
249    }
250
251    /// Prints the contents of the E-Graph. Helpful for debugging.
252    pub fn dump(&self) {
253        println!("");
254        let mut v: Vec<(&Id, &EClass<L, N>)> = self.classes.iter().collect();
255        v.sort_by_key(|(x, _)| *x);
256
257        for (i, c) in v {
258            if c.nodes.len() == 0 {
259                continue;
260            }
261
262            let mut slot_order: Vec<Slot> = c.slots.iter().cloned().collect();
263            slot_order.sort();
264            let slot_str = slot_order
265                .iter()
266                .map(|x| x.to_string())
267                .collect::<Vec<_>>()
268                .join(", ");
269            println!("\n{:?}({}):", i, &slot_str);
270
271            println!(">> {:?}", &c.syn_enode);
272
273            for (sh, psn) in &c.nodes {
274                let n = sh.apply_slotmap(&psn.elem);
275
276                #[cfg(feature = "explanations")]
277                println!(" - {n:?}    [originally {:?}]", psn.src_id);
278
279                #[cfg(not(feature = "explanations"))]
280                println!(" - {n:?}");
281            }
282            for pp in &c.group.generators() {
283                println!(" -- {:?}", pp.elem);
284            }
285        }
286        println!("");
287    }
288
289    // The resulting e-nodes are written as they exist in the e-class.
290    pub(crate) fn usages(&self, i: Id) -> Vec<L> {
291        let mut out = Vec::new();
292        for x in &self.classes[&i].usages {
293            let j = self.lookup(x).unwrap().id;
294            let bij = &self.classes[&j].nodes[&x].elem;
295            let x = x.apply_slotmap(bij);
296            out.push(x);
297        }
298        out
299    }
300
301    pub(crate) fn shape(&self, e: &L) -> (L, Bijection) {
302        let (pnode, bij) = self.proven_shape(e);
303        (pnode.elem, bij)
304    }
305
306    pub(crate) fn proven_shape(&self, e: &L) -> (ProvenNode<L>, Bijection) {
307        self.proven_proven_shape(&self.refl_pn(e))
308    }
309
310    pub(crate) fn proven_proven_shape(&self, e: &ProvenNode<L>) -> (ProvenNode<L>, Bijection) {
311        self.proven_proven_pre_shape(&e).weak_shape()
312    }
313
314    pub(crate) fn proven_proven_pre_shape(&self, e: &ProvenNode<L>) -> ProvenNode<L> {
315        let e = self.proven_proven_find_enode(e);
316        self.proven_proven_get_group_compatible_variants(&e)
317            .into_iter()
318            .min_by_key(|pn| pn.weak_shape().0.elem.all_slot_occurrences())
319            .unwrap()
320    }
321
322    // We want to compute the shape of an e-node n := f(c[$x, $y], c[$y, $x]), where c[$x, $y] = c[$y, $x].
323    // The (strong) shape of f(c[$x, $y], c[$y, $x]) is f(c[$0, $1], c[$0, $1]), whereas the
324    //     weak     shape of f(...)                  is f(c[$0, $1], c[$1, $0]).
325    // Basically, the weak shape doesn't respect group symmetries, while the strong shape does.
326
327    // We first compute the set of e-nodes equivalent to n by group symmetries.
328    // This set would be
329    // {f(c[$x, $y], c[$y, $x]),
330    //  f(c[$y, $x], c[$y, $x]),
331    //  f(c[$x, $y], c[$x, $y]),
332    //  f(c[$y, $x], c[$x, $y])}
333    // This set is what the proven_proven_get_group_compatible_variants returns.
334    // Now: we want to compute the "weak shapes" of them, which means to replace names by numbers (by going through the slots left to right).
335    // When computing the weak shapes, we only have
336    // {f(c[$0, $1], c[$1, $0]),
337    //  f(c[$0, $1], c[$0, $1])}
338    // This is what get_group_compatible_weak_variants would return.
339    pub(crate) fn proven_proven_get_group_compatible_variants(
340        &self,
341        enode: &ProvenNode<L>,
342    ) -> Vec<ProvenNode<L>> {
343        // should only be called with an up-to-date e-node.
344        if CHECKS {
345            for x in enode.elem.applied_id_occurrences() {
346                assert!(self.is_alive(x.id));
347            }
348        }
349
350        let mut out = Vec::new();
351
352        // early-return, if groups are all trivial.
353        if enode
354            .elem
355            .ids()
356            .iter()
357            .all(|i| self.classes[i].group.is_trivial())
358        {
359            out.push(enode.clone());
360            return out;
361        }
362
363        let groups: Vec<Vec<ProvenPerm>> = enode
364            .elem
365            .applied_id_occurrences()
366            .iter()
367            .map(|x| self.classes[&x.id].group.all_perms().into_iter().collect())
368            .collect();
369
370        for l in cartesian(&groups) {
371            let pn = enode.clone();
372            let pn = self.chain_pn_map(&pn, |i, pai| self.chain_pai_pp(&pai, l[i]));
373            // TODO fix check.
374            // if CHECKS { pn.check_base(enode.base()); }
375            out.push(pn);
376        }
377
378        out
379    }
380
381    // for all AppliedIds that are contained in `enode`, permute their arguments as their groups allow.
382    // TODO every usage of this function hurts performance drastically. Which of them can I eliminate?
383    pub(crate) fn proven_get_group_compatible_variants(&self, enode: &L) -> Vec<ProvenNode<L>> {
384        self.proven_proven_get_group_compatible_variants(&self.refl_pn(enode))
385    }
386
387    pub(crate) fn get_group_compatible_variants(&self, enode: &L) -> Vec<L> {
388        self.proven_get_group_compatible_variants(enode)
389            .into_iter()
390            .map(|pnode| pnode.elem)
391            .collect()
392    }
393
394    pub(crate) fn get_group_compatible_weak_variants(&self, enode: &L) -> Vec<L> {
395        let set = self.get_group_compatible_variants(enode);
396        let mut shapes = SmallHashSet::empty();
397        let mut out = Vec::new();
398
399        for x in set {
400            let (sh, _) = x.weak_shape();
401            if shapes.contains(&sh) {
402                continue;
403            }
404            shapes.insert(sh);
405            out.push(x);
406        }
407
408        out
409    }
410
411    pub(crate) fn synify_app_id(&self, app: AppliedId) -> AppliedId {
412        let mut app = app;
413        for s in self.syn_slots(app.id) {
414            if !app.m.contains_key(s) {
415                app.m.insert(s, Slot::fresh());
416            }
417        }
418        app
419    }
420
421    pub(crate) fn synify_enode(&self, enode: L) -> L {
422        enode.map_applied_ids(|app| self.synify_app_id(app))
423    }
424
425    pub(crate) fn semify_app_id(&self, app: AppliedId) -> AppliedId {
426        let slots = self.slots(app.id);
427
428        let mut app = app;
429        for k in app.m.keys() {
430            if !slots.contains(&k) {
431                app.m.remove(k);
432            }
433        }
434        app
435    }
436
437    #[cfg(feature = "explanations")]
438    pub(crate) fn semify_enode(&self, enode: L) -> L {
439        enode.map_applied_ids(|app| self.semify_app_id(app))
440    }
441
442    /// Returns the canonical term corresponding to `i`.
443    ///
444    /// This function will use [EGraph::get_syn_node] repeatedly to build up this term.
445    pub fn get_syn_expr(&self, i: &AppliedId) -> RecExpr<L> {
446        let enode = self.get_syn_node(i);
447        let cs = enode
448            .applied_id_occurrences()
449            .iter()
450            .map(|x| self.get_syn_expr(x))
451            .collect();
452        RecExpr {
453            node: nullify_app_ids(&enode),
454            children: cs,
455        }
456    }
457
458    /// Returns the canonical e-node corresponding to `i`.
459    pub fn get_syn_node(&self, i: &AppliedId) -> L {
460        let syn = &self.classes[&i.id].syn_enode;
461        syn.apply_slotmap(&i.m)
462    }
463}
464
465impl PendingType {
466    pub(crate) fn merge(self, other: PendingType) -> PendingType {
467        match (self, other) {
468            (PendingType::Full, _) => PendingType::Full,
469            (_, PendingType::Full) => PendingType::Full,
470            (PendingType::OnlyAnalysis, PendingType::OnlyAnalysis) => PendingType::OnlyAnalysis,
471        }
472    }
473}
474
475// {1,2} x {3} x {4,5} -> (1,3,4), (1,3,5), (2,3,4), (2,3,5)
476fn cartesian<'a, T>(input: &'a [Vec<T>]) -> impl Iterator<Item = Vec<&'a T>> + use<'a, T> {
477    let n = input.len();
478    let mut indices = vec![0; n];
479    let mut done = false;
480    let f = move || {
481        if done {
482            return None;
483        }
484        let out: Vec<&T> = (0..n).map(|i| &input[i][indices[i]]).collect();
485        for i in 0..n {
486            indices[i] += 1;
487            if indices[i] >= input[i].len() {
488                indices[i] = 0;
489            } else {
490                return Some(out);
491            }
492        }
493        done = true;
494        Some(out)
495    };
496    std::iter::from_fn(f)
497}
498
499#[test]
500fn cartesian1() {
501    let v = [vec![1, 2], vec![3], vec![4, 5]];
502    let vals = cartesian(&v);
503    assert_eq!(vals.count(), 4);
504}