1use crate::Location;
18use leo_span::Symbol;
19
20use indexmap::{IndexMap, IndexSet};
21use std::{fmt::Debug, hash::Hash, rc::Rc};
22
23pub type CompositeGraph = DiGraph<Location>;
25
26pub type CallGraph = DiGraph<Location>;
28
29pub type ImportGraph = DiGraph<Symbol>;
31
32pub trait GraphNode: Clone + 'static + Eq + PartialEq + Debug + Hash {}
34
35impl<T> GraphNode for T where T: 'static + Clone + Eq + PartialEq + Debug + Hash {}
36
37#[derive(Debug)]
39pub enum DiGraphError<N: GraphNode> {
40 CycleDetected(Vec<N>),
42}
43
44#[derive(Clone, Debug, PartialEq, Eq)]
46pub struct DiGraph<N: GraphNode> {
47 nodes: IndexSet<Rc<N>>,
49
50 edges: IndexMap<Rc<N>, IndexSet<Rc<N>>>,
53}
54
55impl<N: GraphNode> Default for DiGraph<N> {
56 fn default() -> Self {
57 Self { nodes: IndexSet::new(), edges: IndexMap::new() }
58 }
59}
60
61impl<N: GraphNode> DiGraph<N> {
62 pub fn new(nodes: IndexSet<N>) -> Self {
64 let nodes: IndexSet<_> = nodes.into_iter().map(Rc::new).collect();
65 Self { nodes, edges: IndexMap::new() }
66 }
67
68 pub fn add_node(&mut self, node: N) {
70 self.nodes.insert(Rc::new(node));
71 }
72
73 pub fn nodes(&self) -> impl Iterator<Item = &N> {
75 self.nodes.iter().map(|rc| rc.as_ref())
76 }
77
78 pub fn add_edge(&mut self, from: N, to: N) {
80 let from_rc = self.get_or_insert(from);
82 let to_rc = self.get_or_insert(to);
83
84 self.edges.entry(from_rc).or_default().insert(to_rc);
86 }
87
88 pub fn remove_node(&mut self, node: &N) -> bool {
90 if let Some(rc_node) = self.nodes.shift_take(&Rc::new(node.clone())) {
91 self.edges.shift_remove(&rc_node);
93
94 for targets in self.edges.values_mut() {
96 targets.shift_remove(&rc_node);
97 }
98 true
99 } else {
100 false
101 }
102 }
103
104 pub fn neighbors(&self, node: &N) -> impl Iterator<Item = &N> {
106 self.edges
107 .get(node) .into_iter()
109 .flat_map(|neighbors| neighbors.iter().map(|rc| rc.as_ref()))
110 }
111
112 pub fn contains_node(&self, node: N) -> bool {
114 self.nodes.contains(&Rc::new(node))
115 }
116
117 pub fn post_order(&self) -> Result<IndexSet<N>, DiGraphError<N>> {
120 self.post_order_with_filter(|_| true)
121 }
122
123 pub fn post_order_with_filter<F>(&self, filter: F) -> Result<IndexSet<N>, DiGraphError<N>>
128 where
129 F: Fn(&N) -> bool,
130 {
131 let mut finished = IndexSet::with_capacity(self.nodes.len());
133
134 for node_rc in self.nodes.iter().filter(|n| filter(n.as_ref())) {
137 if !finished.contains(node_rc) {
139 let mut discovered = IndexSet::new();
141 if let Some(cycle_node) = self.contains_cycle_from(node_rc, &mut discovered, &mut finished) {
143 let mut path = vec![cycle_node.as_ref().clone()];
144 while let Some(next) = discovered.pop() {
146 path.push(next.as_ref().clone());
148 if Rc::ptr_eq(&next, &cycle_node) {
150 break;
151 }
152 }
153 path.reverse();
155 return Err(DiGraphError::CycleDetected(path));
157 }
158 }
159 }
160
161 Ok(finished.iter().map(|rc| (**rc).clone()).collect())
163 }
164
165 pub fn retain_nodes(&mut self, keep: &IndexSet<N>) {
167 let keep: IndexSet<_> = keep.iter().map(|n| Rc::new(n.clone())).collect();
168 self.nodes.retain(|n| keep.contains(n));
170 self.edges.retain(|n, _| keep.contains(n));
171 for targets in self.edges.values_mut() {
173 targets.retain(|t| keep.contains(t));
174 }
175 }
176
177 fn contains_cycle_from(
182 &self,
183 node: &Rc<N>,
184 discovered: &mut IndexSet<Rc<N>>,
185 finished: &mut IndexSet<Rc<N>>,
186 ) -> Option<Rc<N>> {
187 discovered.insert(node.clone());
189
190 if let Some(children) = self.edges.get(node) {
192 for child in children {
193 if discovered.contains(child) {
195 return Some(child.clone());
198 }
199 if !finished.contains(child)
201 && let Some(cycle_node) = self.contains_cycle_from(child, discovered, finished)
202 {
203 return Some(cycle_node);
204 }
205 }
206 }
207
208 discovered.pop();
210 finished.insert(node.clone());
212 None
213 }
214
215 fn get_or_insert(&mut self, node: N) -> Rc<N> {
217 if let Some(existing) = self.nodes.get(&node) {
218 return existing.clone();
219 }
220 let rc = Rc::new(node);
221 self.nodes.insert(rc.clone());
222 rc
223 }
224}
225
226#[cfg(test)]
227mod test {
228 use super::*;
229
230 fn check_post_order<N: GraphNode>(graph: &DiGraph<N>, expected: &[N]) {
231 let result = graph.post_order();
232 assert!(result.is_ok());
233
234 let order: Vec<N> = result.unwrap().into_iter().collect();
235 assert_eq!(order, expected);
236 }
237
238 #[test]
239 fn test_post_order() {
240 let mut graph = DiGraph::<u32>::new(IndexSet::new());
241
242 graph.add_edge(1, 2);
243 graph.add_edge(1, 3);
244 graph.add_edge(2, 4);
245 graph.add_edge(3, 4);
246 graph.add_edge(4, 5);
247
248 check_post_order(&graph, &[5, 4, 2, 3, 1]);
249
250 let mut graph = DiGraph::<u32>::new(IndexSet::new());
251
252 graph.add_edge(6, 2);
254 graph.add_edge(2, 1);
256 graph.add_edge(2, 4);
258 graph.add_edge(4, 3);
260 graph.add_edge(4, 5);
262 graph.add_edge(6, 7);
264 graph.add_edge(7, 9);
266 graph.add_edge(9, 8);
268
269 check_post_order(&graph, &[1, 3, 5, 4, 2, 8, 9, 7, 6]);
271 }
272
273 #[test]
274 fn test_cycle() {
275 let mut graph = DiGraph::<u32>::new(IndexSet::new());
276
277 graph.add_edge(1, 2);
278 graph.add_edge(2, 3);
279 graph.add_edge(2, 4);
280 graph.add_edge(4, 1);
281
282 let result = graph.post_order();
283 assert!(result.is_err());
284
285 let DiGraphError::CycleDetected(cycle) = result.unwrap_err();
286 let expected = Vec::from([1u32, 2, 4, 1]);
287 assert_eq!(cycle, expected);
288 }
289
290 #[test]
291 fn test_unconnected_graph() {
292 let graph = DiGraph::<u32>::new(IndexSet::from([1, 2, 3, 4, 5]));
293
294 check_post_order(&graph, &[1, 2, 3, 4, 5]);
295 }
296
297 #[test]
298 fn test_retain_nodes() {
299 let mut graph = DiGraph::<u32>::new(IndexSet::new());
300
301 graph.add_edge(1, 2);
302 graph.add_edge(1, 3);
303 graph.add_edge(1, 5);
304 graph.add_edge(2, 3);
305 graph.add_edge(2, 4);
306 graph.add_edge(2, 5);
307 graph.add_edge(3, 4);
308 graph.add_edge(4, 5);
309
310 let mut nodes = IndexSet::new();
311 nodes.insert(1);
312 nodes.insert(2);
313 nodes.insert(3);
314
315 graph.retain_nodes(&nodes);
316
317 let mut expected = DiGraph::<u32>::new(IndexSet::new());
318 expected.add_edge(1, 2);
319 expected.add_edge(1, 3);
320 expected.add_edge(2, 3);
321 expected.edges.insert(3.into(), IndexSet::new());
322
323 assert_eq!(graph, expected);
324 }
325}