jthread/
directed_graph.rs

1use std::fmt::Debug;
2
3use std::collections::{HashMap, HashSet};
4use std::fmt::Formatter;
5use std::hash::Hash;
6
7#[derive(Debug, PartialEq)]
8pub enum DGError {
9    EdgeCreatesCycle,
10}
11
12impl std::fmt::Display for DGError {
13    fn fmt(&self, w: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
14        match self {
15            DGError::EdgeCreatesCycle => w.write_str("Edge creates cycle"),
16        }
17    }
18}
19
20pub struct DirectedGraph<T> {
21    adj_list: HashMap<T, Vec<T>>,
22}
23
24impl<T: Hash + PartialEq + Eq + Clone> DirectedGraph<T> {
25    pub fn new() -> Self {
26        Self {
27            adj_list: HashMap::new(),
28        }
29    }
30
31    pub fn add_edge_with_check(&mut self, src: T, dest: T) -> Result<(), DGError> {
32        // Temporarily add the edge
33        self.adj_list
34            .entry(src.clone())
35            .or_insert(vec![])
36            .push(dest.clone());
37
38        if self.is_cyclic() {
39            // If a cycle is detected, remove the edge and return an error
40            if let Some(edges) = self.adj_list.get_mut(&src) {
41                edges.retain(|x| *x != dest);
42            }
43            Err(DGError::EdgeCreatesCycle)
44        } else {
45            // If no cycle is detected, keep the edge and return Ok
46            Ok(())
47        }
48    }
49
50    // Method to check if the graph is cyclic
51    fn is_cyclic(&self) -> bool {
52        let mut visited = HashSet::new();
53        let mut rec_stack = HashSet::new();
54
55        for node in self.adj_list.keys() {
56            if !visited.contains(node) && self.is_cyclic_util(node, &mut visited, &mut rec_stack) {
57                return true;
58            }
59        }
60        false
61    }
62
63    fn is_cyclic_util(
64        &self,
65        node: &T,
66        visited: &mut HashSet<T>,
67        rec_stack: &mut HashSet<T>,
68    ) -> bool {
69        if rec_stack.contains(&node) {
70            return true;
71        }
72        if visited.contains(&node) {
73            return false;
74        }
75
76        visited.insert(node.clone());
77        rec_stack.insert(node.clone());
78
79        if let Some(neighbors) = self.adj_list.get(&node) {
80            for neighbor in neighbors {
81                if self.is_cyclic_util(&neighbor, visited, rec_stack) {
82                    return true;
83                }
84            }
85        }
86
87        rec_stack.remove(&node);
88        false
89    }
90}
91
92#[cfg(test)]
93mod directed_graph_tests {
94    use super::*;
95
96    #[test]
97    fn test_adding_edge_no_cycle() {
98        let mut graph = DirectedGraph::new();
99        assert!(graph.add_edge_with_check(0, 1).is_ok());
100        assert!(graph.add_edge_with_check(1, 2).is_ok());
101    }
102
103    #[test]
104    fn test_adding_edge_creates_cycle() {
105        let mut graph = DirectedGraph::new();
106        graph.add_edge_with_check(0, 1).unwrap();
107        graph.add_edge_with_check(1, 2).unwrap();
108        assert!(graph.add_edge_with_check(2, 0).is_err());
109    }
110
111    #[test]
112    fn test_empty_graph() {
113        let graph = DirectedGraph::<i32>::new();
114        assert!(graph.adj_list.is_empty());
115    }
116
117    #[test]
118    fn test_single_node_self_loop() {
119        let mut graph = DirectedGraph::new();
120        assert!(graph.add_edge_with_check(0, 0).is_err());
121    }
122}