Skip to main content

chomsky_uir/
egraph.rs

1pub use crate::union_find::Id;
2use crate::union_find::UnionFind;
3use chomsky_types::Loc;
4use dashmap::DashMap;
5use std::hash::Hash;
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8pub trait Language: Hash + Eq + Clone + Ord {
9    fn children(&self) -> Vec<Id>;
10    fn map_children(&self, f: impl FnMut(Id) -> Id) -> Self;
11}
12
13#[derive(Debug, Clone)]
14pub struct EClass<L: Language, D> {
15    pub id: Id,
16    pub nodes: Vec<L>,
17    pub data: D,
18}
19
20pub trait Analysis<L: Language>: Default {
21    type Data;
22    fn make(egraph: &EGraph<L, Self>, enode: &L) -> Self::Data;
23    fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> bool;
24    fn on_add(&self, _data: &mut Self::Data, _loc: Loc) {}
25    fn is_compatible(&self, _data1: &Self::Data, _data2: &Self::Data) -> bool {
26        true
27    }
28}
29
30#[derive(Debug)]
31pub struct EGraph<L: Language, A: Analysis<L> = ()> {
32    pub union_find: UnionFind,
33    pub classes: DashMap<Id, EClass<L, A::Data>>,
34    pub memo: DashMap<L, Id>,
35    pub analysis: std::sync::RwLock<A>,
36    pub next_id: AtomicUsize,
37    pub dirty: DashMap<Id, ()>,
38}
39
40impl<L: Language> Analysis<L> for () {
41    type Data = ();
42    fn make(_egraph: &EGraph<L, Self>, _enode: &L) -> Self::Data {
43        ()
44    }
45    fn merge(&mut self, _to: &mut Self::Data, _from: Self::Data) -> bool {
46        false
47    }
48}
49
50/// A specialized analysis for tracking debug information (Loc)
51#[derive(Default)]
52pub struct DebugAnalysis;
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct DebugData {
56    pub locs: Vec<Loc>,
57}
58
59pub trait HasDebugInfo {
60    fn get_locs(&self) -> &[Loc];
61}
62
63impl HasDebugInfo for DebugData {
64    fn get_locs(&self) -> &[Loc] {
65        &self.locs
66    }
67}
68
69impl HasDebugInfo for () {
70    fn get_locs(&self) -> &[Loc] {
71        &[]
72    }
73}
74
75impl<L: Language> Analysis<L> for DebugAnalysis {
76    type Data = DebugData;
77
78    fn make(_egraph: &EGraph<L, Self>, _enode: &L) -> Self::Data {
79        DebugData { locs: Vec::new() }
80    }
81
82    fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> bool {
83        let old_len = to.locs.len();
84        for loc in from.locs {
85            if !to.locs.contains(&loc) {
86                to.locs.push(loc);
87            }
88        }
89        to.locs.len() > old_len
90    }
91
92    fn on_add(&self, data: &mut Self::Data, loc: Loc) {
93        if !data.locs.contains(&loc) {
94            data.locs.push(loc);
95        }
96    }
97}
98
99impl<L: Language, A: Analysis<L>> EGraph<L, A> {
100    pub fn get_class(&self, id: Id) -> dashmap::mapref::one::Ref<'_, Id, EClass<L, A::Data>> {
101        let root = self.union_find.find(id);
102        self.classes.get(&root).expect("Class not found")
103    }
104
105    pub fn new() -> Self {
106        Self {
107            union_find: UnionFind::new(),
108            classes: DashMap::new(),
109            memo: DashMap::new(),
110            analysis: std::sync::RwLock::new(A::default()),
111            next_id: AtomicUsize::new(0),
112            dirty: DashMap::new(),
113        }
114    }
115
116    pub fn add(&self, enode: L) -> Id {
117        let canonical = enode.map_children(|id| self.union_find.find(id));
118        if let Some(id) = self.memo.get(&canonical) {
119            return self.union_find.find(*id);
120        }
121
122        let id = Id::from(self.next_id.fetch_add(1, Ordering::SeqCst));
123
124        let data = A::make(self, &canonical);
125
126        self.memo.insert(canonical.clone(), id);
127        let eclass = EClass {
128            id,
129            nodes: vec![canonical],
130            data,
131        };
132        self.classes.insert(id, eclass);
133
134        id
135    }
136
137    pub fn add_with_loc(&self, enode: L, loc: Loc) -> Id {
138        let id = self.add(enode);
139        let root = self.union_find.find(id);
140        if let Some(mut eclass) = self.classes.get_mut(&root) {
141            let analysis = self.analysis.read().unwrap();
142            analysis.on_add(&mut eclass.data, loc);
143        }
144        root
145    }
146
147    pub fn union(&self, id1: Id, id2: Id) -> Id {
148        let root1 = self.union_find.find(id1);
149        let root2 = self.union_find.find(id2);
150        if root1 == root2 {
151            return root1;
152        }
153
154        // --- Conflict Detection ---
155        {
156            let analysis = self.analysis.read().unwrap();
157            let data1 = &self.classes.get(&root1).unwrap().data;
158            let data2 = &self.classes.get(&root2).unwrap().data;
159
160            if !analysis.is_compatible(data1, data2) {
161                // Return one of the roots without merging
162                return root1;
163            }
164        }
165
166        let new_root = self.union_find.union(root1, root2);
167        let old_root = if new_root == root1 { root2 } else { root1 };
168
169        self.dirty.insert(new_root, ());
170
171        if let Some((_, old_class)) = self.classes.remove(&old_root) {
172            let mut new_class = self.classes.get_mut(&new_root).unwrap();
173            for node in old_class.nodes {
174                if !new_class.nodes.contains(&node) {
175                    new_class.nodes.push(node);
176                }
177            }
178
179            // Actual analysis merge
180            let mut analysis = self.analysis.write().unwrap();
181            analysis.merge(&mut new_class.data, old_class.data);
182        }
183
184        new_root
185    }
186
187    pub fn rebuild(&self) {
188        while !self.dirty.is_empty() {
189            let mut todo = Vec::new();
190            let dirty_list: Vec<Id> = self.dirty.iter().map(|e| *e.key()).collect();
191            self.dirty.clear();
192
193            for id in dirty_list {
194                let root = self.union_find.find(id);
195                if let Some(mut eclass) = self.classes.get_mut(&root) {
196                    let mut new_nodes = Vec::new();
197                    for node in eclass.nodes.drain(..) {
198                        let canonical = node.map_children(|child| self.union_find.find(child));
199                        if let Some(old_id) = self.memo.get(&canonical) {
200                            let old_root = self.union_find.find(*old_id);
201                            if old_root != root {
202                                todo.push((old_root, root));
203                            }
204                        }
205                        self.memo.insert(canonical.clone(), root);
206                        new_nodes.push(canonical);
207                    }
208                    eclass.nodes = new_nodes;
209                }
210            }
211
212            for (id1, id2) in todo {
213                self.union(id1, id2);
214            }
215        }
216    }
217}