1use crate::Location;
18use leo_span::Symbol;
19
20use indexmap::{IndexMap, IndexSet};
21use std::{fmt::Debug, hash::Hash, rc::Rc};
22
23pub type CompositeGraph = DiGraph<Location>;
25
26pub type CallGraph = DiGraph<Location>;
28
29pub type ImportGraph = DiGraph<Symbol>;
31
32pub trait GraphNode: Clone + 'static + Eq + PartialEq + Debug + Hash {}
34
35impl<T> GraphNode for T where T: 'static + Clone + Eq + PartialEq + Debug + Hash {}
36
37#[derive(Debug)]
39pub enum DiGraphError<N: GraphNode> {
40 CycleDetected(Vec<N>),
42}
43
44#[derive(Clone, Debug, PartialEq, Eq)]
46pub struct DiGraph<N: GraphNode> {
47 nodes: IndexSet<Rc<N>>,
49
50 edges: IndexMap<Rc<N>, IndexSet<Rc<N>>>,
53}
54
55impl<N: GraphNode> Default for DiGraph<N> {
56 fn default() -> Self {
57 Self { nodes: IndexSet::new(), edges: IndexMap::new() }
58 }
59}
60
61impl<N: GraphNode> DiGraph<N> {
62 pub fn new(nodes: IndexSet<N>) -> Self {
64 let nodes: IndexSet<_> = nodes.into_iter().map(Rc::new).collect();
65 Self { nodes, edges: IndexMap::new() }
66 }
67
68 pub fn add_node(&mut self, node: N) {
70 self.nodes.insert(Rc::new(node));
71 }
72
73 pub fn nodes(&self) -> impl Iterator<Item = &N> {
75 self.nodes.iter().map(|rc| rc.as_ref())
76 }
77
78 pub fn add_edge(&mut self, from: N, to: N) {
80 let from_rc = self.get_or_insert(from);
82 let to_rc = self.get_or_insert(to);
83
84 self.edges.entry(from_rc).or_default().insert(to_rc);
86 }
87
88 pub fn remove_node(&mut self, node: &N) -> bool {
90 if let Some(rc_node) = self.nodes.shift_take(&Rc::new(node.clone())) {
91 self.edges.shift_remove(&rc_node);
93
94 for targets in self.edges.values_mut() {
96 targets.shift_remove(&rc_node);
97 }
98 true
99 } else {
100 false
101 }
102 }
103
104 pub fn neighbors(&self, node: &N) -> impl Iterator<Item = &N> {
106 self.edges
107 .get(node) .into_iter()
109 .flat_map(|neighbors| neighbors.iter().map(|rc| rc.as_ref()))
110 }
111
112 pub fn transitive_closure(&self, node: &N) -> IndexSet<N> {
114 let mut res = IndexSet::new();
115 let mut queue: Vec<_> = self.neighbors(node).collect();
116
117 while let Some(cur) = queue.pop() {
118 if !res.contains(cur) {
119 res.insert(cur.clone());
120 queue.extend(self.neighbors(cur));
121 }
122 }
123 res
124 }
125
126 pub fn contains_node(&self, node: N) -> bool {
128 self.nodes.contains(&Rc::new(node))
129 }
130
131 pub fn post_order(&self) -> Result<IndexSet<N>, DiGraphError<N>> {
134 self.post_order_with_filter(|_| true)
135 }
136
137 pub fn post_order_with_filter<F>(&self, filter: F) -> Result<IndexSet<N>, DiGraphError<N>>
142 where
143 F: Fn(&N) -> bool,
144 {
145 let mut finished = IndexSet::with_capacity(self.nodes.len());
147
148 for node_rc in self.nodes.iter().filter(|n| filter(n.as_ref())) {
151 if !finished.contains(node_rc) {
153 let mut discovered = IndexSet::new();
155 if let Some(cycle_node) = self.contains_cycle_from(node_rc, &mut discovered, &mut finished) {
157 let mut path = vec![cycle_node.as_ref().clone()];
158 while let Some(next) = discovered.pop() {
160 path.push(next.as_ref().clone());
162 if Rc::ptr_eq(&next, &cycle_node) {
164 break;
165 }
166 }
167 path.reverse();
169 return Err(DiGraphError::CycleDetected(path));
171 }
172 }
173 }
174
175 Ok(finished.iter().map(|rc| (**rc).clone()).collect())
177 }
178
179 pub fn retain_nodes(&mut self, keep: &IndexSet<N>) {
181 let keep: IndexSet<_> = keep.iter().map(|n| Rc::new(n.clone())).collect();
182 self.nodes.retain(|n| keep.contains(n));
184 self.edges.retain(|n, _| keep.contains(n));
185 for targets in self.edges.values_mut() {
187 targets.retain(|t| keep.contains(t));
188 }
189 }
190
191 fn contains_cycle_from(
196 &self,
197 node: &Rc<N>,
198 discovered: &mut IndexSet<Rc<N>>,
199 finished: &mut IndexSet<Rc<N>>,
200 ) -> Option<Rc<N>> {
201 discovered.insert(node.clone());
203
204 if let Some(children) = self.edges.get(node) {
206 for child in children {
207 if discovered.contains(child) {
209 return Some(child.clone());
212 }
213 if !finished.contains(child)
215 && let Some(cycle_node) = self.contains_cycle_from(child, discovered, finished)
216 {
217 return Some(cycle_node);
218 }
219 }
220 }
221
222 discovered.pop();
224 finished.insert(node.clone());
226 None
227 }
228
229 fn get_or_insert(&mut self, node: N) -> Rc<N> {
231 if let Some(existing) = self.nodes.get(&node) {
232 return existing.clone();
233 }
234 let rc = Rc::new(node);
235 self.nodes.insert(rc.clone());
236 rc
237 }
238}
239
240#[cfg(test)]
241mod test {
242 use super::*;
243
244 fn check_post_order<N: GraphNode>(graph: &DiGraph<N>, expected: &[N]) {
245 let result = graph.post_order();
246 assert!(result.is_ok());
247
248 let order: Vec<N> = result.unwrap().into_iter().collect();
249 assert_eq!(order, expected);
250 }
251
252 #[test]
253 fn test_post_order() {
254 let mut graph = DiGraph::<u32>::new(IndexSet::new());
255
256 graph.add_edge(1, 2);
257 graph.add_edge(1, 3);
258 graph.add_edge(2, 4);
259 graph.add_edge(3, 4);
260 graph.add_edge(4, 5);
261
262 check_post_order(&graph, &[5, 4, 2, 3, 1]);
263
264 let mut graph = DiGraph::<u32>::new(IndexSet::new());
265
266 graph.add_edge(6, 2);
268 graph.add_edge(2, 1);
270 graph.add_edge(2, 4);
272 graph.add_edge(4, 3);
274 graph.add_edge(4, 5);
276 graph.add_edge(6, 7);
278 graph.add_edge(7, 9);
280 graph.add_edge(9, 8);
282
283 check_post_order(&graph, &[1, 3, 5, 4, 2, 8, 9, 7, 6]);
285 }
286
287 #[test]
288 fn test_cycle() {
289 let mut graph = DiGraph::<u32>::new(IndexSet::new());
290
291 graph.add_edge(1, 2);
292 graph.add_edge(2, 3);
293 graph.add_edge(2, 4);
294 graph.add_edge(4, 1);
295
296 let result = graph.post_order();
297 assert!(result.is_err());
298
299 let DiGraphError::CycleDetected(cycle) = result.unwrap_err();
300 let expected = Vec::from([1u32, 2, 4, 1]);
301 assert_eq!(cycle, expected);
302 }
303
304 #[test]
305 fn test_transitive_closure() {
306 let mut graph = DiGraph::<u32>::new(IndexSet::new());
307
308 graph.add_edge(1, 2);
309 graph.add_edge(2, 3);
310 graph.add_edge(2, 4);
311 graph.add_edge(4, 1);
312 graph.add_edge(3, 5);
313
314 assert_eq!(graph.transitive_closure(&2), IndexSet::from([4, 1, 2, 3, 5]));
315 assert_eq!(graph.transitive_closure(&3), IndexSet::from([5]));
316 assert_eq!(graph.transitive_closure(&5), IndexSet::from([]));
317
318 let mut graph = DiGraph::<u32>::new(IndexSet::new());
319 graph.add_edge(1, 2);
320 graph.add_edge(1, 3);
321 graph.add_edge(2, 5);
322 graph.add_edge(3, 5);
323 graph.add_edge(3, 4);
324 assert_eq!(graph.transitive_closure(&1), IndexSet::from([2, 5, 3, 4]));
325 assert_eq!(graph.transitive_closure(&2), IndexSet::from([5]));
326 assert_eq!(graph.transitive_closure(&3), IndexSet::from([5, 4]));
327 assert_eq!(graph.transitive_closure(&4), IndexSet::from([]));
328 }
329
330 #[test]
331 fn test_unconnected_graph() {
332 let graph = DiGraph::<u32>::new(IndexSet::from([1, 2, 3, 4, 5]));
333
334 check_post_order(&graph, &[1, 2, 3, 4, 5]);
335 }
336
337 #[test]
338 fn test_retain_nodes() {
339 let mut graph = DiGraph::<u32>::new(IndexSet::new());
340
341 graph.add_edge(1, 2);
342 graph.add_edge(1, 3);
343 graph.add_edge(1, 5);
344 graph.add_edge(2, 3);
345 graph.add_edge(2, 4);
346 graph.add_edge(2, 5);
347 graph.add_edge(3, 4);
348 graph.add_edge(4, 5);
349
350 let mut nodes = IndexSet::new();
351 nodes.insert(1);
352 nodes.insert(2);
353 nodes.insert(3);
354
355 graph.retain_nodes(&nodes);
356
357 let mut expected = DiGraph::<u32>::new(IndexSet::new());
358 expected.add_edge(1, 2);
359 expected.add_edge(1, 3);
360 expected.add_edge(2, 3);
361 expected.edges.insert(3.into(), IndexSet::new());
362
363 assert_eq!(graph, expected);
364 }
365}