Skip to main content

leo_ast/common/graph/
mod.rs

1// Copyright (C) 2019-2026 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::Location;
18use leo_span::Symbol;
19
20use indexmap::{IndexMap, IndexSet};
21use std::{fmt::Debug, hash::Hash, rc::Rc};
22
23/// A struct dependency graph.
24pub type CompositeGraph = DiGraph<Location>;
25
26/// A call graph.
27pub type CallGraph = DiGraph<Location>;
28
29/// An import dependency graph.
30pub type ImportGraph = DiGraph<Symbol>;
31
32/// A node in a graph.
33pub trait GraphNode: Clone + 'static + Eq + PartialEq + Debug + Hash {}
34
35impl<T> GraphNode for T where T: 'static + Clone + Eq + PartialEq + Debug + Hash {}
36
37/// Errors in directed graph operations.
38#[derive(Debug)]
39pub enum DiGraphError<N: GraphNode> {
40    /// An error that is emitted when a cycle is detected in the directed graph. Contains the path of the cycle.
41    CycleDetected(Vec<N>),
42}
43
44/// A directed graph using reference-counted nodes.
45#[derive(Clone, Debug, PartialEq, Eq)]
46pub struct DiGraph<N: GraphNode> {
47    /// The set of nodes in the graph.
48    nodes: IndexSet<Rc<N>>,
49
50    /// The directed edges in the graph.
51    /// Each entry in the map is a node in the graph, and the set of nodes that it points to.
52    edges: IndexMap<Rc<N>, IndexSet<Rc<N>>>,
53}
54
55impl<N: GraphNode> Default for DiGraph<N> {
56    fn default() -> Self {
57        Self { nodes: IndexSet::new(), edges: IndexMap::new() }
58    }
59}
60
61impl<N: GraphNode> DiGraph<N> {
62    /// Initializes a new `DiGraph` from a set of source nodes.
63    pub fn new(nodes: IndexSet<N>) -> Self {
64        let nodes: IndexSet<_> = nodes.into_iter().map(Rc::new).collect();
65        Self { nodes, edges: IndexMap::new() }
66    }
67
68    /// Adds a node to the graph.
69    pub fn add_node(&mut self, node: N) {
70        self.nodes.insert(Rc::new(node));
71    }
72
73    /// Returns an iterator over the nodes in the graph.
74    pub fn nodes(&self) -> impl Iterator<Item = &N> {
75        self.nodes.iter().map(|rc| rc.as_ref())
76    }
77
78    /// Adds an edge to the graph.
79    pub fn add_edge(&mut self, from: N, to: N) {
80        // Add `from` and `to` to the set of nodes if they are not already in the set.
81        let from_rc = self.get_or_insert(from);
82        let to_rc = self.get_or_insert(to);
83
84        // Add the edge to the adjacency list.
85        self.edges.entry(from_rc).or_default().insert(to_rc);
86    }
87
88    /// Removes a node and all associated edges from the graph.
89    pub fn remove_node(&mut self, node: &N) -> bool {
90        if let Some(rc_node) = self.nodes.shift_take(&Rc::new(node.clone())) {
91            // Remove all outgoing edges from the node
92            self.edges.shift_remove(&rc_node);
93
94            // Remove all incoming edges to the node
95            for targets in self.edges.values_mut() {
96                targets.shift_remove(&rc_node);
97            }
98            true
99        } else {
100            false
101        }
102    }
103
104    /// Returns an iterator to the immediate neighbors of a given node.
105    pub fn neighbors(&self, node: &N) -> impl Iterator<Item = &N> {
106        self.edges
107            .get(node) // ← no Rc::from() needed!
108            .into_iter()
109            .flat_map(|neighbors| neighbors.iter().map(|rc| rc.as_ref()))
110    }
111
112    /// Returns `true` if the graph contains the given node.
113    pub fn contains_node(&self, node: N) -> bool {
114        self.nodes.contains(&Rc::new(node))
115    }
116
117    /// Returns the post-order ordering of the graph.
118    /// Detects if there is a cycle in the graph.
119    pub fn post_order(&self) -> Result<IndexSet<N>, DiGraphError<N>> {
120        self.post_order_with_filter(|_| true)
121    }
122
123    /// Returns the post-order ordering of the graph but only considering a subset of the nodes that
124    /// satisfy the given filter.
125    ///
126    /// Detects if there is a cycle in the graph.
127    pub fn post_order_with_filter<F>(&self, filter: F) -> Result<IndexSet<N>, DiGraphError<N>>
128    where
129        F: Fn(&N) -> bool,
130    {
131        // The set of nodes that do not need to be visited again.
132        let mut finished = IndexSet::with_capacity(self.nodes.len());
133
134        // Perform a depth-first search of the graph, starting from `node`, for each node in the graph that satisfies
135        // `is_entry_point`.
136        for node_rc in self.nodes.iter().filter(|n| filter(n.as_ref())) {
137            // If the node has not been explored, explore it.
138            if !finished.contains(node_rc) {
139                // The set of nodes that are on the path to the current node in the searc
140                let mut discovered = IndexSet::new();
141                // Check if there is a cycle in the graph starting from `node`.
142                if let Some(cycle_node) = self.contains_cycle_from(node_rc, &mut discovered, &mut finished) {
143                    let mut path = vec![cycle_node.as_ref().clone()];
144                    // Backtrack through the discovered nodes to find the cycle.
145                    while let Some(next) = discovered.pop() {
146                        // Add the node to the path.
147                        path.push(next.as_ref().clone());
148                        // If the node is the same as the first node in the path, we have found the cycle.
149                        if Rc::ptr_eq(&next, &cycle_node) {
150                            break;
151                        }
152                    }
153                    // Reverse the path to get the cycle in the correct order.
154                    path.reverse();
155                    // A cycle was detected. Return the path of the cycle.
156                    return Err(DiGraphError::CycleDetected(path));
157                }
158            }
159        }
160
161        // No cycle was found. Return the set of nodes in topological order.
162        Ok(finished.iter().map(|rc| (**rc).clone()).collect())
163    }
164
165    /// Retains a subset of the nodes, and removes all edges in which the source or destination is not in the subset.
166    pub fn retain_nodes(&mut self, keep: &IndexSet<N>) {
167        let keep: IndexSet<_> = keep.iter().map(|n| Rc::new(n.clone())).collect();
168        // Remove the nodes from the set of nodes.
169        self.nodes.retain(|n| keep.contains(n));
170        self.edges.retain(|n, _| keep.contains(n));
171        // Remove the edges that reference the nodes.
172        for targets in self.edges.values_mut() {
173            targets.retain(|t| keep.contains(t));
174        }
175    }
176
177    // Detects if there is a cycle in the graph starting from the given node, via a recursive depth-first search.
178    // If there is no cycle, returns `None`.
179    // If there is a cycle, returns the node that was most recently discovered.
180    // Nodes are added to `finished` in post-order order.
181    fn contains_cycle_from(
182        &self,
183        node: &Rc<N>,
184        discovered: &mut IndexSet<Rc<N>>,
185        finished: &mut IndexSet<Rc<N>>,
186    ) -> Option<Rc<N>> {
187        // Add the node to the set of discovered nodes.
188        discovered.insert(node.clone());
189
190        // Check each outgoing edge of the node.
191        if let Some(children) = self.edges.get(node) {
192            for child in children {
193                // If the node already been discovered, there is a cycle.
194                if discovered.contains(child) {
195                    // Insert the child node into the set of discovered nodes; this is used to reconstruct the cycle.
196                    // Note that this case is always hit when there is a cycle.
197                    return Some(child.clone());
198                }
199                // If the node has not been explored, explore it.
200                if !finished.contains(child)
201                    && let Some(cycle_node) = self.contains_cycle_from(child, discovered, finished)
202                {
203                    return Some(cycle_node);
204                }
205            }
206        }
207
208        // Remove the node from the set of discovered nodes.
209        discovered.pop();
210        // Add the node to the set of finished nodes.
211        finished.insert(node.clone());
212        None
213    }
214
215    /// Helper: get or insert Rc<N> into the graph.
216    fn get_or_insert(&mut self, node: N) -> Rc<N> {
217        if let Some(existing) = self.nodes.get(&node) {
218            return existing.clone();
219        }
220        let rc = Rc::new(node);
221        self.nodes.insert(rc.clone());
222        rc
223    }
224}
225
226#[cfg(test)]
227mod test {
228    use super::*;
229
230    fn check_post_order<N: GraphNode>(graph: &DiGraph<N>, expected: &[N]) {
231        let result = graph.post_order();
232        assert!(result.is_ok());
233
234        let order: Vec<N> = result.unwrap().into_iter().collect();
235        assert_eq!(order, expected);
236    }
237
238    #[test]
239    fn test_post_order() {
240        let mut graph = DiGraph::<u32>::new(IndexSet::new());
241
242        graph.add_edge(1, 2);
243        graph.add_edge(1, 3);
244        graph.add_edge(2, 4);
245        graph.add_edge(3, 4);
246        graph.add_edge(4, 5);
247
248        check_post_order(&graph, &[5, 4, 2, 3, 1]);
249
250        let mut graph = DiGraph::<u32>::new(IndexSet::new());
251
252        // F -> B
253        graph.add_edge(6, 2);
254        // B -> A
255        graph.add_edge(2, 1);
256        // B -> D
257        graph.add_edge(2, 4);
258        // D -> C
259        graph.add_edge(4, 3);
260        // D -> E
261        graph.add_edge(4, 5);
262        // F -> G
263        graph.add_edge(6, 7);
264        // G -> I
265        graph.add_edge(7, 9);
266        // I -> H
267        graph.add_edge(9, 8);
268
269        // A, C, E, D, B, H, I, G, F.
270        check_post_order(&graph, &[1, 3, 5, 4, 2, 8, 9, 7, 6]);
271    }
272
273    #[test]
274    fn test_cycle() {
275        let mut graph = DiGraph::<u32>::new(IndexSet::new());
276
277        graph.add_edge(1, 2);
278        graph.add_edge(2, 3);
279        graph.add_edge(2, 4);
280        graph.add_edge(4, 1);
281
282        let result = graph.post_order();
283        assert!(result.is_err());
284
285        let DiGraphError::CycleDetected(cycle) = result.unwrap_err();
286        let expected = Vec::from([1u32, 2, 4, 1]);
287        assert_eq!(cycle, expected);
288    }
289
290    #[test]
291    fn test_unconnected_graph() {
292        let graph = DiGraph::<u32>::new(IndexSet::from([1, 2, 3, 4, 5]));
293
294        check_post_order(&graph, &[1, 2, 3, 4, 5]);
295    }
296
297    #[test]
298    fn test_retain_nodes() {
299        let mut graph = DiGraph::<u32>::new(IndexSet::new());
300
301        graph.add_edge(1, 2);
302        graph.add_edge(1, 3);
303        graph.add_edge(1, 5);
304        graph.add_edge(2, 3);
305        graph.add_edge(2, 4);
306        graph.add_edge(2, 5);
307        graph.add_edge(3, 4);
308        graph.add_edge(4, 5);
309
310        let mut nodes = IndexSet::new();
311        nodes.insert(1);
312        nodes.insert(2);
313        nodes.insert(3);
314
315        graph.retain_nodes(&nodes);
316
317        let mut expected = DiGraph::<u32>::new(IndexSet::new());
318        expected.add_edge(1, 2);
319        expected.add_edge(1, 3);
320        expected.add_edge(2, 3);
321        expected.edges.insert(3.into(), IndexSet::new());
322
323        assert_eq!(graph, expected);
324    }
325}