Skip to main content

miden_assembly/linker/
callgraph.rs

1use alloc::{
2    collections::{BTreeMap, BTreeSet, VecDeque},
3    vec::Vec,
4};
5
6use crate::GlobalItemIndex;
7
8/// Represents the inability to construct a topological ordering of the nodes in a [CallGraph]
9/// due to a cycle in the graph, which can happen due to recursion.
10#[derive(Debug)]
11pub struct CycleError(BTreeSet<GlobalItemIndex>);
12
13impl CycleError {
14    pub fn new(nodes: impl IntoIterator<Item = GlobalItemIndex>) -> Self {
15        Self(nodes.into_iter().collect())
16    }
17
18    pub fn into_node_ids(self) -> impl ExactSizeIterator<Item = GlobalItemIndex> {
19        self.0.into_iter()
20    }
21}
22
23// CALL GRAPH
24// ================================================================================================
25
26/// A [CallGraph] is a directed, acyclic graph which represents all of the edges between procedures
27/// formed by a caller/callee relationship.
28///
29/// More precisely, this graph can be used to perform the following analyses:
30///
31/// - What is the maximum call stack depth for a program?
32/// - Are there any recursive procedure calls?
33/// - Are there procedures which are unreachable from the program entrypoint?, i.e. dead code
34/// - What is the set of procedures which are reachable from a given procedure, and which of those
35///   are (un)conditionally called?
36///
37/// A [CallGraph] is the actual graph underpinning the conceptual "module graph" of the linker, and
38/// the two are intrinsically linked to one another (i.e. a [CallGraph] is meaningless without
39/// the corresponding [super::Linker] state).
40#[derive(Default, Clone)]
41pub struct CallGraph {
42    /// The adjacency matrix for procedures in the call graph
43    nodes: BTreeMap<GlobalItemIndex, Vec<GlobalItemIndex>>,
44}
45
46impl CallGraph {
47    /// Gets the set of edges from the given caller to its callees in the graph.
48    pub fn out_edges(&self, gid: GlobalItemIndex) -> &[GlobalItemIndex] {
49        self.nodes.get(&gid).map(Vec::as_slice).unwrap_or(&[])
50    }
51
52    /// Inserts a node in the graph for `id`, if not already present.
53    ///
54    /// Returns the set of [GlobalItemIndex] which are the outbound neighbors of `id` in the
55    /// graph, i.e. the callees of a call-like instruction.
56    pub fn get_or_insert_node(&mut self, id: GlobalItemIndex) -> &mut Vec<GlobalItemIndex> {
57        self.nodes.entry(id).or_default()
58    }
59
60    /// Add an edge in the call graph from `caller` to `callee`.
61    ///
62    /// This operation is unchecked, i.e. it is possible to introduce cycles in the graph using it.
63    /// As a result, it is essential that the caller either know that adding the edge does _not_
64    /// introduce a cycle, or that [Self::toposort] is run once the graph is built, in order to
65    /// verify that the graph is valid and has no cycles.
66    ///
67    /// Returns an error if adding the edge would introduce a trivial self-cycle.
68    pub fn add_edge(
69        &mut self,
70        caller: GlobalItemIndex,
71        callee: GlobalItemIndex,
72    ) -> Result<(), CycleError> {
73        if caller == callee {
74            return Err(CycleError::new([caller]));
75        }
76
77        // Make sure the callee is in the graph
78        self.get_or_insert_node(callee);
79        // Make sure the caller is in the graph
80        let callees = self.get_or_insert_node(caller);
81        // If the caller already references the callee, we're done
82        if callees.contains(&callee) {
83            return Ok(());
84        }
85
86        callees.push(callee);
87        Ok(())
88    }
89
90    /// Removes the edge between `caller` and `callee` from the graph
91    pub fn remove_edge(&mut self, caller: GlobalItemIndex, callee: GlobalItemIndex) {
92        if let Some(out_edges) = self.nodes.get_mut(&caller) {
93            out_edges.retain(|n| *n != callee);
94        }
95    }
96
97    /// Returns the number of predecessors of `id` in the graph, i.e.
98    /// the number of procedures which call `id`.
99    pub fn num_predecessors(&self, id: GlobalItemIndex) -> usize {
100        self.nodes.iter().filter(|(_, out_edges)| out_edges.contains(&id)).count()
101    }
102
103    /// Construct the topological ordering of all nodes in the call graph.
104    ///
105    /// Uses Kahn's algorithm with pre-computed in-degrees for O(V + E) complexity.
106    ///
107    /// Returns `Err` if a cycle is detected in the graph
108    pub fn toposort(&self) -> Result<Vec<GlobalItemIndex>, CycleError> {
109        if self.nodes.is_empty() {
110            return Ok(vec![]);
111        }
112
113        let num_nodes = self.nodes.len();
114        let mut output = Vec::with_capacity(num_nodes);
115
116        // Compute in-degree for each node: O(V + E)
117        let mut in_degree: BTreeMap<GlobalItemIndex, usize> =
118            self.nodes.keys().map(|&k| (k, 0)).collect();
119        for out_edges in self.nodes.values() {
120            for &succ in out_edges {
121                *in_degree.entry(succ).or_default() += 1;
122            }
123        }
124
125        // Seed the queue with all zero-in-degree nodes: O(V)
126        let mut queue: VecDeque<GlobalItemIndex> =
127            in_degree.iter().filter(|&(_, &deg)| deg == 0).map(|(&n, _)| n).collect();
128
129        // Kahn's algorithm: process each node exactly once, each edge exactly once → O(V + E)
130        while let Some(id) = queue.pop_front() {
131            output.push(id);
132            for &mid in self.out_edges(id) {
133                let deg = in_degree.get_mut(&mid).unwrap();
134                *deg -= 1;
135                if *deg == 0 {
136                    queue.push_back(mid);
137                }
138            }
139        }
140
141        // If not all nodes were visited, the remaining nodes participate in cycles
142        if output.len() != num_nodes {
143            let visited: BTreeSet<GlobalItemIndex> = output.iter().copied().collect();
144            let mut in_cycle = BTreeSet::default();
145            for (&n, out_edges) in self.nodes.iter() {
146                if visited.contains(&n) {
147                    continue;
148                }
149                in_cycle.insert(n);
150                for &succ in out_edges {
151                    if !visited.contains(&succ) {
152                        in_cycle.insert(succ);
153                    }
154                }
155            }
156            Err(CycleError(in_cycle))
157        } else {
158            Ok(output)
159        }
160    }
161
162    /// Gets a new graph which is a subgraph of `self` containing all of the nodes reachable from
163    /// `root`, and nothing else.
164    pub fn subgraph(&self, root: GlobalItemIndex) -> Self {
165        let mut worklist = VecDeque::from_iter([root]);
166        let mut graph = Self::default();
167        let mut visited = BTreeSet::default();
168
169        while let Some(gid) = worklist.pop_front() {
170            if !visited.insert(gid) {
171                continue;
172            }
173
174            let new_successors = graph.get_or_insert_node(gid);
175            let prev_successors = self.out_edges(gid);
176            worklist.extend(prev_successors.iter().cloned());
177            new_successors.extend_from_slice(prev_successors);
178        }
179
180        graph
181    }
182
183    /// Computes the set of nodes in this graph which can reach `root`.
184    fn reverse_reachable(&self, root: GlobalItemIndex) -> BTreeSet<GlobalItemIndex> {
185        // Build reverse adjacency map: O(V + E)
186        let mut predecessors: BTreeMap<GlobalItemIndex, Vec<GlobalItemIndex>> =
187            self.nodes.keys().map(|&k| (k, Vec::new())).collect();
188        for (&node, out_edges) in self.nodes.iter() {
189            for &succ in out_edges {
190                predecessors.entry(succ).or_default().push(node);
191            }
192        }
193
194        // BFS on reverse graph: O(V + E)
195        let mut worklist = VecDeque::from_iter([root]);
196        let mut visited = BTreeSet::default();
197
198        while let Some(gid) = worklist.pop_front() {
199            if !visited.insert(gid) {
200                continue;
201            }
202
203            if let Some(preds) = predecessors.get(&gid) {
204                worklist.extend(preds.iter().copied());
205            }
206        }
207
208        visited
209    }
210
211    /// Constructs the topological ordering of nodes in the call graph, for which `caller` is an
212    /// ancestor.
213    ///
214    /// Uses Kahn's algorithm with pre-computed in-degrees for O(V + E) complexity.
215    ///
216    /// # Errors
217    /// Returns an error if a cycle is detected in the graph.
218    pub fn toposort_caller(
219        &self,
220        caller: GlobalItemIndex,
221    ) -> Result<Vec<GlobalItemIndex>, CycleError> {
222        // Build a subgraph of `self` containing only those nodes reachable from `caller`
223        let subgraph = self.subgraph(caller);
224        let num_nodes = subgraph.nodes.len();
225        let mut output = Vec::with_capacity(num_nodes);
226
227        // Compute in-degree for each node in the subgraph: O(V + E)
228        let mut in_degree: BTreeMap<GlobalItemIndex, usize> =
229            subgraph.nodes.keys().map(|&k| (k, 0)).collect();
230        for out_edges in subgraph.nodes.values() {
231            for &succ in out_edges {
232                *in_degree.entry(succ).or_default() += 1;
233            }
234        }
235
236        // Check if any cycle closes back to `caller` (i.e. caller has predecessors in its
237        // own reachable subgraph)
238        let caller_has_predecessors = in_degree.get(&caller).copied().unwrap_or(0) > 0;
239
240        // Force `caller` as the root by zeroing its in-degree (equivalent to removing
241        // all back-edges to `caller`)
242        in_degree.insert(caller, 0);
243
244        // Seed queue with `caller` as the sole root
245        let mut queue = VecDeque::from_iter([caller]);
246
247        // Kahn's algorithm: O(V + E)
248        while let Some(id) = queue.pop_front() {
249            output.push(id);
250            for &mid in subgraph.out_edges(id) {
251                // Skip back-edges to caller (already processed as root)
252                if mid == caller {
253                    continue;
254                }
255                let deg = in_degree.get_mut(&mid).unwrap();
256                *deg -= 1;
257                if *deg == 0 {
258                    queue.push_back(mid);
259                }
260            }
261        }
262
263        // Detect cycles: either caller had predecessors in its subgraph (a cycle closes
264        // back to it), or not all nodes were reachable (an internal cycle)
265        let has_cycle = caller_has_predecessors || output.len() != num_nodes;
266        if has_cycle {
267            let visited: BTreeSet<GlobalItemIndex> = output.iter().copied().collect();
268            let mut in_cycle = BTreeSet::default();
269
270            // Collect nodes not processed by the sort (they're in internal cycles)
271            for (&n, out_edges) in subgraph.nodes.iter() {
272                if !visited.contains(&n) {
273                    in_cycle.insert(n);
274                    for &succ in out_edges {
275                        if !visited.contains(&succ) {
276                            in_cycle.insert(succ);
277                        }
278                    }
279                }
280            }
281
282            // If caller has back-edges, include all nodes participating in the cycle
283            // through caller
284            if caller_has_predecessors {
285                in_cycle.extend(subgraph.reverse_reachable(caller));
286            }
287
288            Err(CycleError(in_cycle))
289        } else {
290            Ok(output)
291        }
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::{GlobalItemIndex, ModuleIndex, ast::ItemIndex};
299
300    const A: ModuleIndex = ModuleIndex::const_new(1);
301    const B: ModuleIndex = ModuleIndex::const_new(2);
302    const P1: ItemIndex = ItemIndex::const_new(1);
303    const P2: ItemIndex = ItemIndex::const_new(2);
304    const P3: ItemIndex = ItemIndex::const_new(3);
305    const A1: GlobalItemIndex = GlobalItemIndex { module: A, index: P1 };
306    const A2: GlobalItemIndex = GlobalItemIndex { module: A, index: P2 };
307    const A3: GlobalItemIndex = GlobalItemIndex { module: A, index: P3 };
308    const B1: GlobalItemIndex = GlobalItemIndex { module: B, index: P1 };
309    const B2: GlobalItemIndex = GlobalItemIndex { module: B, index: P2 };
310    const B3: GlobalItemIndex = GlobalItemIndex { module: B, index: P3 };
311
312    #[test]
313    fn callgraph_add_edge() {
314        let graph = callgraph_simple();
315
316        // Verify the graph structure
317        assert_eq!(graph.num_predecessors(A1), 0);
318        assert_eq!(graph.num_predecessors(B1), 0);
319        assert_eq!(graph.num_predecessors(A2), 1);
320        assert_eq!(graph.num_predecessors(B2), 2);
321        assert_eq!(graph.num_predecessors(B3), 1);
322        assert_eq!(graph.num_predecessors(A3), 2);
323
324        assert_eq!(graph.out_edges(A1), &[A2]);
325        assert_eq!(graph.out_edges(B1), &[B2]);
326        assert_eq!(graph.out_edges(A2), &[B2, A3]);
327        assert_eq!(graph.out_edges(B2), &[B3]);
328        assert_eq!(graph.out_edges(A3), &[]);
329        assert_eq!(graph.out_edges(B3), &[A3]);
330    }
331
332    #[test]
333    fn callgraph_add_edge_with_cycle() {
334        let graph = callgraph_cycle();
335
336        // Verify the graph structure
337        assert_eq!(graph.num_predecessors(A1), 0);
338        assert_eq!(graph.num_predecessors(B1), 0);
339        assert_eq!(graph.num_predecessors(A2), 2);
340        assert_eq!(graph.num_predecessors(B2), 2);
341        assert_eq!(graph.num_predecessors(B3), 1);
342        assert_eq!(graph.num_predecessors(A3), 1);
343
344        assert_eq!(graph.out_edges(A1), &[A2]);
345        assert_eq!(graph.out_edges(B1), &[B2]);
346        assert_eq!(graph.out_edges(A2), &[B2]);
347        assert_eq!(graph.out_edges(B2), &[B3]);
348        assert_eq!(graph.out_edges(A3), &[A2]);
349        assert_eq!(graph.out_edges(B3), &[A3]);
350    }
351
352    #[test]
353    fn callgraph_subgraph() {
354        let graph = callgraph_simple();
355        let subgraph = graph.subgraph(A2);
356
357        assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
358    }
359
360    #[test]
361    fn callgraph_with_cycle_subgraph() {
362        let graph = callgraph_cycle();
363        let subgraph = graph.subgraph(A2);
364
365        assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
366    }
367
368    #[test]
369    fn callgraph_toposort() {
370        let graph = callgraph_simple();
371
372        let sorted = graph.toposort().expect("expected valid topological ordering");
373        assert_eq!(sorted.as_slice(), &[A1, B1, A2, B2, B3, A3]);
374    }
375
376    #[test]
377    fn callgraph_toposort_caller() {
378        let graph = callgraph_simple();
379
380        let sorted = graph.toposort_caller(A2).expect("expected valid topological ordering");
381        assert_eq!(sorted.as_slice(), &[A2, B2, B3, A3]);
382    }
383
384    #[test]
385    fn callgraph_with_cycle_toposort() {
386        let graph = callgraph_cycle();
387
388        let err = graph.toposort().expect_err("expected topological sort to fail with cycle");
389        assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
390    }
391
392    #[test]
393    fn callgraph_toposort_caller_with_reachable_cycle() {
394        let graph = callgraph_cycle();
395
396        let err = graph
397            .toposort_caller(A1)
398            .expect_err("expected toposort_caller to fail when a reachable cycle exists");
399        assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
400    }
401
402    #[test]
403    fn callgraph_toposort_caller_root_closing_cycle() {
404        let graph = callgraph_cycle();
405
406        let err = graph
407            .toposort_caller(A2)
408            .expect_err("expected toposort_caller to detect cycle closing back into root");
409        assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
410    }
411
412    #[test]
413    fn callgraph_add_edge_with_self_cycle_is_error() {
414        let mut graph = CallGraph::default();
415
416        let err = graph.add_edge(A1, A1).expect_err("expected self-edge to be rejected");
417        assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A1]);
418    }
419
420    #[test]
421    fn callgraph_rootless_cycle_toposort_is_error() {
422        let mut graph = CallGraph::default();
423        graph.add_edge(A1, B1).expect("A1 -> B1 must be accepted");
424        graph.add_edge(B1, A1).expect("B1 -> A1 must be accepted");
425
426        let err = graph.toposort().expect_err("expected topological sort to fail with cycle");
427        assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A1, B1]);
428    }
429
430    #[test]
431    fn callgraph_toposort_whole_graph_cycle_without_roots() {
432        let graph = callgraph_cycle_without_roots();
433        let err = graph.toposort().expect_err(
434            "expected topological sort to fail when every node is blocked behind a cycle",
435        );
436        assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A1, A2, A3]);
437    }
438
439    /// a::a1 -> a::a2 -> a::a3
440    ///            |        ^
441    ///            v        |
442    /// b::b1 -> b::b2 -> b::b3
443    fn callgraph_simple() -> CallGraph {
444        // Construct the graph
445        let mut graph = CallGraph::default();
446        graph.add_edge(A1, A2).expect("A1 -> A2 must be accepted");
447        graph.add_edge(B1, B2).expect("B1 -> B2 must be accepted");
448        graph.add_edge(A2, B2).expect("A2 -> B2 must be accepted");
449        graph.add_edge(A2, A3).expect("A2 -> A3 must be accepted");
450        graph.add_edge(B2, B3).expect("B2 -> B3 must be accepted");
451        graph.add_edge(B3, A3).expect("B3 -> A3 must be accepted");
452
453        graph
454    }
455
456    /// a::a1 -> a::a2 <- a::a3
457    ///            |        ^
458    ///            v        |
459    /// b::b1 -> b::b2 -> b::b3
460    fn callgraph_cycle() -> CallGraph {
461        // Construct the graph
462        let mut graph = CallGraph::default();
463        graph.add_edge(A1, A2).expect("A1 -> A2 must be accepted");
464        graph.add_edge(B1, B2).expect("B1 -> B2 must be accepted");
465        graph.add_edge(A2, B2).expect("A2 -> B2 must be accepted");
466        graph.add_edge(B2, B3).expect("B2 -> B3 must be accepted");
467        graph.add_edge(B3, A3).expect("B3 -> A3 must be accepted");
468        graph.add_edge(A3, A2).expect("A3 -> A2 must be accepted");
469
470        graph
471    }
472
473    /// a::a1 -> a::a2 -> a::a3
474    ///   ^                 |
475    ///   +-----------------+
476    ///
477    /// Every node has in-degree 1, so Kahn's algorithm starts with an empty queue.
478    fn callgraph_cycle_without_roots() -> CallGraph {
479        let mut graph = CallGraph::default();
480        graph.add_edge(A1, A2).expect("A1 -> A2 must be accepted");
481        graph.add_edge(A2, A3).expect("A2 -> A3 must be accepted");
482        graph.add_edge(A3, A1).expect("A3 -> A1 must be accepted");
483
484        graph
485    }
486}