Skip to main content

leo_ast/common/graph/
mod.rs

1// Copyright (C) 2019-2026 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::Location;
18use leo_span::Symbol;
19
20use indexmap::{IndexMap, IndexSet};
21use std::{fmt::Debug, hash::Hash, rc::Rc};
22
23/// A struct dependency graph.
24pub type CompositeGraph = DiGraph<Location>;
25
26/// A call graph.
27pub type CallGraph = DiGraph<Location>;
28
29/// An import dependency graph.
30pub type ImportGraph = DiGraph<Symbol>;
31
32/// A node in a graph.
33pub trait GraphNode: Clone + 'static + Eq + PartialEq + Debug + Hash {}
34
35impl<T> GraphNode for T where T: 'static + Clone + Eq + PartialEq + Debug + Hash {}
36
37/// Errors in directed graph operations.
38#[derive(Debug)]
39pub enum DiGraphError<N: GraphNode> {
40    /// An error that is emitted when a cycle is detected in the directed graph. Contains the path of the cycle.
41    CycleDetected(Vec<N>),
42}
43
44/// A directed graph using reference-counted nodes.
45#[derive(Clone, Debug, PartialEq, Eq)]
46pub struct DiGraph<N: GraphNode> {
47    /// The set of nodes in the graph.
48    nodes: IndexSet<Rc<N>>,
49
50    /// The directed edges in the graph.
51    /// Each entry in the map is a node in the graph, and the set of nodes that it points to.
52    edges: IndexMap<Rc<N>, IndexSet<Rc<N>>>,
53}
54
55impl<N: GraphNode> Default for DiGraph<N> {
56    fn default() -> Self {
57        Self { nodes: IndexSet::new(), edges: IndexMap::new() }
58    }
59}
60
61impl<N: GraphNode> DiGraph<N> {
62    /// Initializes a new `DiGraph` from a set of source nodes.
63    pub fn new(nodes: IndexSet<N>) -> Self {
64        let nodes: IndexSet<_> = nodes.into_iter().map(Rc::new).collect();
65        Self { nodes, edges: IndexMap::new() }
66    }
67
68    /// Adds a node to the graph.
69    pub fn add_node(&mut self, node: N) {
70        self.nodes.insert(Rc::new(node));
71    }
72
73    /// Returns an iterator over the nodes in the graph.
74    pub fn nodes(&self) -> impl Iterator<Item = &N> {
75        self.nodes.iter().map(|rc| rc.as_ref())
76    }
77
78    /// Adds an edge to the graph.
79    pub fn add_edge(&mut self, from: N, to: N) {
80        // Add `from` and `to` to the set of nodes if they are not already in the set.
81        let from_rc = self.get_or_insert(from);
82        let to_rc = self.get_or_insert(to);
83
84        // Add the edge to the adjacency list.
85        self.edges.entry(from_rc).or_default().insert(to_rc);
86    }
87
88    /// Removes a node and all associated edges from the graph.
89    pub fn remove_node(&mut self, node: &N) -> bool {
90        if let Some(rc_node) = self.nodes.shift_take(&Rc::new(node.clone())) {
91            // Remove all outgoing edges from the node
92            self.edges.shift_remove(&rc_node);
93
94            // Remove all incoming edges to the node
95            for targets in self.edges.values_mut() {
96                targets.shift_remove(&rc_node);
97            }
98            true
99        } else {
100            false
101        }
102    }
103
104    /// Returns an iterator to the immediate neighbors of a given node.
105    pub fn neighbors(&self, node: &N) -> impl Iterator<Item = &N> {
106        self.edges
107            .get(node) // ← no Rc::from() needed!
108            .into_iter()
109            .flat_map(|neighbors| neighbors.iter().map(|rc| rc.as_ref()))
110    }
111
112    /// Returns all the nodes that can be reached by a given node
113    pub fn transitive_closure(&self, node: &N) -> IndexSet<N> {
114        let mut res = IndexSet::new();
115        let mut queue: Vec<_> = self.neighbors(node).collect();
116
117        while let Some(cur) = queue.pop() {
118            if !res.contains(cur) {
119                res.insert(cur.clone());
120                queue.extend(self.neighbors(cur));
121            }
122        }
123        res
124    }
125
126    /// Returns `true` if the graph contains the given node.
127    pub fn contains_node(&self, node: N) -> bool {
128        self.nodes.contains(&Rc::new(node))
129    }
130
131    /// Returns the post-order ordering of the graph.
132    /// Detects if there is a cycle in the graph.
133    pub fn post_order(&self) -> Result<IndexSet<N>, DiGraphError<N>> {
134        self.post_order_with_filter(|_| true)
135    }
136
137    /// Returns the post-order ordering of the graph but only considering a subset of the nodes that
138    /// satisfy the given filter.
139    ///
140    /// Detects if there is a cycle in the graph.
141    pub fn post_order_with_filter<F>(&self, filter: F) -> Result<IndexSet<N>, DiGraphError<N>>
142    where
143        F: Fn(&N) -> bool,
144    {
145        // The set of nodes that do not need to be visited again.
146        let mut finished = IndexSet::with_capacity(self.nodes.len());
147
148        // Perform a depth-first search of the graph, starting from `node`, for each node in the graph that satisfies
149        // `is_entry_point`.
150        for node_rc in self.nodes.iter().filter(|n| filter(n.as_ref())) {
151            // If the node has not been explored, explore it.
152            if !finished.contains(node_rc) {
153                // The set of nodes that are on the path to the current node in the searc
154                let mut discovered = IndexSet::new();
155                // Check if there is a cycle in the graph starting from `node`.
156                if let Some(cycle_node) = self.contains_cycle_from(node_rc, &mut discovered, &mut finished) {
157                    let mut path = vec![cycle_node.as_ref().clone()];
158                    // Backtrack through the discovered nodes to find the cycle.
159                    while let Some(next) = discovered.pop() {
160                        // Add the node to the path.
161                        path.push(next.as_ref().clone());
162                        // If the node is the same as the first node in the path, we have found the cycle.
163                        if Rc::ptr_eq(&next, &cycle_node) {
164                            break;
165                        }
166                    }
167                    // Reverse the path to get the cycle in the correct order.
168                    path.reverse();
169                    // A cycle was detected. Return the path of the cycle.
170                    return Err(DiGraphError::CycleDetected(path));
171                }
172            }
173        }
174
175        // No cycle was found. Return the set of nodes in topological order.
176        Ok(finished.iter().map(|rc| (**rc).clone()).collect())
177    }
178
179    /// Retains a subset of the nodes, and removes all edges in which the source or destination is not in the subset.
180    pub fn retain_nodes(&mut self, keep: &IndexSet<N>) {
181        let keep: IndexSet<_> = keep.iter().map(|n| Rc::new(n.clone())).collect();
182        // Remove the nodes from the set of nodes.
183        self.nodes.retain(|n| keep.contains(n));
184        self.edges.retain(|n, _| keep.contains(n));
185        // Remove the edges that reference the nodes.
186        for targets in self.edges.values_mut() {
187            targets.retain(|t| keep.contains(t));
188        }
189    }
190
191    // Detects if there is a cycle in the graph starting from the given node, via a recursive depth-first search.
192    // If there is no cycle, returns `None`.
193    // If there is a cycle, returns the node that was most recently discovered.
194    // Nodes are added to `finished` in post-order order.
195    fn contains_cycle_from(
196        &self,
197        node: &Rc<N>,
198        discovered: &mut IndexSet<Rc<N>>,
199        finished: &mut IndexSet<Rc<N>>,
200    ) -> Option<Rc<N>> {
201        // Add the node to the set of discovered nodes.
202        discovered.insert(node.clone());
203
204        // Check each outgoing edge of the node.
205        if let Some(children) = self.edges.get(node) {
206            for child in children {
207                // If the node already been discovered, there is a cycle.
208                if discovered.contains(child) {
209                    // Insert the child node into the set of discovered nodes; this is used to reconstruct the cycle.
210                    // Note that this case is always hit when there is a cycle.
211                    return Some(child.clone());
212                }
213                // If the node has not been explored, explore it.
214                if !finished.contains(child)
215                    && let Some(cycle_node) = self.contains_cycle_from(child, discovered, finished)
216                {
217                    return Some(cycle_node);
218                }
219            }
220        }
221
222        // Remove the node from the set of discovered nodes.
223        discovered.pop();
224        // Add the node to the set of finished nodes.
225        finished.insert(node.clone());
226        None
227    }
228
229    /// Helper: get or insert Rc<N> into the graph.
230    fn get_or_insert(&mut self, node: N) -> Rc<N> {
231        if let Some(existing) = self.nodes.get(&node) {
232            return existing.clone();
233        }
234        let rc = Rc::new(node);
235        self.nodes.insert(rc.clone());
236        rc
237    }
238}
239
240#[cfg(test)]
241mod test {
242    use super::*;
243
244    fn check_post_order<N: GraphNode>(graph: &DiGraph<N>, expected: &[N]) {
245        let result = graph.post_order();
246        assert!(result.is_ok());
247
248        let order: Vec<N> = result.unwrap().into_iter().collect();
249        assert_eq!(order, expected);
250    }
251
252    #[test]
253    fn test_post_order() {
254        let mut graph = DiGraph::<u32>::new(IndexSet::new());
255
256        graph.add_edge(1, 2);
257        graph.add_edge(1, 3);
258        graph.add_edge(2, 4);
259        graph.add_edge(3, 4);
260        graph.add_edge(4, 5);
261
262        check_post_order(&graph, &[5, 4, 2, 3, 1]);
263
264        let mut graph = DiGraph::<u32>::new(IndexSet::new());
265
266        // F -> B
267        graph.add_edge(6, 2);
268        // B -> A
269        graph.add_edge(2, 1);
270        // B -> D
271        graph.add_edge(2, 4);
272        // D -> C
273        graph.add_edge(4, 3);
274        // D -> E
275        graph.add_edge(4, 5);
276        // F -> G
277        graph.add_edge(6, 7);
278        // G -> I
279        graph.add_edge(7, 9);
280        // I -> H
281        graph.add_edge(9, 8);
282
283        // A, C, E, D, B, H, I, G, F.
284        check_post_order(&graph, &[1, 3, 5, 4, 2, 8, 9, 7, 6]);
285    }
286
287    #[test]
288    fn test_cycle() {
289        let mut graph = DiGraph::<u32>::new(IndexSet::new());
290
291        graph.add_edge(1, 2);
292        graph.add_edge(2, 3);
293        graph.add_edge(2, 4);
294        graph.add_edge(4, 1);
295
296        let result = graph.post_order();
297        assert!(result.is_err());
298
299        let DiGraphError::CycleDetected(cycle) = result.unwrap_err();
300        let expected = Vec::from([1u32, 2, 4, 1]);
301        assert_eq!(cycle, expected);
302    }
303
304    #[test]
305    fn test_transitive_closure() {
306        let mut graph = DiGraph::<u32>::new(IndexSet::new());
307
308        graph.add_edge(1, 2);
309        graph.add_edge(2, 3);
310        graph.add_edge(2, 4);
311        graph.add_edge(4, 1);
312        graph.add_edge(3, 5);
313
314        assert_eq!(graph.transitive_closure(&2), IndexSet::from([4, 1, 2, 3, 5]));
315        assert_eq!(graph.transitive_closure(&3), IndexSet::from([5]));
316        assert_eq!(graph.transitive_closure(&5), IndexSet::from([]));
317
318        let mut graph = DiGraph::<u32>::new(IndexSet::new());
319        graph.add_edge(1, 2);
320        graph.add_edge(1, 3);
321        graph.add_edge(2, 5);
322        graph.add_edge(3, 5);
323        graph.add_edge(3, 4);
324        assert_eq!(graph.transitive_closure(&1), IndexSet::from([2, 5, 3, 4]));
325        assert_eq!(graph.transitive_closure(&2), IndexSet::from([5]));
326        assert_eq!(graph.transitive_closure(&3), IndexSet::from([5, 4]));
327        assert_eq!(graph.transitive_closure(&4), IndexSet::from([]));
328    }
329
330    #[test]
331    fn test_unconnected_graph() {
332        let graph = DiGraph::<u32>::new(IndexSet::from([1, 2, 3, 4, 5]));
333
334        check_post_order(&graph, &[1, 2, 3, 4, 5]);
335    }
336
337    #[test]
338    fn test_retain_nodes() {
339        let mut graph = DiGraph::<u32>::new(IndexSet::new());
340
341        graph.add_edge(1, 2);
342        graph.add_edge(1, 3);
343        graph.add_edge(1, 5);
344        graph.add_edge(2, 3);
345        graph.add_edge(2, 4);
346        graph.add_edge(2, 5);
347        graph.add_edge(3, 4);
348        graph.add_edge(4, 5);
349
350        let mut nodes = IndexSet::new();
351        nodes.insert(1);
352        nodes.insert(2);
353        nodes.insert(3);
354
355        graph.retain_nodes(&nodes);
356
357        let mut expected = DiGraph::<u32>::new(IndexSet::new());
358        expected.add_edge(1, 2);
359        expected.add_edge(1, 3);
360        expected.add_edge(2, 3);
361        expected.edges.insert(3.into(), IndexSet::new());
362
363        assert_eq!(graph, expected);
364    }
365}