1use alloc::{
2 collections::{BTreeMap, BTreeSet, VecDeque},
3 vec::Vec,
4};
5
6use crate::GlobalItemIndex;
7
8#[derive(Debug)]
11pub struct CycleError(BTreeSet<GlobalItemIndex>);
12
13impl CycleError {
14 pub fn new(nodes: impl IntoIterator<Item = GlobalItemIndex>) -> Self {
15 Self(nodes.into_iter().collect())
16 }
17
18 pub fn into_node_ids(self) -> impl ExactSizeIterator<Item = GlobalItemIndex> {
19 self.0.into_iter()
20 }
21}
22
23#[derive(Default, Clone)]
41pub struct CallGraph {
42 nodes: BTreeMap<GlobalItemIndex, Vec<GlobalItemIndex>>,
44}
45
46impl CallGraph {
47 pub fn out_edges(&self, gid: GlobalItemIndex) -> &[GlobalItemIndex] {
49 self.nodes.get(&gid).map(Vec::as_slice).unwrap_or(&[])
50 }
51
52 pub fn get_or_insert_node(&mut self, id: GlobalItemIndex) -> &mut Vec<GlobalItemIndex> {
57 self.nodes.entry(id).or_default()
58 }
59
60 pub fn add_edge(
69 &mut self,
70 caller: GlobalItemIndex,
71 callee: GlobalItemIndex,
72 ) -> Result<(), CycleError> {
73 if caller == callee {
74 return Err(CycleError::new([caller]));
75 }
76
77 self.get_or_insert_node(callee);
79 let callees = self.get_or_insert_node(caller);
81 if callees.contains(&callee) {
83 return Ok(());
84 }
85
86 callees.push(callee);
87 Ok(())
88 }
89
90 pub fn remove_edge(&mut self, caller: GlobalItemIndex, callee: GlobalItemIndex) {
92 if let Some(out_edges) = self.nodes.get_mut(&caller) {
93 out_edges.retain(|n| *n != callee);
94 }
95 }
96
97 pub fn num_predecessors(&self, id: GlobalItemIndex) -> usize {
100 self.nodes.iter().filter(|(_, out_edges)| out_edges.contains(&id)).count()
101 }
102
103 pub fn toposort(&self) -> Result<Vec<GlobalItemIndex>, CycleError> {
109 if self.nodes.is_empty() {
110 return Ok(vec![]);
111 }
112
113 let num_nodes = self.nodes.len();
114 let mut output = Vec::with_capacity(num_nodes);
115
116 let mut in_degree: BTreeMap<GlobalItemIndex, usize> =
118 self.nodes.keys().map(|&k| (k, 0)).collect();
119 for out_edges in self.nodes.values() {
120 for &succ in out_edges {
121 *in_degree.entry(succ).or_default() += 1;
122 }
123 }
124
125 let mut queue: VecDeque<GlobalItemIndex> =
127 in_degree.iter().filter(|&(_, °)| deg == 0).map(|(&n, _)| n).collect();
128
129 while let Some(id) = queue.pop_front() {
131 output.push(id);
132 for &mid in self.out_edges(id) {
133 let deg = in_degree.get_mut(&mid).unwrap();
134 *deg -= 1;
135 if *deg == 0 {
136 queue.push_back(mid);
137 }
138 }
139 }
140
141 if output.len() != num_nodes {
143 let visited: BTreeSet<GlobalItemIndex> = output.iter().copied().collect();
144 let mut in_cycle = BTreeSet::default();
145 for (&n, out_edges) in self.nodes.iter() {
146 if visited.contains(&n) {
147 continue;
148 }
149 in_cycle.insert(n);
150 for &succ in out_edges {
151 if !visited.contains(&succ) {
152 in_cycle.insert(succ);
153 }
154 }
155 }
156 Err(CycleError(in_cycle))
157 } else {
158 Ok(output)
159 }
160 }
161
162 pub fn subgraph(&self, root: GlobalItemIndex) -> Self {
165 let mut worklist = VecDeque::from_iter([root]);
166 let mut graph = Self::default();
167 let mut visited = BTreeSet::default();
168
169 while let Some(gid) = worklist.pop_front() {
170 if !visited.insert(gid) {
171 continue;
172 }
173
174 let new_successors = graph.get_or_insert_node(gid);
175 let prev_successors = self.out_edges(gid);
176 worklist.extend(prev_successors.iter().cloned());
177 new_successors.extend_from_slice(prev_successors);
178 }
179
180 graph
181 }
182
183 fn reverse_reachable(&self, root: GlobalItemIndex) -> BTreeSet<GlobalItemIndex> {
185 let mut predecessors: BTreeMap<GlobalItemIndex, Vec<GlobalItemIndex>> =
187 self.nodes.keys().map(|&k| (k, Vec::new())).collect();
188 for (&node, out_edges) in self.nodes.iter() {
189 for &succ in out_edges {
190 predecessors.entry(succ).or_default().push(node);
191 }
192 }
193
194 let mut worklist = VecDeque::from_iter([root]);
196 let mut visited = BTreeSet::default();
197
198 while let Some(gid) = worklist.pop_front() {
199 if !visited.insert(gid) {
200 continue;
201 }
202
203 if let Some(preds) = predecessors.get(&gid) {
204 worklist.extend(preds.iter().copied());
205 }
206 }
207
208 visited
209 }
210
211 pub fn toposort_caller(
219 &self,
220 caller: GlobalItemIndex,
221 ) -> Result<Vec<GlobalItemIndex>, CycleError> {
222 let subgraph = self.subgraph(caller);
224 let num_nodes = subgraph.nodes.len();
225 let mut output = Vec::with_capacity(num_nodes);
226
227 let mut in_degree: BTreeMap<GlobalItemIndex, usize> =
229 subgraph.nodes.keys().map(|&k| (k, 0)).collect();
230 for out_edges in subgraph.nodes.values() {
231 for &succ in out_edges {
232 *in_degree.entry(succ).or_default() += 1;
233 }
234 }
235
236 let caller_has_predecessors = in_degree.get(&caller).copied().unwrap_or(0) > 0;
239
240 in_degree.insert(caller, 0);
243
244 let mut queue = VecDeque::from_iter([caller]);
246
247 while let Some(id) = queue.pop_front() {
249 output.push(id);
250 for &mid in subgraph.out_edges(id) {
251 if mid == caller {
253 continue;
254 }
255 let deg = in_degree.get_mut(&mid).unwrap();
256 *deg -= 1;
257 if *deg == 0 {
258 queue.push_back(mid);
259 }
260 }
261 }
262
263 let has_cycle = caller_has_predecessors || output.len() != num_nodes;
266 if has_cycle {
267 let visited: BTreeSet<GlobalItemIndex> = output.iter().copied().collect();
268 let mut in_cycle = BTreeSet::default();
269
270 for (&n, out_edges) in subgraph.nodes.iter() {
272 if !visited.contains(&n) {
273 in_cycle.insert(n);
274 for &succ in out_edges {
275 if !visited.contains(&succ) {
276 in_cycle.insert(succ);
277 }
278 }
279 }
280 }
281
282 if caller_has_predecessors {
285 in_cycle.extend(subgraph.reverse_reachable(caller));
286 }
287
288 Err(CycleError(in_cycle))
289 } else {
290 Ok(output)
291 }
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::{GlobalItemIndex, ModuleIndex, ast::ItemIndex};
299
300 const A: ModuleIndex = ModuleIndex::const_new(1);
301 const B: ModuleIndex = ModuleIndex::const_new(2);
302 const P1: ItemIndex = ItemIndex::const_new(1);
303 const P2: ItemIndex = ItemIndex::const_new(2);
304 const P3: ItemIndex = ItemIndex::const_new(3);
305 const A1: GlobalItemIndex = GlobalItemIndex { module: A, index: P1 };
306 const A2: GlobalItemIndex = GlobalItemIndex { module: A, index: P2 };
307 const A3: GlobalItemIndex = GlobalItemIndex { module: A, index: P3 };
308 const B1: GlobalItemIndex = GlobalItemIndex { module: B, index: P1 };
309 const B2: GlobalItemIndex = GlobalItemIndex { module: B, index: P2 };
310 const B3: GlobalItemIndex = GlobalItemIndex { module: B, index: P3 };
311
312 #[test]
313 fn callgraph_add_edge() {
314 let graph = callgraph_simple();
315
316 assert_eq!(graph.num_predecessors(A1), 0);
318 assert_eq!(graph.num_predecessors(B1), 0);
319 assert_eq!(graph.num_predecessors(A2), 1);
320 assert_eq!(graph.num_predecessors(B2), 2);
321 assert_eq!(graph.num_predecessors(B3), 1);
322 assert_eq!(graph.num_predecessors(A3), 2);
323
324 assert_eq!(graph.out_edges(A1), &[A2]);
325 assert_eq!(graph.out_edges(B1), &[B2]);
326 assert_eq!(graph.out_edges(A2), &[B2, A3]);
327 assert_eq!(graph.out_edges(B2), &[B3]);
328 assert_eq!(graph.out_edges(A3), &[]);
329 assert_eq!(graph.out_edges(B3), &[A3]);
330 }
331
332 #[test]
333 fn callgraph_add_edge_with_cycle() {
334 let graph = callgraph_cycle();
335
336 assert_eq!(graph.num_predecessors(A1), 0);
338 assert_eq!(graph.num_predecessors(B1), 0);
339 assert_eq!(graph.num_predecessors(A2), 2);
340 assert_eq!(graph.num_predecessors(B2), 2);
341 assert_eq!(graph.num_predecessors(B3), 1);
342 assert_eq!(graph.num_predecessors(A3), 1);
343
344 assert_eq!(graph.out_edges(A1), &[A2]);
345 assert_eq!(graph.out_edges(B1), &[B2]);
346 assert_eq!(graph.out_edges(A2), &[B2]);
347 assert_eq!(graph.out_edges(B2), &[B3]);
348 assert_eq!(graph.out_edges(A3), &[A2]);
349 assert_eq!(graph.out_edges(B3), &[A3]);
350 }
351
352 #[test]
353 fn callgraph_subgraph() {
354 let graph = callgraph_simple();
355 let subgraph = graph.subgraph(A2);
356
357 assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
358 }
359
360 #[test]
361 fn callgraph_with_cycle_subgraph() {
362 let graph = callgraph_cycle();
363 let subgraph = graph.subgraph(A2);
364
365 assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
366 }
367
368 #[test]
369 fn callgraph_toposort() {
370 let graph = callgraph_simple();
371
372 let sorted = graph.toposort().expect("expected valid topological ordering");
373 assert_eq!(sorted.as_slice(), &[A1, B1, A2, B2, B3, A3]);
374 }
375
376 #[test]
377 fn callgraph_toposort_caller() {
378 let graph = callgraph_simple();
379
380 let sorted = graph.toposort_caller(A2).expect("expected valid topological ordering");
381 assert_eq!(sorted.as_slice(), &[A2, B2, B3, A3]);
382 }
383
384 #[test]
385 fn callgraph_with_cycle_toposort() {
386 let graph = callgraph_cycle();
387
388 let err = graph.toposort().expect_err("expected topological sort to fail with cycle");
389 assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
390 }
391
392 #[test]
393 fn callgraph_toposort_caller_with_reachable_cycle() {
394 let graph = callgraph_cycle();
395
396 let err = graph
397 .toposort_caller(A1)
398 .expect_err("expected toposort_caller to fail when a reachable cycle exists");
399 assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
400 }
401
402 #[test]
403 fn callgraph_toposort_caller_root_closing_cycle() {
404 let graph = callgraph_cycle();
405
406 let err = graph
407 .toposort_caller(A2)
408 .expect_err("expected toposort_caller to detect cycle closing back into root");
409 assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
410 }
411
412 #[test]
413 fn callgraph_add_edge_with_self_cycle_is_error() {
414 let mut graph = CallGraph::default();
415
416 let err = graph.add_edge(A1, A1).expect_err("expected self-edge to be rejected");
417 assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A1]);
418 }
419
420 #[test]
421 fn callgraph_rootless_cycle_toposort_is_error() {
422 let mut graph = CallGraph::default();
423 graph.add_edge(A1, B1).expect("A1 -> B1 must be accepted");
424 graph.add_edge(B1, A1).expect("B1 -> A1 must be accepted");
425
426 let err = graph.toposort().expect_err("expected topological sort to fail with cycle");
427 assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A1, B1]);
428 }
429
430 #[test]
431 fn callgraph_toposort_whole_graph_cycle_without_roots() {
432 let graph = callgraph_cycle_without_roots();
433 let err = graph.toposort().expect_err(
434 "expected topological sort to fail when every node is blocked behind a cycle",
435 );
436 assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A1, A2, A3]);
437 }
438
439 fn callgraph_simple() -> CallGraph {
444 let mut graph = CallGraph::default();
446 graph.add_edge(A1, A2).expect("A1 -> A2 must be accepted");
447 graph.add_edge(B1, B2).expect("B1 -> B2 must be accepted");
448 graph.add_edge(A2, B2).expect("A2 -> B2 must be accepted");
449 graph.add_edge(A2, A3).expect("A2 -> A3 must be accepted");
450 graph.add_edge(B2, B3).expect("B2 -> B3 must be accepted");
451 graph.add_edge(B3, A3).expect("B3 -> A3 must be accepted");
452
453 graph
454 }
455
456 fn callgraph_cycle() -> CallGraph {
461 let mut graph = CallGraph::default();
463 graph.add_edge(A1, A2).expect("A1 -> A2 must be accepted");
464 graph.add_edge(B1, B2).expect("B1 -> B2 must be accepted");
465 graph.add_edge(A2, B2).expect("A2 -> B2 must be accepted");
466 graph.add_edge(B2, B3).expect("B2 -> B3 must be accepted");
467 graph.add_edge(B3, A3).expect("B3 -> A3 must be accepted");
468 graph.add_edge(A3, A2).expect("A3 -> A2 must be accepted");
469
470 graph
471 }
472
473 fn callgraph_cycle_without_roots() -> CallGraph {
479 let mut graph = CallGraph::default();
480 graph.add_edge(A1, A2).expect("A1 -> A2 must be accepted");
481 graph.add_edge(A2, A3).expect("A2 -> A3 must be accepted");
482 graph.add_edge(A3, A1).expect("A3 -> A1 must be accepted");
483
484 graph
485 }
486}