1pub use crate::union_find::Id;
2use crate::union_find::UnionFind;
3use chomsky_types::Loc;
4use dashmap::DashMap;
5use std::hash::Hash;
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8pub trait Language: Hash + Eq + Clone + Ord {
9 fn children(&self) -> Vec<Id>;
10 fn map_children(&self, f: impl FnMut(Id) -> Id) -> Self;
11}
12
13#[derive(Debug, Clone)]
14pub struct EClass<L: Language, D> {
15 pub id: Id,
16 pub nodes: Vec<L>,
17 pub data: D,
18}
19
20pub trait Analysis<L: Language>: Default {
21 type Data;
22 fn make(egraph: &EGraph<L, Self>, enode: &L) -> Self::Data;
23 fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> bool;
24 fn on_add(&self, _data: &mut Self::Data, _loc: Loc) {}
25 fn is_compatible(&self, _data1: &Self::Data, _data2: &Self::Data) -> bool {
26 true
27 }
28}
29
30#[derive(Debug)]
31pub struct EGraph<L: Language, A: Analysis<L> = ()> {
32 pub union_find: UnionFind,
33 pub classes: DashMap<Id, EClass<L, A::Data>>,
34 pub memo: DashMap<L, Id>,
35 pub analysis: std::sync::RwLock<A>,
36 pub next_id: AtomicUsize,
37 pub dirty: DashMap<Id, ()>,
38}
39
40impl<L: Language> Analysis<L> for () {
41 type Data = ();
42 fn make(_egraph: &EGraph<L, Self>, _enode: &L) -> Self::Data {
43 ()
44 }
45 fn merge(&mut self, _to: &mut Self::Data, _from: Self::Data) -> bool {
46 false
47 }
48}
49
50#[derive(Default)]
52pub struct DebugAnalysis;
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct DebugData {
56 pub locs: Vec<Loc>,
57}
58
59pub trait HasDebugInfo {
60 fn get_locs(&self) -> &[Loc];
61}
62
63impl HasDebugInfo for DebugData {
64 fn get_locs(&self) -> &[Loc] {
65 &self.locs
66 }
67}
68
69impl HasDebugInfo for () {
70 fn get_locs(&self) -> &[Loc] {
71 &[]
72 }
73}
74
75impl<L: Language> Analysis<L> for DebugAnalysis {
76 type Data = DebugData;
77
78 fn make(_egraph: &EGraph<L, Self>, _enode: &L) -> Self::Data {
79 DebugData { locs: Vec::new() }
80 }
81
82 fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> bool {
83 let old_len = to.locs.len();
84 for loc in from.locs {
85 if !to.locs.contains(&loc) {
86 to.locs.push(loc);
87 }
88 }
89 to.locs.len() > old_len
90 }
91
92 fn on_add(&self, data: &mut Self::Data, loc: Loc) {
93 if !data.locs.contains(&loc) {
94 data.locs.push(loc);
95 }
96 }
97}
98
99impl<L: Language, A: Analysis<L>> EGraph<L, A> {
100 pub fn get_class(&self, id: Id) -> dashmap::mapref::one::Ref<'_, Id, EClass<L, A::Data>> {
101 let root = self.union_find.find(id);
102 self.classes.get(&root).expect("Class not found")
103 }
104
105 pub fn new() -> Self {
106 Self {
107 union_find: UnionFind::new(),
108 classes: DashMap::new(),
109 memo: DashMap::new(),
110 analysis: std::sync::RwLock::new(A::default()),
111 next_id: AtomicUsize::new(0),
112 dirty: DashMap::new(),
113 }
114 }
115
116 pub fn add(&self, enode: L) -> Id {
117 let canonical = enode.map_children(|id| self.union_find.find(id));
118 if let Some(id) = self.memo.get(&canonical) {
119 return self.union_find.find(*id);
120 }
121
122 let id = Id::from(self.next_id.fetch_add(1, Ordering::SeqCst));
123
124 let data = A::make(self, &canonical);
125
126 self.memo.insert(canonical.clone(), id);
127 let eclass = EClass {
128 id,
129 nodes: vec![canonical],
130 data,
131 };
132 self.classes.insert(id, eclass);
133
134 id
135 }
136
137 pub fn add_with_loc(&self, enode: L, loc: Loc) -> Id {
138 let id = self.add(enode);
139 let root = self.union_find.find(id);
140 if let Some(mut eclass) = self.classes.get_mut(&root) {
141 let analysis = self.analysis.read().unwrap();
142 analysis.on_add(&mut eclass.data, loc);
143 }
144 root
145 }
146
147 pub fn union(&self, id1: Id, id2: Id) -> Id {
148 let root1 = self.union_find.find(id1);
149 let root2 = self.union_find.find(id2);
150 if root1 == root2 {
151 return root1;
152 }
153
154 {
156 let analysis = self.analysis.read().unwrap();
157 let data1 = &self.classes.get(&root1).unwrap().data;
158 let data2 = &self.classes.get(&root2).unwrap().data;
159
160 if !analysis.is_compatible(data1, data2) {
161 return root1;
163 }
164 }
165
166 let new_root = self.union_find.union(root1, root2);
167 let old_root = if new_root == root1 { root2 } else { root1 };
168
169 self.dirty.insert(new_root, ());
170
171 if let Some((_, old_class)) = self.classes.remove(&old_root) {
172 let mut new_class = self.classes.get_mut(&new_root).unwrap();
173 for node in old_class.nodes {
174 if !new_class.nodes.contains(&node) {
175 new_class.nodes.push(node);
176 }
177 }
178
179 let mut analysis = self.analysis.write().unwrap();
181 analysis.merge(&mut new_class.data, old_class.data);
182 }
183
184 new_root
185 }
186
187 pub fn rebuild(&self) {
188 while !self.dirty.is_empty() {
189 let mut todo = Vec::new();
190 let dirty_list: Vec<Id> = self.dirty.iter().map(|e| *e.key()).collect();
191 self.dirty.clear();
192
193 for id in dirty_list {
194 let root = self.union_find.find(id);
195 if let Some(mut eclass) = self.classes.get_mut(&root) {
196 let mut new_nodes = Vec::new();
197 for node in eclass.nodes.drain(..) {
198 let canonical = node.map_children(|child| self.union_find.find(child));
199 if let Some(old_id) = self.memo.get(&canonical) {
200 let old_root = self.union_find.find(*old_id);
201 if old_root != root {
202 todo.push((old_root, root));
203 }
204 }
205 self.memo.insert(canonical.clone(), root);
206 new_nodes.push(canonical);
207 }
208 eclass.nodes = new_nodes;
209 }
210 }
211
212 for (id1, id2) in todo {
213 self.union(id1, id2);
214 }
215 }
216 }
217}