miden_assembly/linker/
callgraph.rs

1use alloc::{
2    collections::{BTreeMap, BTreeSet, VecDeque},
3    vec::Vec,
4};
5
6use crate::GlobalItemIndex;
7
8/// Represents the inability to construct a topological ordering of the nodes in a [CallGraph]
9/// due to a cycle in the graph, which can happen due to recursion.
10#[derive(Debug)]
11pub struct CycleError(BTreeSet<GlobalItemIndex>);
12
13impl CycleError {
14    pub fn into_node_ids(self) -> impl ExactSizeIterator<Item = GlobalItemIndex> {
15        self.0.into_iter()
16    }
17}
18
19// CALL GRAPH
20// ================================================================================================
21
22/// A [CallGraph] is a directed, acyclic graph which represents all of the edges between procedures
23/// formed by a caller/callee relationship.
24///
25/// More precisely, this graph can be used to perform the following analyses:
26///
27/// - What is the maximum call stack depth for a program?
28/// - Are there any recursive procedure calls?
29/// - Are there procedures which are unreachable from the program entrypoint?, i.e. dead code
30/// - What is the set of procedures which are reachable from a given procedure, and which of those
31///   are (un)conditionally called?
32///
33/// A [CallGraph] is the actual graph underpinning the conceptual "module graph" of the linker, and
34/// the two are intrinsically linked to one another (i.e. a [CallGraph] is meaningless without
35/// the corresponding [super::Linker] state).
36#[derive(Default, Clone)]
37pub struct CallGraph {
38    /// The adjacency matrix for procedures in the call graph
39    nodes: BTreeMap<GlobalItemIndex, Vec<GlobalItemIndex>>,
40}
41
42impl CallGraph {
43    /// Gets the set of edges from the given caller to its callees in the graph.
44    pub fn out_edges(&self, gid: GlobalItemIndex) -> &[GlobalItemIndex] {
45        self.nodes.get(&gid).map(|out_edges| out_edges.as_slice()).unwrap_or(&[])
46    }
47
48    /// Inserts a node in the graph for `id`, if not already present.
49    ///
50    /// Returns the set of [GlobalItemIndex] which are the outbound neighbors of `id` in the
51    /// graph, i.e. the callees of a call-like instruction.
52    pub fn get_or_insert_node(&mut self, id: GlobalItemIndex) -> &mut Vec<GlobalItemIndex> {
53        self.nodes.entry(id).or_default()
54    }
55
56    /// Add an edge in the call graph from `caller` to `callee`.
57    ///
58    /// This operation is unchecked, i.e. it is possible to introduce cycles in the graph using it.
59    /// As a result, it is essential that the caller either know that adding the edge does _not_
60    /// introduce a cycle, or that [Self::toposort] is run once the graph is built, in order to
61    /// verify that the graph is valid and has no cycles.
62    ///
63    /// NOTE: This function will panic if you attempt to add an edge from a function to itself,
64    /// which trivially introduces a cycle. All other cycle-inducing edges must be caught by a
65    /// call to [Self::toposort].
66    pub fn add_edge(&mut self, caller: GlobalItemIndex, callee: GlobalItemIndex) {
67        assert_ne!(caller, callee, "a procedure cannot call itself");
68
69        // Make sure the callee is in the graph
70        self.get_or_insert_node(callee);
71        // Make sure the caller is in the graph
72        let callees = self.get_or_insert_node(caller);
73        // If the caller already references the callee, we're done
74        if callees.contains(&callee) {
75            return;
76        }
77
78        callees.push(callee);
79    }
80
81    /// Removes the edge between `caller` and `callee` from the graph
82    pub fn remove_edge(&mut self, caller: GlobalItemIndex, callee: GlobalItemIndex) {
83        if let Some(out_edges) = self.nodes.get_mut(&caller) {
84            out_edges.retain(|n| *n != callee);
85        }
86    }
87
88    /// Returns the number of predecessors of `id` in the graph, i.e.
89    /// the number of procedures which call `id`.
90    pub fn num_predecessors(&self, id: GlobalItemIndex) -> usize {
91        self.nodes.iter().filter(|(_, out_edges)| out_edges.contains(&id)).count()
92    }
93
94    /// Construct the topological ordering of all nodes in the call graph.
95    ///
96    /// Returns `Err` if a cycle is detected in the graph
97    pub fn toposort(&self) -> Result<Vec<GlobalItemIndex>, CycleError> {
98        if self.nodes.is_empty() {
99            return Ok(vec![]);
100        }
101
102        let mut output = Vec::with_capacity(self.nodes.len());
103        let mut graph = self.clone();
104
105        // Build the set of roots by finding all nodes
106        // that have no predecessors
107        let mut has_preds = BTreeSet::default();
108        for (_node, out_edges) in graph.nodes.iter() {
109            for succ in out_edges.iter() {
110                has_preds.insert(*succ);
111            }
112        }
113        let mut roots =
114            VecDeque::from_iter(graph.nodes.keys().copied().filter(|n| !has_preds.contains(n)));
115
116        // If all nodes have predecessors, there must be a cycle, so just pick a node and let the
117        // algorithm find the cycle for that node so we have a useful error. Set a flag so that we
118        // can assert that the cycle was actually found as a sanity check
119        let mut expect_cycle = false;
120        if roots.is_empty() {
121            expect_cycle = true;
122            roots.extend(graph.nodes.keys().next());
123        }
124
125        let mut successors = Vec::with_capacity(4);
126        while let Some(id) = roots.pop_front() {
127            output.push(id);
128            successors.clear();
129            successors.extend(graph.nodes[&id].iter().copied());
130            for mid in successors.drain(..) {
131                graph.remove_edge(id, mid);
132                if graph.num_predecessors(mid) == 0 {
133                    roots.push_back(mid);
134                }
135            }
136        }
137
138        let has_cycle = graph
139            .nodes
140            .iter()
141            .any(|(n, out_edges)| !output.contains(n) || !out_edges.is_empty());
142        if has_cycle {
143            let mut in_cycle = BTreeSet::default();
144            for (n, edges) in graph.nodes.iter() {
145                if edges.is_empty() {
146                    continue;
147                }
148                in_cycle.insert(*n);
149                in_cycle.extend(edges.as_slice());
150            }
151            Err(CycleError(in_cycle))
152        } else {
153            assert!(!expect_cycle, "we expected a cycle to be found, but one was not identified");
154            Ok(output)
155        }
156    }
157
158    /// Gets a new graph which is a subgraph of `self` containing all of the nodes reachable from
159    /// `root`, and nothing else.
160    pub fn subgraph(&self, root: GlobalItemIndex) -> Self {
161        let mut worklist = VecDeque::from_iter([root]);
162        let mut graph = Self::default();
163        let mut visited = BTreeSet::default();
164
165        while let Some(gid) = worklist.pop_front() {
166            if !visited.insert(gid) {
167                continue;
168            }
169
170            let new_successors = graph.get_or_insert_node(gid);
171            let prev_successors = self.out_edges(gid);
172            worklist.extend(prev_successors.iter().cloned());
173            new_successors.extend_from_slice(prev_successors);
174        }
175
176        graph
177    }
178
179    /// Constructs the topological ordering of nodes in the call graph, for which `caller` is an
180    /// ancestor.
181    ///
182    /// # Errors
183    /// Returns an error if a cycle is detected in the graph.
184    pub fn toposort_caller(
185        &self,
186        caller: GlobalItemIndex,
187    ) -> Result<Vec<GlobalItemIndex>, CycleError> {
188        let mut output = Vec::with_capacity(self.nodes.len());
189
190        // Build a subgraph of `self` containing only those nodes reachable from `caller`
191        let mut graph = self.subgraph(caller);
192
193        // Remove all predecessor edges to `caller`
194        graph.nodes.iter_mut().for_each(|(_pred, out_edges)| {
195            out_edges.retain(|n| *n != caller);
196        });
197
198        let mut roots = VecDeque::from_iter([caller]);
199        let mut successors = Vec::with_capacity(4);
200        while let Some(id) = roots.pop_front() {
201            output.push(id);
202            successors.clear();
203            successors.extend(graph.nodes[&id].iter().copied());
204            for mid in successors.drain(..) {
205                graph.remove_edge(id, mid);
206                if graph.num_predecessors(mid) == 0 {
207                    roots.push_back(mid);
208                }
209            }
210        }
211
212        let has_cycle = graph
213            .nodes
214            .iter()
215            .any(|(n, out_edges)| output.contains(n) && !out_edges.is_empty());
216        if has_cycle {
217            let mut in_cycle = BTreeSet::default();
218            for (n, edges) in graph.nodes.iter() {
219                if edges.is_empty() {
220                    continue;
221                }
222                in_cycle.insert(*n);
223                in_cycle.extend(edges.as_slice());
224            }
225            Err(CycleError(in_cycle))
226        } else {
227            Ok(output)
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::{GlobalItemIndex, ModuleIndex, ast::ItemIndex};
236
237    const A: ModuleIndex = ModuleIndex::const_new(1);
238    const B: ModuleIndex = ModuleIndex::const_new(2);
239    const P1: ItemIndex = ItemIndex::const_new(1);
240    const P2: ItemIndex = ItemIndex::const_new(2);
241    const P3: ItemIndex = ItemIndex::const_new(3);
242    const A1: GlobalItemIndex = GlobalItemIndex { module: A, index: P1 };
243    const A2: GlobalItemIndex = GlobalItemIndex { module: A, index: P2 };
244    const A3: GlobalItemIndex = GlobalItemIndex { module: A, index: P3 };
245    const B1: GlobalItemIndex = GlobalItemIndex { module: B, index: P1 };
246    const B2: GlobalItemIndex = GlobalItemIndex { module: B, index: P2 };
247    const B3: GlobalItemIndex = GlobalItemIndex { module: B, index: P3 };
248
249    #[test]
250    fn callgraph_add_edge() {
251        let graph = callgraph_simple();
252
253        // Verify the graph structure
254        assert_eq!(graph.num_predecessors(A1), 0);
255        assert_eq!(graph.num_predecessors(B1), 0);
256        assert_eq!(graph.num_predecessors(A2), 1);
257        assert_eq!(graph.num_predecessors(B2), 2);
258        assert_eq!(graph.num_predecessors(B3), 1);
259        assert_eq!(graph.num_predecessors(A3), 2);
260
261        assert_eq!(graph.out_edges(A1), &[A2]);
262        assert_eq!(graph.out_edges(B1), &[B2]);
263        assert_eq!(graph.out_edges(A2), &[B2, A3]);
264        assert_eq!(graph.out_edges(B2), &[B3]);
265        assert_eq!(graph.out_edges(A3), &[]);
266        assert_eq!(graph.out_edges(B3), &[A3]);
267    }
268
269    #[test]
270    fn callgraph_add_edge_with_cycle() {
271        let graph = callgraph_cycle();
272
273        // Verify the graph structure
274        assert_eq!(graph.num_predecessors(A1), 0);
275        assert_eq!(graph.num_predecessors(B1), 0);
276        assert_eq!(graph.num_predecessors(A2), 2);
277        assert_eq!(graph.num_predecessors(B2), 2);
278        assert_eq!(graph.num_predecessors(B3), 1);
279        assert_eq!(graph.num_predecessors(A3), 1);
280
281        assert_eq!(graph.out_edges(A1), &[A2]);
282        assert_eq!(graph.out_edges(B1), &[B2]);
283        assert_eq!(graph.out_edges(A2), &[B2]);
284        assert_eq!(graph.out_edges(B2), &[B3]);
285        assert_eq!(graph.out_edges(A3), &[A2]);
286        assert_eq!(graph.out_edges(B3), &[A3]);
287    }
288
289    #[test]
290    fn callgraph_subgraph() {
291        let graph = callgraph_simple();
292        let subgraph = graph.subgraph(A2);
293
294        assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
295    }
296
297    #[test]
298    fn callgraph_with_cycle_subgraph() {
299        let graph = callgraph_cycle();
300        let subgraph = graph.subgraph(A2);
301
302        assert_eq!(subgraph.nodes.keys().copied().collect::<Vec<_>>(), vec![A2, A3, B2, B3]);
303    }
304
305    #[test]
306    fn callgraph_toposort() {
307        let graph = callgraph_simple();
308
309        let sorted = graph.toposort().expect("expected valid topological ordering");
310        assert_eq!(sorted.as_slice(), &[A1, B1, A2, B2, B3, A3]);
311    }
312
313    #[test]
314    fn callgraph_toposort_caller() {
315        let graph = callgraph_simple();
316
317        let sorted = graph.toposort_caller(A2).expect("expected valid topological ordering");
318        assert_eq!(sorted.as_slice(), &[A2, B2, B3, A3]);
319    }
320
321    #[test]
322    fn callgraph_with_cycle_toposort() {
323        let graph = callgraph_cycle();
324
325        let err = graph.toposort().expect_err("expected topological sort to fail with cycle");
326        assert_eq!(err.0.into_iter().collect::<Vec<_>>(), &[A2, A3, B2, B3]);
327    }
328
329    /// a::a1 -> a::a2 -> a::a3
330    ///            |        ^
331    ///            v        |
332    /// b::b1 -> b::b2 -> b::b3
333    fn callgraph_simple() -> CallGraph {
334        // Construct the graph
335        let mut graph = CallGraph::default();
336        graph.add_edge(A1, A2);
337        graph.add_edge(B1, B2);
338        graph.add_edge(A2, B2);
339        graph.add_edge(A2, A3);
340        graph.add_edge(B2, B3);
341        graph.add_edge(B3, A3);
342
343        graph
344    }
345
346    /// a::a1 -> a::a2 <- a::a3
347    ///            |        ^
348    ///            v        |
349    /// b::b1 -> b::b2 -> b::b3
350    fn callgraph_cycle() -> CallGraph {
351        // Construct the graph
352        let mut graph = CallGraph::default();
353        graph.add_edge(A1, A2);
354        graph.add_edge(B1, B2);
355        graph.add_edge(A2, B2);
356        graph.add_edge(B2, B3);
357        graph.add_edge(B3, A3);
358        graph.add_edge(A3, A2);
359
360        graph
361    }
362}