1use alloc::{
2 collections::{BTreeMap, BTreeSet, VecDeque},
3 vec::Vec,
4};
5
6use crate::GlobalItemIndex;
7
8#[derive(Debug)]
11pub struct CycleError(BTreeSet<GlobalItemIndex>);
12
13impl CycleError {
14 pub fn into_node_ids(self) -> impl ExactSizeIterator<Item = GlobalItemIndex> {
15 self.0.into_iter()
16 }
17}
18
19#[derive(Default, Clone)]
37pub struct CallGraph {
38 nodes: BTreeMap<GlobalItemIndex, Vec<GlobalItemIndex>>,
40}
41
42impl CallGraph {
43 pub fn out_edges(&self, gid: GlobalItemIndex) -> &[GlobalItemIndex] {
45 self.nodes.get(&gid).map(|out_edges| out_edges.as_slice()).unwrap_or(&[])
46 }
47
48 pub fn get_or_insert_node(&mut self, id: GlobalItemIndex) -> &mut Vec<GlobalItemIndex> {
53 self.nodes.entry(id).or_default()
54 }
55
56 pub fn add_edge(&mut self, caller: GlobalItemIndex, callee: GlobalItemIndex) {
67 assert_ne!(caller, callee, "a procedure cannot call itself");
68
69 self.get_or_insert_node(callee);
71 let callees = self.get_or_insert_node(caller);
73 if callees.contains(&callee) {
75 return;
76 }
77
78 callees.push(callee);
79 }
80
81 pub fn remove_edge(&mut self, caller: GlobalItemIndex, callee: GlobalItemIndex) {
83 if let Some(out_edges) = self.nodes.get_mut(&caller) {
84 out_edges.retain(|n| *n != callee);
85 }
86 }
87
88 pub fn num_predecessors(&self, id: GlobalItemIndex) -> usize {
91 self.nodes.iter().filter(|(_, out_edges)| out_edges.contains(&id)).count()
92 }
93
94 pub fn toposort(&self) -> Result<Vec<GlobalItemIndex>, CycleError> {
98 if self.nodes.is_empty() {
99 return Ok(vec![]);
100 }
101
102 let mut output = Vec::with_capacity(self.nodes.len());
103 let mut graph = self.clone();
104
105 let mut has_preds = BTreeSet::default();
108 for (_node, out_edges) in graph.nodes.iter() {
109 for succ in out_edges.iter() {
110 has_preds.insert(*succ);
111 }
112 }
113 let mut roots =
114 VecDeque::from_iter(graph.nodes.keys().copied().filter(|n| !has_preds.contains(n)));
115
116 let mut expect_cycle = false;
120 if roots.is_empty() {
121 expect_cycle = true;
122 roots.extend(graph.nodes.keys().next());
123 }
124
125 let mut successors = Vec::with_capacity(4);
126 while let Some(id) = roots.pop_front() {
127 output.push(id);
128 successors.clear();
129 successors.extend(graph.nodes[&id].iter().copied());
130 for mid in successors.drain(..) {
131 graph.remove_edge(id, mid);
132 if graph.num_predecessors(mid) == 0 {
133 roots.push_back(mid);
134 }
135 }
136 }
137
138 let has_cycle = graph
139 .nodes
140 .iter()
141 .any(|(n, out_edges)| !output.contains(n) || !out_edges.is_empty());
142 if has_cycle {
143 let mut in_cycle = BTreeSet::default();
144 for (n, edges) in graph.nodes.iter() {
145 if edges.is_empty() {
146 continue;
147 }
148 in_cycle.insert(*n);
149 in_cycle.extend(edges.as_slice());
150 }
151 Err(CycleError(in_cycle))
152 } else {
153 assert!(!expect_cycle, "we expected a cycle to be found, but one was not identified");
154 Ok(output)
155 }
156 }
157
158 pub fn subgraph(&self, root: GlobalItemIndex) -> Self {
161 let mut worklist = VecDeque::from_iter([root]);
162 let mut graph = Self::default();
163 let mut visited = BTreeSet::default();
164
165 while let Some(gid) = worklist.pop_front() {
166 if !visited.insert(gid) {
167 continue;
168 }
169
170 let new_successors = graph.get_or_insert_node(gid);
171 let prev_successors = self.out_edges(gid);
172 worklist.extend(prev_successors.iter().cloned());
173 new_successors.extend_from_slice(prev_successors);
174 }
175
176 graph
177 }
178
179 pub fn toposort_caller(
185 &self,
186 caller: GlobalItemIndex,
187 ) -> Result<Vec<GlobalItemIndex>, CycleError> {
188 let mut output = Vec::with_capacity(self.nodes.len());
189
190 let mut graph = self.subgraph(caller);
192
193 graph.nodes.iter_mut().for_each(|(_pred, out_edges)| {
195 out_edges.retain(|n| *n != caller);
196 });
197
198 let mut roots = VecDeque::from_iter([caller]);
199 let mut successors = Vec::with_capacity(4);
200 while let Some(id) = roots.pop_front() {
201 output.push(id);
202 successors.clear();
203 successors.extend(graph.nodes[&id].iter().copied());
204 for mid in successors.drain(..) {
205 graph.remove_edge(id, mid);
206 if graph.num_predecessors(mid) == 0 {
207 roots.push_back(mid);
208 }
209 }
210 }
211
212 let has_cycle = graph
213 .nodes
214 .iter()
215 .any(|(n, out_edges)| output.contains(n) && !out_edges.is_empty());
216 if has_cycle {
217 let mut in_cycle = BTreeSet::default();
218 for (n, edges) in graph.nodes.iter() {
219 if edges.is_empty() {
220 continue;
221 }
222 in_cycle.insert(*n);
223 in_cycle.extend(edges.as_slice());
224 }
225 Err(CycleError(in_cycle))
226 } else {
227 Ok(output)
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::{GlobalItemIndex, ModuleIndex, ast::ItemIndex};
236
237 const A: ModuleIndex = ModuleIndex::const_new(1);
238 const B: ModuleIndex = ModuleIndex::const_new(2);
239 const P1: ItemIndex = ItemIndex::const_new(1);
240 const P2: ItemIndex = ItemIndex::const_new(2);
241 const P3: ItemIndex = ItemIndex::const_new(3);
242 const A1: GlobalItemIndex = GlobalItemIndex { module: A, index: P1 };
243 const A2: GlobalItemIndex = GlobalItemIndex { module: A, index: P2 };
244 const A3: GlobalItemIndex = GlobalItemIndex { module: A, index: P3 };
245 const B1: GlobalItemIndex = GlobalItemIndex { module: B, index: P1 };
246 const B2: GlobalItemIndex = GlobalItemIndex { module: B, index: P2 };
247 const B3: GlobalItemIndex = GlobalItemIndex { module: B, index: P3 };
248
249 #[test]
250 fn callgraph_add_edge() {
251 let graph = callgraph_simple();
252
253 assert_eq!(graph.num_predecessors(A1), 0);
255 assert_eq!(graph.num_predecessors(B1), 0);
256 assert_eq!(graph.num_predecessors(A2), 1);
257 assert_eq!(graph.num_predecessors(B2), 2);
258 assert_eq!(graph.num_predecessors(B3), 1);
259 assert_eq!(graph.num_predecessors(A3), 2);
260
261 assert_eq!(graph.out_edges(A1), &[A2]);
262 assert_eq!(graph.out_edges(B1), &[B2]);
263 assert_eq!(graph.out_edges(A2), &[B2, A3]);
264 assert_eq!(graph.out_edges(B2), &[B3]);
265 assert_eq!(graph.out_edges(A3), &[]);
266 assert_eq!(graph.out_edges(B3), &[A3]);
267 }
268
269 #[test]
270 fn callgraph_add_edge_with_cycle() {
271 let graph = callgraph_cycle();
272
273 assert_eq!(graph.num_predecessors(A1), 0);
275 assert_eq!(graph.num_predecessors(B1), 0);
276 assert_eq!(graph.num_predecessors(A2), 2);
277 assert_eq!(graph.num_predecessors(B2), 2);
278 assert_eq!(graph.num_predecessors(B3), 1);
279 assert_eq!(graph.num_predecessors(A3), 1);
280
281 assert_eq!(graph.out_edges(A1), &[A2]);
282 assert_eq!(graph.out_edges(B1), &[B2]);
283 assert_eq!(graph.out_edges(A2), &[B2]);
284 assert_eq!(graph.out_edges(B2), &[B3]);
285 assert_eq!(graph.out_edges(A3), &[A2]);
286 assert_eq!(graph.out_edges(B3), &[A3]);
287 }
288
289 #[test]
290 fn callgraph_subgraph() {
291 let graph = callgraph_simple();
292 let subgraph = graph.subgraph(A2);
293
294 assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
295 }
296
297 #[test]
298 fn callgraph_with_cycle_subgraph() {
299 let graph = callgraph_cycle();
300 let subgraph = graph.subgraph(A2);
301
302 assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
303 }
304
305 #[test]
306 fn callgraph_toposort() {
307 let graph = callgraph_simple();
308
309 let sorted = graph.toposort().expect("expected valid topological ordering");
310 assert_eq!(sorted.as_slice(), &[A1, B1, A2, B2, B3, A3]);
311 }
312
313 #[test]
314 fn callgraph_toposort_caller() {
315 let graph = callgraph_simple();
316
317 let sorted = graph.toposort_caller(A2).expect("expected valid topological ordering");
318 assert_eq!(sorted.as_slice(), &[A2, B2, B3, A3]);
319 }
320
321 #[test]
322 fn callgraph_with_cycle_toposort() {
323 let graph = callgraph_cycle();
324
325 let err = graph.toposort().expect_err("expected topological sort to fail with cycle");
326 assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
327 }
328
329 fn callgraph_simple() -> CallGraph {
334 let mut graph = CallGraph::default();
336 graph.add_edge(A1, A2);
337 graph.add_edge(B1, B2);
338 graph.add_edge(A2, B2);
339 graph.add_edge(A2, A3);
340 graph.add_edge(B2, B3);
341 graph.add_edge(B3, A3);
342
343 graph
344 }
345
346 fn callgraph_cycle() -> CallGraph {
351 let mut graph = CallGraph::default();
353 graph.add_edge(A1, A2);
354 graph.add_edge(B1, B2);
355 graph.add_edge(A2, B2);
356 graph.add_edge(B2, B3);
357 graph.add_edge(B3, A3);
358 graph.add_edge(A3, A2);
359
360 graph
361 }
362}